diff --git a/src-tauri/src/audio.rs b/src-tauri/src/audio.rs index c9256cc..7c258d9 100644 --- a/src-tauri/src/audio.rs +++ b/src-tauri/src/audio.rs @@ -22,7 +22,11 @@ impl AudioPipeline { let output_path = workspace.join("normalized.wav"); let mut command = Command::new(ffmpeg_path); #[cfg(target_os = "macos")] - if let Some(lib_dir) = ffmpeg_path.parent().and_then(|bin_dir| bin_dir.parent()).map(|root| root.join("lib")) { + if let Some(lib_dir) = ffmpeg_path + .parent() + .and_then(|bin_dir| bin_dir.parent()) + .map(|root| root.join("lib")) + { if lib_dir.exists() { command.env("DYLD_FALLBACK_LIBRARY_PATH", &lib_dir); } @@ -70,7 +74,9 @@ impl AudioPipeline { } } - let status = child.wait().with_context(|| "ffmpeg process failed to wait")?; + let status = child + .wait() + .with_context(|| "ffmpeg process failed to wait")?; if !status.success() { return Err(anyhow!("ffmpeg exited with status: {}", status)); } @@ -80,8 +86,8 @@ impl AudioPipeline { } pub fn load_wav_f32(path: &Path) -> Result> { - let mut reader = - hound::WavReader::open(path).with_context(|| format!("failed to open {}", path.display()))?; + let mut reader = hound::WavReader::open(path) + .with_context(|| format!("failed to open {}", path.display()))?; let spec = reader.spec(); if spec.channels != 1 { @@ -124,7 +130,9 @@ fn parse_ffmpeg_duration(line: &str) -> Option { fn parse_ffmpeg_time(line: &str) -> Option { let pos = line.find("time=")?; let rest = &line[pos + 5..]; - let end = rest.find(|c: char| !c.is_digit(10) && c != ':' && c != '.').unwrap_or(rest.len()); + let end = rest + .find(|c: char| !c.is_digit(10) && c != ':' && c != '.') + .unwrap_or(rest.len()); let time_str = &rest[..end]; let parts: Vec<&str> = time_str.split(':').collect(); if parts.len() == 3 { diff --git a/src-tauri/src/lib.rs b/src-tauri/src/lib.rs index 93be975..1ac1833 100644 --- a/src-tauri/src/lib.rs +++ b/src-tauri/src/lib.rs @@ -10,17 +10,17 @@ mod whisper; use models::{ DefaultModelPaths, StartTaskPayload, SubtitleSegment, SubtitleTask, TranslationConfig, }; +#[cfg(target_os = "macos")] +use objc2_app_kit::NSWindow; +#[cfg(target_os = "macos")] +use objc2_foundation::NSSize; use state::AppState; -use tauri::{AppHandle, Manager, PhysicalSize, Size}; #[cfg(target_os = "macos")] use tauri::{ menu::{MenuBuilder, MenuItemBuilder, PredefinedMenuItem, SubmenuBuilder}, Emitter, }; -#[cfg(target_os = "macos")] -use objc2_app_kit::NSWindow; -#[cfg(target_os = "macos")] -use objc2_foundation::NSSize; +use tauri::{AppHandle, Manager, PhysicalSize, Size}; #[cfg(target_os = "macos")] const WINDOW_RATIO_WIDTH: f64 = 16.0; @@ -74,7 +74,9 @@ fn export_subtitles( } #[tauri::command] -fn get_default_model_paths(app: tauri::AppHandle) -> std::result::Result { +fn get_default_model_paths( + app: tauri::AppHandle, +) -> std::result::Result { task::get_default_model_paths(&app).map_err(error_to_string) } @@ -152,7 +154,11 @@ fn configure_macos_menu(app: &AppHandle) -> tauri::Result<()> { .build()?; let file_menu = SubmenuBuilder::new(app, "文件") - .item(&MenuItemBuilder::with_id("pick_files", "选择媒体文件").accelerator("CmdOrCtrl+O").build(app)?) + .item( + &MenuItemBuilder::with_id("pick_files", "选择媒体文件") + .accelerator("CmdOrCtrl+O") + .build(app)?, + ) .separator() .item(&MenuItemBuilder::with_id("export_srt", "导出 SRT").build(app)?) .item(&MenuItemBuilder::with_id("export_vtt", "导出 VTT").build(app)?) diff --git a/src-tauri/src/state.rs b/src-tauri/src/state.rs index 6b5b48f..29a5e63 100644 --- a/src-tauri/src/state.rs +++ b/src-tauri/src/state.rs @@ -11,13 +11,19 @@ pub struct AppState { impl AppState { pub fn upsert_task(&self, task: SubtitleTask) -> Result<()> { - let mut guard = self.tasks.lock().map_err(|_| anyhow!("task store poisoned"))?; + let mut guard = self + .tasks + .lock() + .map_err(|_| anyhow!("task store poisoned"))?; guard.insert(task.id.clone(), task); Ok(()) } pub fn get_task(&self, task_id: &str) -> Result { - let guard = self.tasks.lock().map_err(|_| anyhow!("task store poisoned"))?; + let guard = self + .tasks + .lock() + .map_err(|_| anyhow!("task store poisoned"))?; guard .get(task_id) .cloned() @@ -25,20 +31,29 @@ impl AppState { } pub fn list_tasks(&self) -> Result> { - let guard = self.tasks.lock().map_err(|_| anyhow!("task store poisoned"))?; + let guard = self + .tasks + .lock() + .map_err(|_| anyhow!("task store poisoned"))?; let mut tasks = guard.values().cloned().collect::>(); tasks.sort_by(|left, right| right.id.cmp(&left.id)); Ok(tasks) } pub fn delete_task(&self, task_id: &str) -> Result<()> { - let mut guard = self.tasks.lock().map_err(|_| anyhow!("task store poisoned"))?; + let mut guard = self + .tasks + .lock() + .map_err(|_| anyhow!("task store poisoned"))?; guard.remove(task_id); Ok(()) } pub fn update_segment(&self, segment: SubtitleSegment) -> Result { - let mut guard = self.tasks.lock().map_err(|_| anyhow!("task store poisoned"))?; + let mut guard = self + .tasks + .lock() + .map_err(|_| anyhow!("task store poisoned"))?; let task = guard .get_mut(&segment.task_id) .ok_or_else(|| anyhow!("task not found: {}", segment.task_id))?; diff --git a/src-tauri/src/task.rs b/src-tauri/src/task.rs index 0c6cf25..277b540 100644 --- a/src-tauri/src/task.rs +++ b/src-tauri/src/task.rs @@ -5,6 +5,7 @@ use std::{ atomic::{AtomicU32, Ordering}, Arc, }, + time::Duration, }; use anyhow::{Context, Result}; @@ -14,9 +15,9 @@ use uuid::Uuid; use crate::{ audio::AudioPipeline, models::{ - DefaultModelPaths, ErrorEvent, LogEvent, OutputMode, ProgressEvent, - ResetSegmentsEvent, StartTaskPayload, SubStageProgress, SubtitleSegment, - SubtitleTask, TargetLanguage, TaskStatus, TranslationConfig, + DefaultModelPaths, ErrorEvent, LogEvent, OutputMode, ProgressEvent, ResetSegmentsEvent, + StartTaskPayload, SubStageProgress, SubtitleSegment, SubtitleTask, TargetLanguage, + TaskStatus, TranslationConfig, }, state::AppState, subtitle::{render, SubtitleFormat}, @@ -45,7 +46,11 @@ pub async fn start_task( state: tauri::State<'_, AppState>, mut payload: StartTaskPayload, ) -> Result { - if payload.whisper_model_path.as_deref().is_none_or(str::is_empty) { + if payload + .whisper_model_path + .as_deref() + .is_none_or(str::is_empty) + { payload.whisper_model_path = resolve_default_model_path(&app, DEFAULT_WHISPER_MODEL); } if payload.vad_model_path.as_deref().is_none_or(str::is_empty) { @@ -85,11 +90,21 @@ pub async fn start_task( let task_id = task.id.clone(); tauri::async_runtime::spawn(async move { - if let Err(error) = run_pipeline(app_handle, window_handle.clone(), task_for_spawn, payload_for_spawn).await { - if let Ok(mut failed_task) = app_handle_for_error.state::().get_task(&task_id) { + if let Err(error) = run_pipeline( + app_handle, + window_handle.clone(), + task_for_spawn, + payload_for_spawn, + ) + .await + { + if let Ok(mut failed_task) = app_handle_for_error.state::().get_task(&task_id) + { failed_task.status = TaskStatus::Failed; failed_task.error = Some(error.to_string()); - let _ = app_handle_for_error.state::().upsert_task(failed_task); + let _ = app_handle_for_error + .state::() + .upsert_task(failed_task); } let _ = emit_error(&window_handle, &task_id, &error.to_string()); } @@ -171,12 +186,28 @@ async fn run_pipeline( } } - let ffmpeg_path = resolve_ffmpeg_path(&app) - .ok_or_else(|| anyhow::anyhow!("未找到可用 ffmpeg,请重新执行打包命令或在系统中安装 ffmpeg"))?; + let ffmpeg_path = resolve_ffmpeg_path(&app).ok_or_else(|| { + anyhow::anyhow!("未找到可用 ffmpeg,请重新执行打包命令或在系统中安装 ffmpeg") + })?; - set_status(&window, &app_state, &mut task, TaskStatus::Extracting, 5.0, "正在抽取音频")?; - emit_log(&window, &task.id, format!("task: input file={}", payload.file_path))?; - emit_log(&window, &task.id, format!("audio: ffmpeg={}", ffmpeg_path.display()))?; + set_status( + &window, + &app_state, + &mut task, + TaskStatus::Extracting, + 5.0, + "正在抽取音频", + )?; + emit_log( + &window, + &task.id, + format!("task: input file={}", payload.file_path), + )?; + emit_log( + &window, + &task.id, + format!("audio: ffmpeg={}", ffmpeg_path.display()), + )?; let window_for_extract = window.clone(); let task_id_for_extract = task.id.clone(); @@ -204,40 +235,60 @@ async fn run_pipeline( ); }, )?; - emit_log(&window, &task.id, format!("audio: normalized wav={}", wav_path.display()))?; + emit_log( + &window, + &task.id, + format!("audio: normalized wav={}", wav_path.display()), + )?; - set_status(&window, &app_state, &mut task, TaskStatus::VadProcessing, 15.0, "正在分析语音片段")?; + set_status( + &window, + &app_state, + &mut task, + TaskStatus::VadProcessing, + 15.0, + "正在分析语音片段", + )?; let samples = AudioPipeline::load_wav_f32(&wav_path)?; let vad = VadEngine::new(payload.vad_model_path.clone(), VadConfig::default())?; let window_for_vad = window.clone(); let task_id_for_vad = task.id.clone(); - let speech_ranges = vad.detect_segments(&samples, move |ratio: f32| { - let overall = 15.0 + ratio.clamp(0.0, 1.0) * 15.0; - let sub = SubStageProgress { - extracting: 100.0, - vad: ratio.clamp(0.0, 1.0) * 100.0, - transcribing: 0.0, - translating: 0.0, - }; - let _ = window_for_vad.emit( - "task:progress", - ProgressEvent { - task_id: task_id_for_vad.clone(), - status: TaskStatus::VadProcessing, - progress: overall, - message: "正在分析语音片段".to_string(), - sub_stage_progress: sub, - }, - ); - }).await; + let speech_ranges = vad + .detect_segments(&samples, move |ratio: f32| { + let overall = 15.0 + ratio.clamp(0.0, 1.0) * 15.0; + let sub = SubStageProgress { + extracting: 100.0, + vad: ratio.clamp(0.0, 1.0) * 100.0, + transcribing: 0.0, + translating: 0.0, + }; + let _ = window_for_vad.emit( + "task:progress", + ProgressEvent { + task_id: task_id_for_vad.clone(), + status: TaskStatus::VadProcessing, + progress: overall, + message: "正在分析语音片段".to_string(), + sub_stage_progress: sub, + }, + ); + }) + .await; emit_log( &window, &task.id, format!("vad: detected {} speech ranges", speech_ranges.len()), )?; - set_status(&window, &app_state, &mut task, TaskStatus::Transcribing, 30.0, "正在执行 Whisper")?; + set_status( + &window, + &app_state, + &mut task, + TaskStatus::Transcribing, + 30.0, + "正在执行 Whisper", + )?; // Shared progress state between concurrent transcribing and translating let transcribing_pct: Arc = Arc::new(AtomicU32::new(0)); @@ -251,7 +302,11 @@ async fn run_pipeline( .translation_config .clone() .or_else(load_translation_config) - .ok_or_else(|| anyhow::anyhow!("翻译模式需要填写 LLM API 配置,或设置 OPENAI_API_BASE / OPENAI_API_KEY"))?; + .ok_or_else(|| { + anyhow::anyhow!( + "翻译模式需要填写 LLM API 配置,或设置 OPENAI_API_BASE / OPENAI_API_KEY" + ) + })?; let translator = Translator::new(config)?; let (tx, rx) = tokio::sync::mpsc::channel::(1024); let window_for_worker = window.clone(); @@ -399,126 +454,107 @@ async fn incremental_translate( let mut all_segments: Vec = Vec::new(); let mut buffer: Vec = Vec::new(); let mut translated_count: usize = 0; + let idle_flush_after = Duration::from_secs(3); - let emit_translate_progress = |window: &Window, task_id: &str, done: usize, total: usize| -> Result<()> { - let ratio = if total > 0 { - (done as f32 / total as f32).clamp(0.0, 1.0) - } else { - 0.0 + let emit_translate_progress = + |window: &Window, task_id: &str, done: usize, total: usize| -> Result<()> { + let ratio = if total > 0 { + (done as f32 / total as f32).clamp(0.0, 1.0) + } else { + 0.0 + }; + translating_pct.store((ratio * 100.0) as u32, Ordering::Release); + let transcribing = transcribing_pct.load(Ordering::Acquire) as f32 / 100.0; + // translation progress must not exceed transcription progress + let translating_capped = ratio.min(transcribing); + + let overall = 30.0 + transcribing * 40.0 + translating_capped * 25.0; + let sub = SubStageProgress { + extracting: 100.0, + vad: 100.0, + transcribing: (transcribing * 100.0).min(100.0), + translating: translating_capped * 100.0, + }; + let status = if transcribing >= 1.0 { + TaskStatus::Translating + } else { + TaskStatus::Transcribing + }; + window.emit( + "task:progress", + ProgressEvent { + task_id: task_id.to_string(), + status, + progress: overall.min(95.0), + message: "正在生成译文".to_string(), + sub_stage_progress: sub, + }, + )?; + Ok(()) }; - translating_pct.store((ratio * 100.0) as u32, Ordering::Release); - let transcribing = transcribing_pct.load(Ordering::Acquire) as f32 / 100.0; - // translation progress must not exceed transcription progress - let translating_capped = ratio.min(transcribing); - let overall = 30.0 + transcribing * 40.0 + translating_capped * 25.0; - let sub = SubStageProgress { - extracting: 100.0, - vad: 100.0, - transcribing: (transcribing * 100.0).min(100.0), - translating: translating_capped * 100.0, - }; - let status = if transcribing >= 1.0 { - TaskStatus::Translating - } else { - TaskStatus::Transcribing - }; - window.emit( - "task:progress", - ProgressEvent { - task_id: task_id.to_string(), - status, - progress: overall.min(95.0), - message: "正在生成译文".to_string(), - sub_stage_progress: sub, - }, - )?; - Ok(()) - }; + loop { + match tokio::time::timeout(idle_flush_after, rx.recv()).await { + Ok(Some(segment)) => { + all_segments.push(segment.clone()); + buffer.push(segment); - while let Some(segment) = rx.recv().await { - all_segments.push(segment.clone()); - buffer.push(segment); - - if buffer.len() >= batch_size { - let batch = std::mem::take(&mut buffer); - let context_end = all_segments.len().saturating_sub(batch.len()); - let context_start = context_end.saturating_sub(context_size); - let context = &all_segments[context_start..context_end]; - - emit_log(window, task_id, format!( - "translation: batch segments={}", - batch.iter().map(|s| s.id.as_str()).collect::>().join(", ") - ))?; - - let rows = translator - .translate_batch_with_retries(context, &batch, target_lang_name(target_lang)) - .await?; - - translated_count += rows.len(); - - emit_log(window, task_id, format!("translation: batch done, translated={}", rows.len()))?; - - for row in rows { - if let Some(original) = batch.iter().find(|item| item.id == row.id) { - let mut emitted = original.clone(); - emitted.translated_text = Some(row.text); - - if let Ok(mut current_task) = app_state.get_task(task_id) { - upsert_segment(&mut current_task.segments, emitted.clone()); - let _ = app_state.upsert_task(current_task); - } - let _ = window.emit( - "task:segment", - crate::models::SegmentEvent { - task_id: task_id.to_string(), - segment: emitted, - }, - ); + if buffer.len() >= batch_size { + translate_buffered_segments( + &translator, + window, + app_state, + task_id, + target_lang, + &all_segments, + &mut buffer, + &mut translated_count, + context_size, + "batch", + ) + .await?; + emit_translate_progress(window, task_id, translated_count, all_segments.len())?; } } + Ok(None) => break, + Err(_) => { + if buffer.is_empty() { + continue; + } - emit_translate_progress(window, task_id, translated_count, all_segments.len())?; + translate_buffered_segments( + &translator, + window, + app_state, + task_id, + target_lang, + &all_segments, + &mut buffer, + &mut translated_count, + context_size, + "idle batch", + ) + .await?; + emit_translate_progress(window, task_id, translated_count, all_segments.len())?; + } } } // Flush remaining segments below batch_size if !buffer.is_empty() { - let batch = std::mem::take(&mut buffer); - let context_end = all_segments.len().saturating_sub(batch.len()); - let context_start = context_end.saturating_sub(context_size); - let context = &all_segments[context_start..context_end]; - - emit_log(window, task_id, format!( - "translation: final batch segments={}", - batch.iter().map(|s| s.id.as_str()).collect::>().join(", ") - ))?; - - let rows = translator - .translate_batch_with_retries(context, &batch, target_lang_name(target_lang)) - .await?; - - translated_count += rows.len(); - - for row in rows { - if let Some(original) = batch.iter().find(|item| item.id == row.id) { - let mut emitted = original.clone(); - emitted.translated_text = Some(row.text); - - if let Ok(mut current_task) = app_state.get_task(task_id) { - upsert_segment(&mut current_task.segments, emitted.clone()); - let _ = app_state.upsert_task(current_task); - } - let _ = window.emit( - "task:segment", - crate::models::SegmentEvent { - task_id: task_id.to_string(), - segment: emitted, - }, - ); - } - } - + translate_buffered_segments( + &translator, + window, + app_state, + task_id, + target_lang, + &all_segments, + &mut buffer, + &mut translated_count, + context_size, + "final batch", + ) + .await?; emit_translate_progress(window, task_id, translated_count, all_segments.len())?; } @@ -544,6 +580,75 @@ async fn incremental_translate( Ok(()) } +async fn translate_buffered_segments( + translator: &Translator, + window: &Window, + app_state: &AppState, + task_id: &str, + target_lang: &TargetLanguage, + all_segments: &[SubtitleSegment], + buffer: &mut Vec, + translated_count: &mut usize, + context_size: usize, + label: &str, +) -> Result<()> { + if buffer.is_empty() { + return Ok(()); + } + + let batch = std::mem::take(buffer); + let context_end = all_segments.len().saturating_sub(batch.len()); + let context_start = context_end.saturating_sub(context_size); + let context = &all_segments[context_start..context_end]; + + emit_log( + window, + task_id, + format!( + "translation: {} segments={}", + label, + batch + .iter() + .map(|segment| segment.id.as_str()) + .collect::>() + .join(", ") + ), + )?; + + let rows = translator + .translate_batch_with_retries(context, &batch, target_lang_name(target_lang)) + .await?; + + *translated_count += rows.len(); + + emit_log( + window, + task_id, + format!("translation: {} done, translated={}", label, rows.len()), + )?; + + for row in rows { + if let Some(original) = batch.iter().find(|item| item.id == row.id) { + let mut emitted = original.clone(); + emitted.translated_text = Some(row.text); + + if let Ok(mut current_task) = app_state.get_task(task_id) { + upsert_segment(&mut current_task.segments, emitted.clone()); + let _ = app_state.upsert_task(current_task); + } + let _ = window.emit( + "task:segment", + crate::models::SegmentEvent { + task_id: task_id.to_string(), + segment: emitted, + }, + ); + } + } + + Ok(()) +} + fn target_lang_name(target_lang: &TargetLanguage) -> &'static str { match target_lang { TargetLanguage::Zh => "简体中文", @@ -616,7 +721,10 @@ fn set_status( Ok(()) } -pub fn update_segment_text(state: tauri::State<'_, AppState>, segment: SubtitleSegment) -> Result { +pub fn update_segment_text( + state: tauri::State<'_, AppState>, + segment: SubtitleSegment, +) -> Result { state.update_segment(segment) } @@ -628,7 +736,11 @@ pub fn delete_task(state: tauri::State<'_, AppState>, task_id: String) -> Result state.delete_task(&task_id) } -pub fn export_task(state: tauri::State<'_, AppState>, task_id: String, format: String) -> Result { +pub fn export_task( + state: tauri::State<'_, AppState>, + task_id: String, + format: String, +) -> Result { let task = state.get_task(&task_id)?; let format = SubtitleFormat::try_from(format.as_str())?; let content = render(&task.segments, format, task.bilingual_output); @@ -803,7 +915,11 @@ fn upsert_segment(segments: &mut Vec, segment: SubtitleSegment) left.start .partial_cmp(&right.start) .unwrap_or(std::cmp::Ordering::Equal) - .then_with(|| left.end.partial_cmp(&right.end).unwrap_or(std::cmp::Ordering::Equal)) + .then_with(|| { + left.end + .partial_cmp(&right.end) + .unwrap_or(std::cmp::Ordering::Equal) + }) .then_with(|| left.id.cmp(&right.id)) }); } diff --git a/src-tauri/src/translate.rs b/src-tauri/src/translate.rs index d4253e2..4487afe 100644 --- a/src-tauri/src/translate.rs +++ b/src-tauri/src/translate.rs @@ -120,7 +120,10 @@ impl Translator { let rows = self .translate_batch_with_retries(context, batch, target_language_name) .await?; - log(format!("translation: batch done, translated={}", rows.len())); + log(format!( + "translation: batch done, translated={}", + rows.len() + )); for row in rows { if let Some(segment) = translated.iter_mut().find(|item| item.id == row.id) { @@ -257,7 +260,9 @@ impl Translator { match response { Ok(response) => { - let response = response.error_for_status().context("translation http error")?; + let response = response + .error_for_status() + .context("translation http error")?; let raw_text = response.text().await.context("invalid response body")?; eprintln!("translation raw response:\n{}", raw_text); let payload: ChatCompletionResponse = @@ -269,8 +274,9 @@ impl Translator { .message .content .clone(); - let rows = parse_translation_response(&content) - .with_context(|| format!("translation json parse failed: {}", preview(&content)))?; + let rows = parse_translation_response(&content).with_context(|| { + format!("translation json parse failed: {}", preview(&content)) + })?; return Ok(rows); } Err(error) => { @@ -362,10 +368,7 @@ fn strip_code_fence(content: &str) -> String { .trim_start_matches("```json") .trim_start_matches("```JSON") .trim_start_matches("```"); - without_prefix - .trim_end_matches("```") - .trim() - .to_string() + without_prefix.trim_end_matches("```").trim().to_string() } fn extract_json_object(content: &str) -> Option { @@ -403,7 +406,11 @@ fn mask_secret(secret: &str) -> String { return "****".to_string(); } - format!("{}****{}", &secret[..4], &secret[secret.len().saturating_sub(4)..]) + format!( + "{}****{}", + &secret[..4], + &secret[secret.len().saturating_sub(4)..] + ) } fn extract_rows_loose(content: &str) -> Vec { diff --git a/src-tauri/src/vad.rs b/src-tauri/src/vad.rs index 47334ef..975e5df 100644 --- a/src-tauri/src/vad.rs +++ b/src-tauri/src/vad.rs @@ -69,7 +69,8 @@ impl VadEngine { return None; } }; - Self::detect_with_onnx(&mut session, &samples_owned, &config, on_progress_onnx).ok() + Self::detect_with_onnx(&mut session, &samples_owned, &config, on_progress_onnx) + .ok() }), ) .await @@ -156,7 +157,11 @@ impl VadEngine { } on_progress(1.0); - Ok(Self::merge_probabilities(&speech_probabilities, chunk_size, config)) + Ok(Self::merge_probabilities( + &speech_probabilities, + chunk_size, + config, + )) } fn detect_segments_with_energy( @@ -192,10 +197,19 @@ impl VadEngine { energies.len(), dynamic_threshold ); - Self::merge_probabilities_with_threshold(&energies, frame_size, dynamic_threshold, &self.config) + Self::merge_probabilities_with_threshold( + &energies, + frame_size, + dynamic_threshold, + &self.config, + ) } - fn merge_probabilities(frames: &[f32], frame_size: usize, config: &VadConfig) -> Vec<(f32, f32)> { + fn merge_probabilities( + frames: &[f32], + frame_size: usize, + config: &VadConfig, + ) -> Vec<(f32, f32)> { Self::merge_probabilities_with_threshold(frames, frame_size, config.threshold, config) } @@ -228,7 +242,8 @@ impl VadEngine { let end_frame = index.saturating_sub(silent_frames); if end_frame.saturating_sub(start) >= min_speech_frames { let start_sec = (start * frame_size) as f32 / config.sample_rate as f32; - let end_sec = ((end_frame + 1) * frame_size) as f32 / config.sample_rate as f32; + let end_sec = + ((end_frame + 1) * frame_size) as f32 / config.sample_rate as f32; result.push(((start_sec - pad_seconds).max(0.0), end_sec + pad_seconds)); } start_frame = None; diff --git a/src-tauri/src/whisper.rs b/src-tauri/src/whisper.rs index 072d54b..d804238 100644 --- a/src-tauri/src/whisper.rs +++ b/src-tauri/src/whisper.rs @@ -48,12 +48,12 @@ impl WhisperEngine { let audio = load_audio_f32(wav_path)?; let total_seconds = audio.len() as f32 / 16_000.0; let normalized_ranges = normalize_speech_ranges(speech_ranges, audio.len()); - let context = WhisperContext::new_with_params( - model_path, - WhisperContextParameters::default(), - ) - .with_context(|| format!("failed to load whisper model: {model_path}"))?; - let mut state = context.create_state().context("failed to create whisper state")?; + let context = + WhisperContext::new_with_params(model_path, WhisperContextParameters::default()) + .with_context(|| format!("failed to load whisper model: {model_path}"))?; + let mut state = context + .create_state() + .context("failed to create whisper state")?; let detected_language = resolve_source_language(&mut state, &audio, source_lang) .context("failed to resolve source language")?; @@ -187,7 +187,10 @@ impl WhisperEngine { } } - on_log(format!("whisper: total emitted segments={}", segments.len()))?; + on_log(format!( + "whisper: total emitted segments={}", + segments.len() + ))?; Ok(segments) } } @@ -226,7 +229,9 @@ fn transcribe_clip( } } - state.full(params, clip).context("whisper inference failed")?; + state + .full(params, clip) + .context("whisper inference failed")?; let num_segments = state.full_n_segments(); on_log(format!( @@ -303,7 +308,10 @@ fn load_audio_f32(path: &Path) -> Result> { .with_context(|| format!("failed to open wav file: {}", path.display()))?; let spec = reader.spec(); if spec.sample_rate != 16_000 { - return Err(anyhow!("whisper expects 16k audio, got {}", spec.sample_rate)); + return Err(anyhow!( + "whisper expects 16k audio, got {}", + spec.sample_rate + )); } if spec.channels != 1 { return Err(anyhow!("whisper expects mono audio, got {}", spec.channels)); @@ -311,7 +319,11 @@ fn load_audio_f32(path: &Path) -> Result> { let samples = reader .into_samples::() - .map(|sample| sample.map(|value| value as f32 / i16::MAX as f32).map_err(anyhow::Error::from)) + .map(|sample| { + sample + .map(|value| value as f32 / i16::MAX as f32) + .map_err(anyhow::Error::from) + }) .collect::>>()?; Ok(samples) @@ -372,7 +384,9 @@ fn should_prefer_full_audio( full_text_len > vad_text_len + vad_text_len * 3 / 5 || full_audio_segments.len() > vad_segments.len() + 5 || full_end > vad_end + 5.0 - || (total_seconds > 60.0 && full_end + 1.5 >= total_seconds && vad_end + 5.0 < total_seconds) + || (total_seconds > 60.0 + && full_end + 1.5 >= total_seconds + && vad_end + 5.0 < total_seconds) } fn resolve_source_language<'a>(