VAD 超时机制可能导致线程泄漏和结果不一致的问题

This commit is contained in:
kura 2026-04-30 15:44:30 +08:00
parent c855cf5be7
commit a7046eba8c
2 changed files with 61 additions and 47 deletions

View File

@ -158,7 +158,7 @@ async fn run_pipeline(
set_status(&window, &app_state, &mut task, TaskStatus::VadProcessing, 22.0, "正在分析语音片段")?; set_status(&window, &app_state, &mut task, TaskStatus::VadProcessing, 22.0, "正在分析语音片段")?;
let samples = AudioPipeline::load_wav_f32(&wav_path)?; let samples = AudioPipeline::load_wav_f32(&wav_path)?;
let vad = VadEngine::new(payload.vad_model_path.clone(), VadConfig::default())?; let vad = VadEngine::new(payload.vad_model_path.clone(), VadConfig::default())?;
let speech_ranges = vad.detect_segments(&samples); let speech_ranges = vad.detect_segments(&samples).await;
emit_log( emit_log(
&window, &window,
&task.id, &task.id,

View File

@ -1,6 +1,5 @@
use std::{ use std::{
path::{Path, PathBuf}, path::{Path, PathBuf},
sync::mpsc,
time::Duration, time::Duration,
}; };
@ -15,6 +14,7 @@ pub struct VadConfig {
pub min_speech_ms: usize, pub min_speech_ms: usize,
pub min_silence_ms: usize, pub min_silence_ms: usize,
pub pad_ms: usize, pub pad_ms: usize,
pub timeout_seconds: u64,
} }
impl Default for VadConfig { impl Default for VadConfig {
@ -25,6 +25,7 @@ impl Default for VadConfig {
min_speech_ms: 180, min_speech_ms: 180,
min_silence_ms: 320, min_silence_ms: 320,
pad_ms: 220, pad_ms: 220,
timeout_seconds: 60,
} }
} }
} }
@ -46,49 +47,65 @@ impl VadEngine {
Ok(Self { model_path, config }) Ok(Self { model_path, config })
} }
pub fn detect_segments(&self, samples: &[f32]) -> Vec<(f32, f32)> { pub async fn detect_segments(&self, samples: &[f32]) -> Vec<(f32, f32)> {
if let Some(model_path) = &self.model_path { if self.model_path.is_some() {
let model_path = model_path.clone(); let samples_owned = samples.to_vec();
let samples = samples.to_vec(); let model_path = self.model_path.clone().unwrap();
let config = self.config.clone(); let config = self.config.clone();
let (sender, receiver) = mpsc::channel(); let timeout_secs = self.config.timeout_seconds;
std::thread::spawn(move || { match tokio::time::timeout(
let engine = VadEngine { Duration::from_secs(timeout_secs),
model_path: Some(model_path.clone()), tokio::task::spawn_blocking(move || {
config, let mut session = match Self::load_onnx_session(&model_path) {
}; Ok(s) => s,
let result = engine.detect_segments_with_onnx(&samples, &model_path); Err(e) => {
let _ = sender.send(result); eprintln!("vad: failed to load onnx session: {e:#}");
}); return None;
}
match receiver.recv_timeout(Duration::from_secs(3)) { };
Ok(Ok(result)) if !result.is_empty() => return result, Self::detect_with_onnx(&mut session, &samples_owned, &config).ok()
}),
)
.await
{
Ok(Ok(Some(ranges))) if !ranges.is_empty() => {
eprintln!("vad: onnx detected {} speech ranges", ranges.len());
return ranges;
}
Ok(Ok(_)) => {} Ok(Ok(_)) => {}
Ok(Err(error)) => { Ok(Err(e)) => {
eprintln!("silero vad failed, falling back to energy detection: {error:#}"); eprintln!("vad: onnx error: {e:#}, falling back to energy detection");
} }
Err(mpsc::RecvTimeoutError::Timeout) => { Err(_) => {
eprintln!("silero vad timed out, falling back to energy detection"); eprintln!(
} "vad: onnx timed out after {}s, falling back to energy detection",
Err(mpsc::RecvTimeoutError::Disconnected) => { timeout_secs
eprintln!("silero vad worker disconnected, falling back to energy detection"); );
} }
} }
} }
self.detect_segments_with_energy(samples) let ranges = self.detect_segments_with_energy(samples);
eprintln!("vad: energy detection found {} speech ranges", ranges.len());
ranges
} }
fn detect_segments_with_onnx(&self, samples: &[f32], model_path: &Path) -> Result<Vec<(f32, f32)>> { fn load_onnx_session(model_path: &Path) -> Result<Session> {
let mut session = Session::builder() Session::builder()
.context("failed to build onnx session")? .context("failed to build onnx session")?
.commit_from_file(model_path) .commit_from_file(model_path)
.with_context(|| format!("failed to load silero vad model: {}", model_path.display()))?; .with_context(|| format!("failed to load silero vad model: {}", model_path.display()))
}
fn detect_with_onnx(
session: &mut Session,
samples: &[f32],
config: &VadConfig,
) -> Result<Vec<(f32, f32)>> {
let chunk_size = 512usize; let chunk_size = 512usize;
let mut state = Array3::<f32>::zeros((2, 1, 128)); let mut state = Array3::<f32>::zeros((2, 1, 128));
let sr = Array1::<i64>::from_vec(vec![self.config.sample_rate as i64]); let sr = Array1::<i64>::from_vec(vec![config.sample_rate as i64]);
let mut speech_probabilities = Vec::new(); let mut speech_probabilities = Vec::new();
for chunk in samples.chunks(chunk_size) { for chunk in samples.chunks(chunk_size) {
@ -109,10 +126,7 @@ impl VadEngine {
let (_, probs) = first let (_, probs) = first
.try_extract_tensor::<f32>() .try_extract_tensor::<f32>()
.context("failed to extract vad probabilities")?; .context("failed to extract vad probabilities")?;
let probability = probs let probability = probs.iter().copied().fold(0.0_f32, f32::max);
.iter()
.copied()
.fold(0.0_f32, f32::max);
speech_probabilities.push(probability); speech_probabilities.push(probability);
if outputs.len() > 1 { if outputs.len() > 1 {
@ -127,7 +141,7 @@ impl VadEngine {
} }
} }
Ok(self.merge_probabilities(&speech_probabilities, chunk_size)) Ok(Self::merge_probabilities(&speech_probabilities, chunk_size, config))
} }
fn detect_segments_with_energy(&self, samples: &[f32]) -> Vec<(f32, f32)> { fn detect_segments_with_energy(&self, samples: &[f32]) -> Vec<(f32, f32)> {
@ -148,22 +162,22 @@ impl VadEngine {
energies.len(), energies.len(),
dynamic_threshold dynamic_threshold
); );
self.merge_probabilities_with_threshold(&energies, frame_size, dynamic_threshold) Self::merge_probabilities_with_threshold(&energies, frame_size, dynamic_threshold, &self.config)
} }
fn merge_probabilities(&self, frames: &[f32], frame_size: usize) -> Vec<(f32, f32)> { fn merge_probabilities(frames: &[f32], frame_size: usize, config: &VadConfig) -> Vec<(f32, f32)> {
self.merge_probabilities_with_threshold(frames, frame_size, self.config.threshold) Self::merge_probabilities_with_threshold(frames, frame_size, config.threshold, config)
} }
fn merge_probabilities_with_threshold( fn merge_probabilities_with_threshold(
&self,
frames: &[f32], frames: &[f32],
frame_size: usize, frame_size: usize,
threshold: f32, threshold: f32,
config: &VadConfig,
) -> Vec<(f32, f32)> { ) -> Vec<(f32, f32)> {
let min_speech_frames = (self.config.min_speech_ms / 20).max(1); let min_speech_frames = (config.min_speech_ms / 20).max(1);
let min_silence_frames = (self.config.min_silence_ms / 20).max(1); let min_silence_frames = (config.min_silence_ms / 20).max(1);
let pad_seconds = self.config.pad_ms as f32 / 1000.0; let pad_seconds = config.pad_ms as f32 / 1000.0;
let mut result = Vec::new(); let mut result = Vec::new();
let mut start_frame: Option<usize> = None; let mut start_frame: Option<usize> = None;
@ -183,8 +197,8 @@ impl VadEngine {
if silent_frames >= min_silence_frames { if silent_frames >= min_silence_frames {
let end_frame = index.saturating_sub(silent_frames); let end_frame = index.saturating_sub(silent_frames);
if end_frame.saturating_sub(start) >= min_speech_frames { if end_frame.saturating_sub(start) >= min_speech_frames {
let start_sec = (start * frame_size) as f32 / self.config.sample_rate as f32; let start_sec = (start * frame_size) as f32 / config.sample_rate as f32;
let end_sec = ((end_frame + 1) * frame_size) as f32 / self.config.sample_rate as f32; let end_sec = ((end_frame + 1) * frame_size) as f32 / config.sample_rate as f32;
result.push(((start_sec - pad_seconds).max(0.0), end_sec + pad_seconds)); result.push(((start_sec - pad_seconds).max(0.0), end_sec + pad_seconds));
} }
start_frame = None; start_frame = None;
@ -196,14 +210,14 @@ impl VadEngine {
if let Some(start) = start_frame { if let Some(start) = start_frame {
let end_frame = frames.len().saturating_sub(1); let end_frame = frames.len().saturating_sub(1);
if end_frame.saturating_sub(start) >= min_speech_frames { if end_frame.saturating_sub(start) >= min_speech_frames {
let start_sec = (start * frame_size) as f32 / self.config.sample_rate as f32; let start_sec = (start * frame_size) as f32 / config.sample_rate as f32;
let end_sec = ((end_frame + 1) * frame_size) as f32 / self.config.sample_rate as f32; let end_sec = ((end_frame + 1) * frame_size) as f32 / config.sample_rate as f32;
result.push(((start_sec - pad_seconds).max(0.0), end_sec + pad_seconds)); result.push(((start_sec - pad_seconds).max(0.0), end_sec + pad_seconds));
} }
} }
if result.is_empty() && !frames.is_empty() { if result.is_empty() && !frames.is_empty() {
let total_seconds = (frames.len() * frame_size) as f32 / self.config.sample_rate as f32; let total_seconds = (frames.len() * frame_size) as f32 / config.sample_rate as f32;
result.push((0.0, total_seconds)); result.push((0.0, total_seconds));
} }