From 276d49ac53beed80d94b5180d5c54b90b9337b49 Mon Sep 17 00:00:00 2001 From: lelo Date: Sun, 23 Mar 2025 21:12:33 +0100 Subject: [PATCH] add single file transcription --- transcribe_single_file.py | 141 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 141 insertions(+) create mode 100644 transcribe_single_file.py diff --git a/transcribe_single_file.py b/transcribe_single_file.py new file mode 100644 index 0000000..587c25e --- /dev/null +++ b/transcribe_single_file.py @@ -0,0 +1,141 @@ +import os +import sys +import whisper +import json +import re + +# model_name = "large-v3" +model_name = "medium" + +def format_timestamp(seconds): + """Format seconds into HH:MM:SS.""" + hours = int(seconds // 3600) + minutes = int((seconds % 3600) // 60) + secs = int(seconds % 60) + if hours == 0: + return f"{minutes:02}:{secs:02}" + else: + return f"{hours:02}:{minutes:02}:{secs:02}" + +def format_status_path(path): + """Return a string with only the immediate parent folder and the filename.""" + filename = os.path.basename(path) + parent = os.path.basename(os.path.dirname(path)) + if parent: + return os.path.join(parent, filename) + return filename + +def remove_lines_with_words(transcript): + """Removes the last line from the transcript if any banned word is found in it.""" + # Define banned words + banned_words = ["copyright", "ard", "zdf", "wdr"] + + # Split transcript into lines + lines = transcript.rstrip().splitlines() + if not lines: + return transcript # Return unchanged if transcript is empty + + # Check the last line + last_line = lines[-1] + if any(banned_word.lower() in last_line.lower() for banned_word in banned_words): + # Remove the last line if any banned word is present + lines = lines[:-1] + + return "\n".join(lines) + +def apply_error_correction(text): + # Load the JSON file that contains your error_correction + with open('error_correction.json', 'r', encoding='utf-8') as file: + correction_dict = json.load(file) + + # Combine keys into a single regex pattern + pattern = r'\b(' + '|'.join(re.escape(key) for key in correction_dict.keys()) + r')\b' + + def replacement_func(match): + key = match.group(0) + return correction_dict.get(key, key) + + return re.sub(pattern, replacement_func, text) + +def write_markdown(file_path, result, postfix=None): + file_dir = os.path.dirname(file_path) + txt_folder = os.path.join(file_dir, "Transkription") + os.makedirs(txt_folder, exist_ok=True) + base_name = os.path.splitext(os.path.basename(file_path))[0] + if postfix != None: + base_name = f"{base_name}_{postfix}" + output_md = os.path.join(txt_folder, base_name + ".md") + + # Prepare the markdown content. + folder_name = os.path.basename(file_dir) + md_lines = [ + f"### {folder_name}", + f"#### {os.path.basename(file_path)}", + "---", + "" + ] + + previous_text = "" + for segment in result["segments"]: + start = format_timestamp(segment["start"]) + text = segment["text"].strip() + if previous_text != text: # suppress repeating lines + md_lines.append(f"`{start}` {text}") + previous_text = text + + transcript_md = "\n".join(md_lines) + + transcript_md = apply_error_correction(transcript_md) + + transcript_md = remove_lines_with_words(transcript_md) + + with open(output_md, "w", encoding="utf-8") as f: + f.write(transcript_md) + + print(f"... done !") + +def transcribe_file(model, audio_input, language): + initial_prompt = ( + "Dieses Audio ist eine Aufnahme eines christlichen Gottesdienstes, " + "das biblische Zitate, religiöse Begriffe und typische Gottesdienst-Phrasen enthält. " + "Achte darauf auf folgende Begriffe, die häufig falsch transkribiert wurden, korrekt wiederzugeben: " + "Stiftshütte, Bundeslade, Heiligtum, Offenbarung, Evangelium, Buße, Golgatha, " + "Apostelgeschichte, Auferstehung, Wiedergeburt. " + "Das Wort 'Bethaus' wird häufig als synonym für 'Gebetshaus' verwendet. " + "Das Wort 'Abendmahl' ist wichtig und sollte zuverlässig erkannt werden. " + "Ebenso müssen biblische Namen und Persönlichkeiten exakt transkribiert werden. " + "Zahlenangaben, beispielsweise Psalmnummern oder Bibelverse, sollen numerisch dargestellt werden." + ) + result = model.transcribe(audio_input, initial_prompt=initial_prompt, language=language) + return result + +def detect_language(model, audio): + print(" Language detected: ", end='', flush=True) + audio = whisper.pad_or_trim(audio) + mel = whisper.log_mel_spectrogram(audio, n_mels=model.dims.n_mels).to(model.device) + _, probs = model.detect_language(mel) + lang_code = max(probs, key=probs.get) + print(f"{lang_code}. ", end='', flush=True) + return lang_code + +def process_file(file_path, model, audio_input, language=None, postfix=None): + + if language == None: + language = detect_language(model, audio_input) + + print(f"Transcribing {format_status_path(file_path)}, lang={language} ", end='', flush=True) + markdown = transcribe_file(model, audio_input, language) + write_markdown(file_path, markdown, postfix) + + +if __name__ == "__main__": + if len(sys.argv) != 2: + print("Usage: python transcribe_all.py ") + sys.exit(1) + + file_name_path = sys.argv[1] + + print("Loading Whisper model...") + model = whisper.load_model(model_name, device="cuda") + audio = whisper.load_audio(file_name_path) + process_file(file_name_path, model, audio, "de")