crosssubtitle-ai/src-tauri/src/whisper.rs
2026-04-30 15:47:58 +08:00

373 lines
12 KiB
Rust

use std::path::Path;
use anyhow::{anyhow, Context, Result};
use whisper_rs::{
get_lang_str, install_logging_hooks, FullParams, SamplingStrategy, WhisperContext,
WhisperContextParameters,
};
use crate::models::{SubtitleSegment, TargetLanguage};
pub struct WhisperEngine {
model_path: Option<String>,
}
impl WhisperEngine {
pub fn new(model_path: Option<String>) -> Self {
install_logging_hooks();
Self { model_path }
}
pub fn infer_segments<F>(
&self,
wav_path: &Path,
task_id: &str,
source_lang: Option<&str>,
target_lang: &TargetLanguage,
should_translate: bool,
speech_ranges: &[(f32, f32)],
mut on_progress: F,
mut on_reset_segments: impl FnMut() -> Result<()>,
mut on_segment: impl FnMut(SubtitleSegment) -> Result<()>,
mut on_log: impl FnMut(String) -> Result<()>,
) -> Result<Vec<SubtitleSegment>>
where
F: FnMut(f32) -> Result<()>,
{
let Some(model_path) = &self.model_path else {
return Err(anyhow!(
"whisper model path is missing. Please provide a local ggml model path for task {}",
task_id
));
};
if !Path::new(model_path).exists() {
return Err(anyhow!("whisper model not found: {model_path}"));
}
let audio = load_audio_f32(wav_path)?;
let total_seconds = audio.len() as f32 / 16_000.0;
let normalized_ranges = normalize_speech_ranges(speech_ranges, audio.len());
let context = WhisperContext::new_with_params(
model_path,
WhisperContextParameters::default(),
)
.with_context(|| format!("failed to load whisper model: {model_path}"))?;
let mut state = context.create_state().context("failed to create whisper state")?;
let detected_language = resolve_source_language(&mut state, &audio, source_lang)
.context("failed to resolve source language")?;
if let Some(lang) = detected_language {
on_log(format!("whisper: source language={lang}"))?;
} else {
on_log("whisper: source language unresolved, fallback to auto decode".to_string())?;
}
let mut segments = Vec::new();
on_log(format!(
"whisper: processing {} speech ranges (normalized from {}), coverage={:.1}%",
normalized_ranges.len(),
speech_ranges.len(),
speech_coverage_ratio(&normalized_ranges, total_seconds) * 100.0
))?;
for (range_index, (start, end)) in normalized_ranges.iter().enumerate() {
let clip = slice_audio(&audio, *start, *end);
if clip.is_empty() {
continue;
}
let progress_base = range_index as f32 / normalized_ranges.len().max(1) as f32;
let progress_span = 1.0 / normalized_ranges.len().max(1) as f32;
on_progress(progress_base)?;
let clip_segments = transcribe_clip(
&mut state,
&clip,
range_index,
*start,
*end,
task_id,
detected_language,
target_lang,
should_translate,
segments.len(),
&mut on_segment,
&mut on_log,
)?;
segments.extend(clip_segments);
on_progress((progress_base + progress_span).min(1.0))?;
}
let vad_text_len = text_len(&segments);
let vad_end = last_end(&segments);
let vad_coverage = speech_coverage_ratio(&normalized_ranges, total_seconds);
let should_retry_full_audio = !audio.is_empty()
&& (segments.is_empty()
|| vad_coverage < 0.60
|| vad_end + 5.0 < total_seconds
|| (total_seconds > 60.0 && vad_text_len < (total_seconds / 3.0) as usize));
if should_retry_full_audio {
on_log(format!(
"whisper: VAD result looks incomplete, retrying full audio (segments={}, chars={}, end={:.2}s/{:.2}s, coverage={:.1}%)",
segments.len(),
vad_text_len,
vad_end,
total_seconds,
vad_coverage * 100.0
))?;
on_reset_segments()?;
let full_audio_segments = transcribe_clip(
&mut state,
&audio,
0,
0.0,
total_seconds,
task_id,
detected_language,
target_lang,
should_translate,
0,
&mut on_segment,
&mut on_log,
)?;
if should_prefer_full_audio(&segments, &full_audio_segments, total_seconds) {
on_log(format!(
"whisper: using full-audio transcript (vad_segments={}, full_segments={})",
segments.len(),
full_audio_segments.len()
))?;
segments = full_audio_segments;
} else {
on_log(format!(
"whisper: keeping VAD-based transcript (vad_segments={}, full_segments={})",
segments.len(),
full_audio_segments.len()
))?;
on_reset_segments()?;
segments.iter().cloned().try_for_each(&mut on_segment)?;
}
}
on_log(format!("whisper: total emitted segments={}", segments.len()))?;
Ok(segments)
}
}
#[allow(clippy::too_many_arguments)]
fn transcribe_clip(
state: &mut whisper_rs::WhisperState,
clip: &[f32],
range_index: usize,
start: f32,
end: f32,
task_id: &str,
source_lang: Option<&str>,
_target_lang: &TargetLanguage,
_should_translate: bool,
segment_offset: usize,
on_segment: &mut impl FnMut(SubtitleSegment) -> Result<()>,
on_log: &mut impl FnMut(String) -> Result<()>,
) -> Result<Vec<SubtitleSegment>> {
let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 });
params.set_n_threads(4);
params.set_print_special(false);
params.set_print_progress(false);
params.set_print_realtime(false);
params.set_print_timestamps(false);
params.set_token_timestamps(false);
params.set_translate(false);
match source_lang {
Some(lang) => {
params.set_detect_language(false);
params.set_language(Some(lang));
}
None => {
params.set_detect_language(true);
params.set_language(None);
}
}
state.full(params, clip).context("whisper inference failed")?;
let num_segments = state.full_n_segments();
on_log(format!(
"whisper: range #{}, {:.2}-{:.2}s, samples={}, segments={}",
range_index + 1,
start,
end,
clip.len(),
num_segments
))?;
let mut results = Vec::new();
for offset in 0..num_segments {
let segment = state
.get_segment(offset)
.ok_or_else(|| anyhow!("failed to access whisper segment {offset}"))?;
let text = segment
.to_str_lossy()
.context("failed to get whisper segment text")?
.trim()
.to_string();
if text.is_empty() {
continue;
}
on_log(format!("whisper text: {}", text))?;
let local_start = segment.start_timestamp() as f32 / 100.0;
let local_end = segment.end_timestamp() as f32 / 100.0;
let emitted = SubtitleSegment {
id: format!("seg-{:04}", segment_offset + results.len() + 1),
task_id: task_id.to_string(),
start: start + local_start,
end: start + local_end,
source_text: text.clone(),
translated_text: None,
};
on_segment(emitted.clone())?;
results.push(emitted);
}
Ok(results)
}
fn normalize_speech_ranges(ranges: &[(f32, f32)], total_samples: usize) -> Vec<(f32, f32)> {
if ranges.is_empty() {
return vec![(0.0, total_samples as f32 / 16_000.0)];
}
let total_seconds = total_samples as f32 / 16_000.0;
let mut merged = Vec::new();
let mut current = ranges[0];
for &(start, end) in &ranges[1..] {
let current_duration = current.1 - current.0;
let gap = start - current.1;
if gap <= 1.2 || current_duration < 8.0 {
current.1 = end;
} else {
merged.push(current);
current = (start, end);
}
}
merged.push(current);
merged
.into_iter()
.map(|(start, end)| ((start - 0.35).max(0.0), (end + 0.35).min(total_seconds)))
.collect()
}
fn load_audio_f32(path: &Path) -> Result<Vec<f32>> {
let reader = hound::WavReader::open(path)
.with_context(|| format!("failed to open wav file: {}", path.display()))?;
let spec = reader.spec();
if spec.sample_rate != 16_000 {
return Err(anyhow!("whisper expects 16k audio, got {}", spec.sample_rate));
}
if spec.channels != 1 {
return Err(anyhow!("whisper expects mono audio, got {}", spec.channels));
}
let samples = reader
.into_samples::<i16>()
.map(|sample| sample.map(|value| value as f32 / i16::MAX as f32).map_err(anyhow::Error::from))
.collect::<Result<Vec<_>>>()?;
Ok(samples)
}
fn slice_audio(audio: &[f32], start: f32, end: f32) -> Vec<f32> {
let begin = (start * 16_000.0).floor() as usize;
let finish = (end * 16_000.0).ceil() as usize;
audio
.get(begin.min(audio.len())..finish.min(audio.len()))
.unwrap_or(&[])
.to_vec()
}
fn speech_coverage_ratio(ranges: &[(f32, f32)], total_seconds: f32) -> f32 {
if total_seconds <= 0.0 {
return 0.0;
}
let covered = ranges
.iter()
.map(|(start, end)| (end - start).max(0.0))
.sum::<f32>();
(covered / total_seconds).clamp(0.0, 1.0)
}
fn text_len(segments: &[SubtitleSegment]) -> usize {
segments
.iter()
.map(|segment| segment.source_text.chars().count())
.sum()
}
fn last_end(segments: &[SubtitleSegment]) -> f32 {
segments
.iter()
.map(|segment| segment.end)
.fold(0.0_f32, f32::max)
}
fn should_prefer_full_audio(
vad_segments: &[SubtitleSegment],
full_audio_segments: &[SubtitleSegment],
total_seconds: f32,
) -> bool {
if full_audio_segments.is_empty() {
return vad_segments.is_empty();
}
if vad_segments.is_empty() {
return true;
}
let vad_text_len = text_len(vad_segments);
let full_text_len = text_len(full_audio_segments);
let vad_end = last_end(vad_segments);
let full_end = last_end(full_audio_segments);
full_text_len > vad_text_len + vad_text_len * 3 / 5
|| full_audio_segments.len() > vad_segments.len() + 5
|| full_end > vad_end + 5.0
|| (total_seconds > 60.0 && full_end + 1.5 >= total_seconds && vad_end + 5.0 < total_seconds)
}
fn resolve_source_language<'a>(
state: &mut whisper_rs::WhisperState,
audio: &[f32],
source_lang: Option<&'a str>,
) -> Result<Option<&'a str>> {
match source_lang.map(str::trim).filter(|lang| !lang.is_empty()) {
Some("auto") | None => {
let detect_samples = audio.len().min(16_000 * 30);
let sample = &audio[..detect_samples];
state
.pcm_to_mel(sample, 4)
.context("failed to build mel spectrogram for language detection")?;
let (lang_id, probabilities) = state
.lang_detect(0, 4)
.context("whisper language detection failed")?;
let lang = get_lang_str(lang_id)
.ok_or_else(|| anyhow!("unknown whisper language id: {lang_id}"))?;
let probability = probabilities
.get(lang_id as usize)
.copied()
.unwrap_or_default();
if probability < 0.35 {
Ok(None)
} else {
Ok(Some(lang))
}
}
Some(lang) => Ok(Some(lang)),
}
}