373 lines
12 KiB
Rust
373 lines
12 KiB
Rust
use std::path::Path;
|
|
|
|
use anyhow::{anyhow, Context, Result};
|
|
use whisper_rs::{
|
|
get_lang_str, install_logging_hooks, FullParams, SamplingStrategy, WhisperContext,
|
|
WhisperContextParameters,
|
|
};
|
|
|
|
use crate::models::{SubtitleSegment, TargetLanguage};
|
|
|
|
pub struct WhisperEngine {
|
|
model_path: Option<String>,
|
|
}
|
|
|
|
impl WhisperEngine {
|
|
pub fn new(model_path: Option<String>) -> Self {
|
|
install_logging_hooks();
|
|
Self { model_path }
|
|
}
|
|
|
|
pub fn infer_segments<F>(
|
|
&self,
|
|
wav_path: &Path,
|
|
task_id: &str,
|
|
source_lang: Option<&str>,
|
|
target_lang: &TargetLanguage,
|
|
should_translate: bool,
|
|
speech_ranges: &[(f32, f32)],
|
|
mut on_progress: F,
|
|
mut on_reset_segments: impl FnMut() -> Result<()>,
|
|
mut on_segment: impl FnMut(SubtitleSegment) -> Result<()>,
|
|
mut on_log: impl FnMut(String) -> Result<()>,
|
|
) -> Result<Vec<SubtitleSegment>>
|
|
where
|
|
F: FnMut(f32) -> Result<()>,
|
|
{
|
|
let Some(model_path) = &self.model_path else {
|
|
return Err(anyhow!(
|
|
"whisper model path is missing. Please provide a local ggml model path for task {}",
|
|
task_id
|
|
));
|
|
};
|
|
|
|
if !Path::new(model_path).exists() {
|
|
return Err(anyhow!("whisper model not found: {model_path}"));
|
|
}
|
|
|
|
let audio = load_audio_f32(wav_path)?;
|
|
let total_seconds = audio.len() as f32 / 16_000.0;
|
|
let normalized_ranges = normalize_speech_ranges(speech_ranges, audio.len());
|
|
let context = WhisperContext::new_with_params(
|
|
model_path,
|
|
WhisperContextParameters::default(),
|
|
)
|
|
.with_context(|| format!("failed to load whisper model: {model_path}"))?;
|
|
let mut state = context.create_state().context("failed to create whisper state")?;
|
|
let detected_language = resolve_source_language(&mut state, &audio, source_lang)
|
|
.context("failed to resolve source language")?;
|
|
|
|
if let Some(lang) = detected_language {
|
|
on_log(format!("whisper: source language={lang}"))?;
|
|
} else {
|
|
on_log("whisper: source language unresolved, fallback to auto decode".to_string())?;
|
|
}
|
|
|
|
let mut segments = Vec::new();
|
|
on_log(format!(
|
|
"whisper: processing {} speech ranges (normalized from {}), coverage={:.1}%",
|
|
normalized_ranges.len(),
|
|
speech_ranges.len(),
|
|
speech_coverage_ratio(&normalized_ranges, total_seconds) * 100.0
|
|
))?;
|
|
for (range_index, (start, end)) in normalized_ranges.iter().enumerate() {
|
|
let clip = slice_audio(&audio, *start, *end);
|
|
if clip.is_empty() {
|
|
continue;
|
|
}
|
|
|
|
let progress_base = range_index as f32 / normalized_ranges.len().max(1) as f32;
|
|
let progress_span = 1.0 / normalized_ranges.len().max(1) as f32;
|
|
on_progress(progress_base)?;
|
|
|
|
let clip_segments = transcribe_clip(
|
|
&mut state,
|
|
&clip,
|
|
range_index,
|
|
*start,
|
|
*end,
|
|
task_id,
|
|
detected_language,
|
|
target_lang,
|
|
should_translate,
|
|
segments.len(),
|
|
&mut on_segment,
|
|
&mut on_log,
|
|
)?;
|
|
segments.extend(clip_segments);
|
|
|
|
on_progress((progress_base + progress_span).min(1.0))?;
|
|
}
|
|
|
|
let vad_text_len = text_len(&segments);
|
|
let vad_end = last_end(&segments);
|
|
let vad_coverage = speech_coverage_ratio(&normalized_ranges, total_seconds);
|
|
let should_retry_full_audio = !audio.is_empty()
|
|
&& (segments.is_empty()
|
|
|| vad_coverage < 0.60
|
|
|| vad_end + 5.0 < total_seconds
|
|
|| (total_seconds > 60.0 && vad_text_len < (total_seconds / 3.0) as usize));
|
|
|
|
if should_retry_full_audio {
|
|
on_log(format!(
|
|
"whisper: VAD result looks incomplete, retrying full audio (segments={}, chars={}, end={:.2}s/{:.2}s, coverage={:.1}%)",
|
|
segments.len(),
|
|
vad_text_len,
|
|
vad_end,
|
|
total_seconds,
|
|
vad_coverage * 100.0
|
|
))?;
|
|
on_reset_segments()?;
|
|
let full_audio_segments = transcribe_clip(
|
|
&mut state,
|
|
&audio,
|
|
0,
|
|
0.0,
|
|
total_seconds,
|
|
task_id,
|
|
detected_language,
|
|
target_lang,
|
|
should_translate,
|
|
0,
|
|
&mut on_segment,
|
|
&mut on_log,
|
|
)?;
|
|
|
|
if should_prefer_full_audio(&segments, &full_audio_segments, total_seconds) {
|
|
on_log(format!(
|
|
"whisper: using full-audio transcript (vad_segments={}, full_segments={})",
|
|
segments.len(),
|
|
full_audio_segments.len()
|
|
))?;
|
|
segments = full_audio_segments;
|
|
} else {
|
|
on_log(format!(
|
|
"whisper: keeping VAD-based transcript (vad_segments={}, full_segments={})",
|
|
segments.len(),
|
|
full_audio_segments.len()
|
|
))?;
|
|
on_reset_segments()?;
|
|
segments.iter().cloned().try_for_each(&mut on_segment)?;
|
|
}
|
|
}
|
|
|
|
on_log(format!("whisper: total emitted segments={}", segments.len()))?;
|
|
Ok(segments)
|
|
}
|
|
}
|
|
|
|
#[allow(clippy::too_many_arguments)]
|
|
fn transcribe_clip(
|
|
state: &mut whisper_rs::WhisperState,
|
|
clip: &[f32],
|
|
range_index: usize,
|
|
start: f32,
|
|
end: f32,
|
|
task_id: &str,
|
|
source_lang: Option<&str>,
|
|
_target_lang: &TargetLanguage,
|
|
_should_translate: bool,
|
|
segment_offset: usize,
|
|
on_segment: &mut impl FnMut(SubtitleSegment) -> Result<()>,
|
|
on_log: &mut impl FnMut(String) -> Result<()>,
|
|
) -> Result<Vec<SubtitleSegment>> {
|
|
let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 });
|
|
params.set_n_threads(4);
|
|
params.set_print_special(false);
|
|
params.set_print_progress(false);
|
|
params.set_print_realtime(false);
|
|
params.set_print_timestamps(false);
|
|
params.set_token_timestamps(false);
|
|
params.set_translate(false);
|
|
match source_lang {
|
|
Some(lang) => {
|
|
params.set_detect_language(false);
|
|
params.set_language(Some(lang));
|
|
}
|
|
None => {
|
|
params.set_detect_language(true);
|
|
params.set_language(None);
|
|
}
|
|
}
|
|
|
|
state.full(params, clip).context("whisper inference failed")?;
|
|
|
|
let num_segments = state.full_n_segments();
|
|
on_log(format!(
|
|
"whisper: range #{}, {:.2}-{:.2}s, samples={}, segments={}",
|
|
range_index + 1,
|
|
start,
|
|
end,
|
|
clip.len(),
|
|
num_segments
|
|
))?;
|
|
|
|
let mut results = Vec::new();
|
|
for offset in 0..num_segments {
|
|
let segment = state
|
|
.get_segment(offset)
|
|
.ok_or_else(|| anyhow!("failed to access whisper segment {offset}"))?;
|
|
let text = segment
|
|
.to_str_lossy()
|
|
.context("failed to get whisper segment text")?
|
|
.trim()
|
|
.to_string();
|
|
if text.is_empty() {
|
|
continue;
|
|
}
|
|
on_log(format!("whisper text: {}", text))?;
|
|
|
|
let local_start = segment.start_timestamp() as f32 / 100.0;
|
|
let local_end = segment.end_timestamp() as f32 / 100.0;
|
|
|
|
let emitted = SubtitleSegment {
|
|
id: format!("seg-{:04}", segment_offset + results.len() + 1),
|
|
task_id: task_id.to_string(),
|
|
start: start + local_start,
|
|
end: start + local_end,
|
|
source_text: text.clone(),
|
|
translated_text: None,
|
|
};
|
|
on_segment(emitted.clone())?;
|
|
results.push(emitted);
|
|
}
|
|
|
|
Ok(results)
|
|
}
|
|
|
|
fn normalize_speech_ranges(ranges: &[(f32, f32)], total_samples: usize) -> Vec<(f32, f32)> {
|
|
if ranges.is_empty() {
|
|
return vec![(0.0, total_samples as f32 / 16_000.0)];
|
|
}
|
|
|
|
let total_seconds = total_samples as f32 / 16_000.0;
|
|
let mut merged = Vec::new();
|
|
let mut current = ranges[0];
|
|
|
|
for &(start, end) in &ranges[1..] {
|
|
let current_duration = current.1 - current.0;
|
|
let gap = start - current.1;
|
|
if gap <= 1.2 || current_duration < 8.0 {
|
|
current.1 = end;
|
|
} else {
|
|
merged.push(current);
|
|
current = (start, end);
|
|
}
|
|
}
|
|
merged.push(current);
|
|
|
|
merged
|
|
.into_iter()
|
|
.map(|(start, end)| ((start - 0.35).max(0.0), (end + 0.35).min(total_seconds)))
|
|
.collect()
|
|
}
|
|
|
|
fn load_audio_f32(path: &Path) -> Result<Vec<f32>> {
|
|
let reader = hound::WavReader::open(path)
|
|
.with_context(|| format!("failed to open wav file: {}", path.display()))?;
|
|
let spec = reader.spec();
|
|
if spec.sample_rate != 16_000 {
|
|
return Err(anyhow!("whisper expects 16k audio, got {}", spec.sample_rate));
|
|
}
|
|
if spec.channels != 1 {
|
|
return Err(anyhow!("whisper expects mono audio, got {}", spec.channels));
|
|
}
|
|
|
|
let samples = reader
|
|
.into_samples::<i16>()
|
|
.map(|sample| sample.map(|value| value as f32 / i16::MAX as f32).map_err(anyhow::Error::from))
|
|
.collect::<Result<Vec<_>>>()?;
|
|
|
|
Ok(samples)
|
|
}
|
|
|
|
fn slice_audio(audio: &[f32], start: f32, end: f32) -> Vec<f32> {
|
|
let begin = (start * 16_000.0).floor() as usize;
|
|
let finish = (end * 16_000.0).ceil() as usize;
|
|
audio
|
|
.get(begin.min(audio.len())..finish.min(audio.len()))
|
|
.unwrap_or(&[])
|
|
.to_vec()
|
|
}
|
|
|
|
fn speech_coverage_ratio(ranges: &[(f32, f32)], total_seconds: f32) -> f32 {
|
|
if total_seconds <= 0.0 {
|
|
return 0.0;
|
|
}
|
|
|
|
let covered = ranges
|
|
.iter()
|
|
.map(|(start, end)| (end - start).max(0.0))
|
|
.sum::<f32>();
|
|
(covered / total_seconds).clamp(0.0, 1.0)
|
|
}
|
|
|
|
fn text_len(segments: &[SubtitleSegment]) -> usize {
|
|
segments
|
|
.iter()
|
|
.map(|segment| segment.source_text.chars().count())
|
|
.sum()
|
|
}
|
|
|
|
fn last_end(segments: &[SubtitleSegment]) -> f32 {
|
|
segments
|
|
.iter()
|
|
.map(|segment| segment.end)
|
|
.fold(0.0_f32, f32::max)
|
|
}
|
|
|
|
fn should_prefer_full_audio(
|
|
vad_segments: &[SubtitleSegment],
|
|
full_audio_segments: &[SubtitleSegment],
|
|
total_seconds: f32,
|
|
) -> bool {
|
|
if full_audio_segments.is_empty() {
|
|
return vad_segments.is_empty();
|
|
}
|
|
if vad_segments.is_empty() {
|
|
return true;
|
|
}
|
|
|
|
let vad_text_len = text_len(vad_segments);
|
|
let full_text_len = text_len(full_audio_segments);
|
|
let vad_end = last_end(vad_segments);
|
|
let full_end = last_end(full_audio_segments);
|
|
|
|
full_text_len > vad_text_len + vad_text_len * 3 / 5
|
|
|| full_audio_segments.len() > vad_segments.len() + 5
|
|
|| full_end > vad_end + 5.0
|
|
|| (total_seconds > 60.0 && full_end + 1.5 >= total_seconds && vad_end + 5.0 < total_seconds)
|
|
}
|
|
|
|
fn resolve_source_language<'a>(
|
|
state: &mut whisper_rs::WhisperState,
|
|
audio: &[f32],
|
|
source_lang: Option<&'a str>,
|
|
) -> Result<Option<&'a str>> {
|
|
match source_lang.map(str::trim).filter(|lang| !lang.is_empty()) {
|
|
Some("auto") | None => {
|
|
let detect_samples = audio.len().min(16_000 * 30);
|
|
let sample = &audio[..detect_samples];
|
|
state
|
|
.pcm_to_mel(sample, 4)
|
|
.context("failed to build mel spectrogram for language detection")?;
|
|
let (lang_id, probabilities) = state
|
|
.lang_detect(0, 4)
|
|
.context("whisper language detection failed")?;
|
|
let lang = get_lang_str(lang_id)
|
|
.ok_or_else(|| anyhow!("unknown whisper language id: {lang_id}"))?;
|
|
let probability = probabilities
|
|
.get(lang_id as usize)
|
|
.copied()
|
|
.unwrap_or_default();
|
|
|
|
if probability < 0.35 {
|
|
Ok(None)
|
|
} else {
|
|
Ok(Some(lang))
|
|
}
|
|
}
|
|
Some(lang) => Ok(Some(lang)),
|
|
}
|
|
}
|