use std::path::Path; use anyhow::{anyhow, Context, Result}; use whisper_rs::{ get_lang_str, install_logging_hooks, FullParams, SamplingStrategy, WhisperContext, WhisperContextParameters, }; use crate::models::{SubtitleSegment, TargetLanguage}; pub struct WhisperEngine { model_path: Option, } impl WhisperEngine { pub fn new(model_path: Option) -> Self { install_logging_hooks(); Self { model_path } } pub fn infer_segments( &self, wav_path: &Path, task_id: &str, source_lang: Option<&str>, target_lang: &TargetLanguage, should_translate: bool, speech_ranges: &[(f32, f32)], mut on_progress: F, mut on_reset_segments: impl FnMut() -> Result<()>, mut on_segment: impl FnMut(SubtitleSegment) -> Result<()>, mut on_log: impl FnMut(String) -> Result<()>, ) -> Result> where F: FnMut(f32) -> Result<()>, { let Some(model_path) = &self.model_path else { return Err(anyhow!( "whisper model path is missing. Please provide a local ggml model path for task {}", task_id )); }; if !Path::new(model_path).exists() { return Err(anyhow!("whisper model not found: {model_path}")); } let audio = load_audio_f32(wav_path)?; let total_seconds = audio.len() as f32 / 16_000.0; let normalized_ranges = normalize_speech_ranges(speech_ranges, audio.len()); let context = WhisperContext::new_with_params( model_path, WhisperContextParameters::default(), ) .with_context(|| format!("failed to load whisper model: {model_path}"))?; let mut state = context.create_state().context("failed to create whisper state")?; let detected_language = resolve_source_language(&mut state, &audio, source_lang) .context("failed to resolve source language")?; if let Some(lang) = detected_language { on_log(format!("whisper: source language={lang}"))?; } else { on_log("whisper: source language unresolved, fallback to auto decode".to_string())?; } let mut segments = Vec::new(); on_log(format!( "whisper: processing {} speech ranges (normalized from {}), coverage={:.1}%", normalized_ranges.len(), speech_ranges.len(), speech_coverage_ratio(&normalized_ranges, total_seconds) * 100.0 ))?; for (range_index, (start, end)) in normalized_ranges.iter().enumerate() { let clip = slice_audio(&audio, *start, *end); if clip.is_empty() { continue; } let progress_base = range_index as f32 / normalized_ranges.len().max(1) as f32; let progress_span = 1.0 / normalized_ranges.len().max(1) as f32; on_progress(progress_base)?; let clip_segments = transcribe_clip( &mut state, &clip, range_index, *start, *end, task_id, detected_language, target_lang, should_translate, segments.len(), &mut on_segment, &mut on_log, )?; segments.extend(clip_segments); on_progress((progress_base + progress_span).min(1.0))?; } let vad_text_len = text_len(&segments); let vad_end = last_end(&segments); let vad_coverage = speech_coverage_ratio(&normalized_ranges, total_seconds); let should_retry_full_audio = !audio.is_empty() && (segments.is_empty() || vad_coverage < 0.60 || vad_end + 5.0 < total_seconds || (total_seconds > 60.0 && vad_text_len < (total_seconds / 3.0) as usize)); if should_retry_full_audio { on_log(format!( "whisper: VAD result looks incomplete, retrying full audio (segments={}, chars={}, end={:.2}s/{:.2}s, coverage={:.1}%)", segments.len(), vad_text_len, vad_end, total_seconds, vad_coverage * 100.0 ))?; on_reset_segments()?; let full_audio_segments = transcribe_clip( &mut state, &audio, 0, 0.0, total_seconds, task_id, detected_language, target_lang, should_translate, 0, &mut on_segment, &mut on_log, )?; if should_prefer_full_audio(&segments, &full_audio_segments, total_seconds) { on_log(format!( "whisper: using full-audio transcript (vad_segments={}, full_segments={})", segments.len(), full_audio_segments.len() ))?; segments = full_audio_segments; } else { on_log(format!( "whisper: keeping VAD-based transcript (vad_segments={}, full_segments={})", segments.len(), full_audio_segments.len() ))?; on_reset_segments()?; segments.iter().cloned().try_for_each(&mut on_segment)?; } } on_log(format!("whisper: total emitted segments={}", segments.len()))?; Ok(segments) } } #[allow(clippy::too_many_arguments)] fn transcribe_clip( state: &mut whisper_rs::WhisperState, clip: &[f32], range_index: usize, start: f32, end: f32, task_id: &str, source_lang: Option<&str>, _target_lang: &TargetLanguage, _should_translate: bool, segment_offset: usize, on_segment: &mut impl FnMut(SubtitleSegment) -> Result<()>, on_log: &mut impl FnMut(String) -> Result<()>, ) -> Result> { let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 }); params.set_n_threads(4); params.set_print_special(false); params.set_print_progress(false); params.set_print_realtime(false); params.set_print_timestamps(false); params.set_token_timestamps(false); params.set_translate(false); match source_lang { Some(lang) => { params.set_detect_language(false); params.set_language(Some(lang)); } None => { params.set_detect_language(true); params.set_language(None); } } state.full(params, clip).context("whisper inference failed")?; let num_segments = state.full_n_segments(); on_log(format!( "whisper: range #{}, {:.2}-{:.2}s, samples={}, segments={}", range_index + 1, start, end, clip.len(), num_segments ))?; let mut results = Vec::new(); for offset in 0..num_segments { let segment = state .get_segment(offset) .ok_or_else(|| anyhow!("failed to access whisper segment {offset}"))?; let text = segment .to_str_lossy() .context("failed to get whisper segment text")? .trim() .to_string(); if text.is_empty() { continue; } on_log(format!("whisper text: {}", text))?; let local_start = segment.start_timestamp() as f32 / 100.0; let local_end = segment.end_timestamp() as f32 / 100.0; let emitted = SubtitleSegment { id: format!("seg-{:04}", segment_offset + results.len() + 1), task_id: task_id.to_string(), start: start + local_start, end: start + local_end, source_text: text.clone(), translated_text: None, }; on_segment(emitted.clone())?; results.push(emitted); } Ok(results) } fn normalize_speech_ranges(ranges: &[(f32, f32)], total_samples: usize) -> Vec<(f32, f32)> { if ranges.is_empty() { return vec![(0.0, total_samples as f32 / 16_000.0)]; } let total_seconds = total_samples as f32 / 16_000.0; let mut merged = Vec::new(); let mut current = ranges[0]; for &(start, end) in &ranges[1..] { let current_duration = current.1 - current.0; let gap = start - current.1; if gap <= 1.2 || current_duration < 8.0 { current.1 = end; } else { merged.push(current); current = (start, end); } } merged.push(current); merged .into_iter() .map(|(start, end)| ((start - 0.35).max(0.0), (end + 0.35).min(total_seconds))) .collect() } fn load_audio_f32(path: &Path) -> Result> { let reader = hound::WavReader::open(path) .with_context(|| format!("failed to open wav file: {}", path.display()))?; let spec = reader.spec(); if spec.sample_rate != 16_000 { return Err(anyhow!("whisper expects 16k audio, got {}", spec.sample_rate)); } if spec.channels != 1 { return Err(anyhow!("whisper expects mono audio, got {}", spec.channels)); } let samples = reader .into_samples::() .map(|sample| sample.map(|value| value as f32 / i16::MAX as f32).map_err(anyhow::Error::from)) .collect::>>()?; Ok(samples) } fn slice_audio(audio: &[f32], start: f32, end: f32) -> Vec { let begin = (start * 16_000.0).floor() as usize; let finish = (end * 16_000.0).ceil() as usize; audio .get(begin.min(audio.len())..finish.min(audio.len())) .unwrap_or(&[]) .to_vec() } fn speech_coverage_ratio(ranges: &[(f32, f32)], total_seconds: f32) -> f32 { if total_seconds <= 0.0 { return 0.0; } let covered = ranges .iter() .map(|(start, end)| (end - start).max(0.0)) .sum::(); (covered / total_seconds).clamp(0.0, 1.0) } fn text_len(segments: &[SubtitleSegment]) -> usize { segments .iter() .map(|segment| segment.source_text.chars().count()) .sum() } fn last_end(segments: &[SubtitleSegment]) -> f32 { segments .iter() .map(|segment| segment.end) .fold(0.0_f32, f32::max) } fn should_prefer_full_audio( vad_segments: &[SubtitleSegment], full_audio_segments: &[SubtitleSegment], total_seconds: f32, ) -> bool { if full_audio_segments.is_empty() { return vad_segments.is_empty(); } if vad_segments.is_empty() { return true; } let vad_text_len = text_len(vad_segments); let full_text_len = text_len(full_audio_segments); let vad_end = last_end(vad_segments); let full_end = last_end(full_audio_segments); full_text_len > vad_text_len + vad_text_len * 3 / 5 || full_audio_segments.len() > vad_segments.len() + 5 || full_end > vad_end + 5.0 || (total_seconds > 60.0 && full_end + 1.5 >= total_seconds && vad_end + 5.0 < total_seconds) } fn resolve_source_language<'a>( state: &mut whisper_rs::WhisperState, audio: &[f32], source_lang: Option<&'a str>, ) -> Result> { match source_lang.map(str::trim).filter(|lang| !lang.is_empty()) { Some("auto") | None => { let detect_samples = audio.len().min(16_000 * 30); let sample = &audio[..detect_samples]; state .pcm_to_mel(sample, 4) .context("failed to build mel spectrogram for language detection")?; let (lang_id, probabilities) = state .lang_detect(0, 4) .context("whisper language detection failed")?; let lang = get_lang_str(lang_id) .ok_or_else(|| anyhow!("unknown whisper language id: {lang_id}"))?; let probability = probabilities .get(lang_id as usize) .copied() .unwrap_or_default(); if probability < 0.35 { Ok(None) } else { Ok(Some(lang)) } } Some(lang) => Ok(Some(lang)), } }