diff --git a/src-tauri/src/task.rs b/src-tauri/src/task.rs index 1ae9a20..4d9dc51 100644 --- a/src-tauri/src/task.rs +++ b/src-tauri/src/task.rs @@ -158,7 +158,7 @@ async fn run_pipeline( set_status(&window, &app_state, &mut task, TaskStatus::VadProcessing, 22.0, "正在分析语音片段")?; let samples = AudioPipeline::load_wav_f32(&wav_path)?; let vad = VadEngine::new(payload.vad_model_path.clone(), VadConfig::default())?; - let speech_ranges = vad.detect_segments(&samples); + let speech_ranges = vad.detect_segments(&samples).await; emit_log( &window, &task.id, diff --git a/src-tauri/src/vad.rs b/src-tauri/src/vad.rs index 6fc8b3d..77fea74 100644 --- a/src-tauri/src/vad.rs +++ b/src-tauri/src/vad.rs @@ -1,6 +1,5 @@ use std::{ path::{Path, PathBuf}, - sync::mpsc, time::Duration, }; @@ -15,6 +14,7 @@ pub struct VadConfig { pub min_speech_ms: usize, pub min_silence_ms: usize, pub pad_ms: usize, + pub timeout_seconds: u64, } impl Default for VadConfig { @@ -25,6 +25,7 @@ impl Default for VadConfig { min_speech_ms: 180, min_silence_ms: 320, pad_ms: 220, + timeout_seconds: 60, } } } @@ -46,49 +47,65 @@ impl VadEngine { Ok(Self { model_path, config }) } - pub fn detect_segments(&self, samples: &[f32]) -> Vec<(f32, f32)> { - if let Some(model_path) = &self.model_path { - let model_path = model_path.clone(); - let samples = samples.to_vec(); + pub async fn detect_segments(&self, samples: &[f32]) -> Vec<(f32, f32)> { + if self.model_path.is_some() { + let samples_owned = samples.to_vec(); + let model_path = self.model_path.clone().unwrap(); let config = self.config.clone(); - let (sender, receiver) = mpsc::channel(); + let timeout_secs = self.config.timeout_seconds; - std::thread::spawn(move || { - let engine = VadEngine { - model_path: Some(model_path.clone()), - config, - }; - let result = engine.detect_segments_with_onnx(&samples, &model_path); - let _ = sender.send(result); - }); - - match receiver.recv_timeout(Duration::from_secs(3)) { - Ok(Ok(result)) if !result.is_empty() => return result, + match tokio::time::timeout( + Duration::from_secs(timeout_secs), + tokio::task::spawn_blocking(move || { + let mut session = match Self::load_onnx_session(&model_path) { + Ok(s) => s, + Err(e) => { + eprintln!("vad: failed to load onnx session: {e:#}"); + return None; + } + }; + Self::detect_with_onnx(&mut session, &samples_owned, &config).ok() + }), + ) + .await + { + Ok(Ok(Some(ranges))) if !ranges.is_empty() => { + eprintln!("vad: onnx detected {} speech ranges", ranges.len()); + return ranges; + } Ok(Ok(_)) => {} - Ok(Err(error)) => { - eprintln!("silero vad failed, falling back to energy detection: {error:#}"); + Ok(Err(e)) => { + eprintln!("vad: onnx error: {e:#}, falling back to energy detection"); } - Err(mpsc::RecvTimeoutError::Timeout) => { - eprintln!("silero vad timed out, falling back to energy detection"); - } - Err(mpsc::RecvTimeoutError::Disconnected) => { - eprintln!("silero vad worker disconnected, falling back to energy detection"); + Err(_) => { + eprintln!( + "vad: onnx timed out after {}s, falling back to energy detection", + timeout_secs + ); } } } - self.detect_segments_with_energy(samples) + let ranges = self.detect_segments_with_energy(samples); + eprintln!("vad: energy detection found {} speech ranges", ranges.len()); + ranges } - fn detect_segments_with_onnx(&self, samples: &[f32], model_path: &Path) -> Result> { - let mut session = Session::builder() + fn load_onnx_session(model_path: &Path) -> Result { + Session::builder() .context("failed to build onnx session")? .commit_from_file(model_path) - .with_context(|| format!("failed to load silero vad model: {}", model_path.display()))?; + .with_context(|| format!("failed to load silero vad model: {}", model_path.display())) + } + fn detect_with_onnx( + session: &mut Session, + samples: &[f32], + config: &VadConfig, + ) -> Result> { let chunk_size = 512usize; let mut state = Array3::::zeros((2, 1, 128)); - let sr = Array1::::from_vec(vec![self.config.sample_rate as i64]); + let sr = Array1::::from_vec(vec![config.sample_rate as i64]); let mut speech_probabilities = Vec::new(); for chunk in samples.chunks(chunk_size) { @@ -109,10 +126,7 @@ impl VadEngine { let (_, probs) = first .try_extract_tensor::() .context("failed to extract vad probabilities")?; - let probability = probs - .iter() - .copied() - .fold(0.0_f32, f32::max); + let probability = probs.iter().copied().fold(0.0_f32, f32::max); speech_probabilities.push(probability); if outputs.len() > 1 { @@ -127,7 +141,7 @@ impl VadEngine { } } - Ok(self.merge_probabilities(&speech_probabilities, chunk_size)) + Ok(Self::merge_probabilities(&speech_probabilities, chunk_size, config)) } fn detect_segments_with_energy(&self, samples: &[f32]) -> Vec<(f32, f32)> { @@ -148,22 +162,22 @@ impl VadEngine { energies.len(), dynamic_threshold ); - self.merge_probabilities_with_threshold(&energies, frame_size, dynamic_threshold) + Self::merge_probabilities_with_threshold(&energies, frame_size, dynamic_threshold, &self.config) } - fn merge_probabilities(&self, frames: &[f32], frame_size: usize) -> Vec<(f32, f32)> { - self.merge_probabilities_with_threshold(frames, frame_size, self.config.threshold) + fn merge_probabilities(frames: &[f32], frame_size: usize, config: &VadConfig) -> Vec<(f32, f32)> { + Self::merge_probabilities_with_threshold(frames, frame_size, config.threshold, config) } fn merge_probabilities_with_threshold( - &self, frames: &[f32], frame_size: usize, threshold: f32, + config: &VadConfig, ) -> Vec<(f32, f32)> { - let min_speech_frames = (self.config.min_speech_ms / 20).max(1); - let min_silence_frames = (self.config.min_silence_ms / 20).max(1); - let pad_seconds = self.config.pad_ms as f32 / 1000.0; + let min_speech_frames = (config.min_speech_ms / 20).max(1); + let min_silence_frames = (config.min_silence_ms / 20).max(1); + let pad_seconds = config.pad_ms as f32 / 1000.0; let mut result = Vec::new(); let mut start_frame: Option = None; @@ -183,8 +197,8 @@ impl VadEngine { if silent_frames >= min_silence_frames { let end_frame = index.saturating_sub(silent_frames); if end_frame.saturating_sub(start) >= min_speech_frames { - let start_sec = (start * frame_size) as f32 / self.config.sample_rate as f32; - let end_sec = ((end_frame + 1) * frame_size) as f32 / self.config.sample_rate as f32; + let start_sec = (start * frame_size) as f32 / config.sample_rate as f32; + let end_sec = ((end_frame + 1) * frame_size) as f32 / config.sample_rate as f32; result.push(((start_sec - pad_seconds).max(0.0), end_sec + pad_seconds)); } start_frame = None; @@ -196,14 +210,14 @@ impl VadEngine { if let Some(start) = start_frame { let end_frame = frames.len().saturating_sub(1); if end_frame.saturating_sub(start) >= min_speech_frames { - let start_sec = (start * frame_size) as f32 / self.config.sample_rate as f32; - let end_sec = ((end_frame + 1) * frame_size) as f32 / self.config.sample_rate as f32; + let start_sec = (start * frame_size) as f32 / config.sample_rate as f32; + let end_sec = ((end_frame + 1) * frame_size) as f32 / config.sample_rate as f32; result.push(((start_sec - pad_seconds).max(0.0), end_sec + pad_seconds)); } } if result.is_empty() && !frames.is_empty() { - let total_seconds = (frames.len() * frame_size) as f32 / self.config.sample_rate as f32; + let total_seconds = (frames.len() * frame_size) as f32 / config.sample_rate as f32; result.push((0.0, total_seconds)); }