diff --git a/transcribe_all.py b/transcribe_all.py index de02fa6..e8df9af 100755 --- a/transcribe_all.py +++ b/transcribe_all.py @@ -1,6 +1,7 @@ import os import sys import time +import torch import whisper import concurrent.futures import json @@ -18,7 +19,6 @@ with open("transcription_config.yml", "r", encoding="utf-8") as file: settings = yaml.safe_load(file) folder_list = settings.get("folder_list") model_name = settings.get("model_name") - device = settings.get("device") def load_audio_librosa(path: str, sr: int = 16_000) -> np.ndarray: audio, orig_sr = librosa.load(path, sr=sr) # load + resample to 16 kHz @@ -222,7 +222,10 @@ def process_folder(root_folder): else: print(f"Checked {checked_files} files. Start to transcribe {len(valid_files)} files.") - print("Loading Whisper model...") + # Choose “cuda” if available, otherwise “cpu” + device = "cuda" if torch.cuda.is_available() else "cpu" + print(f"Loading Whisper model on {device}…") + model = whisper.load_model(model_name, device=device) # Use a thread pool to pre-load files concurrently.