修复最后一批次的问题

This commit is contained in:
kura 2026-05-02 16:14:43 +08:00
parent 28294a6eb2
commit 78c750bcbf
7 changed files with 369 additions and 188 deletions

View File

@ -22,7 +22,11 @@ impl AudioPipeline {
let output_path = workspace.join("normalized.wav"); let output_path = workspace.join("normalized.wav");
let mut command = Command::new(ffmpeg_path); let mut command = Command::new(ffmpeg_path);
#[cfg(target_os = "macos")] #[cfg(target_os = "macos")]
if let Some(lib_dir) = ffmpeg_path.parent().and_then(|bin_dir| bin_dir.parent()).map(|root| root.join("lib")) { if let Some(lib_dir) = ffmpeg_path
.parent()
.and_then(|bin_dir| bin_dir.parent())
.map(|root| root.join("lib"))
{
if lib_dir.exists() { if lib_dir.exists() {
command.env("DYLD_FALLBACK_LIBRARY_PATH", &lib_dir); command.env("DYLD_FALLBACK_LIBRARY_PATH", &lib_dir);
} }
@ -70,7 +74,9 @@ impl AudioPipeline {
} }
} }
let status = child.wait().with_context(|| "ffmpeg process failed to wait")?; let status = child
.wait()
.with_context(|| "ffmpeg process failed to wait")?;
if !status.success() { if !status.success() {
return Err(anyhow!("ffmpeg exited with status: {}", status)); return Err(anyhow!("ffmpeg exited with status: {}", status));
} }
@ -80,8 +86,8 @@ impl AudioPipeline {
} }
pub fn load_wav_f32(path: &Path) -> Result<Vec<f32>> { pub fn load_wav_f32(path: &Path) -> Result<Vec<f32>> {
let mut reader = let mut reader = hound::WavReader::open(path)
hound::WavReader::open(path).with_context(|| format!("failed to open {}", path.display()))?; .with_context(|| format!("failed to open {}", path.display()))?;
let spec = reader.spec(); let spec = reader.spec();
if spec.channels != 1 { if spec.channels != 1 {
@ -124,7 +130,9 @@ fn parse_ffmpeg_duration(line: &str) -> Option<f64> {
fn parse_ffmpeg_time(line: &str) -> Option<f64> { fn parse_ffmpeg_time(line: &str) -> Option<f64> {
let pos = line.find("time=")?; let pos = line.find("time=")?;
let rest = &line[pos + 5..]; let rest = &line[pos + 5..];
let end = rest.find(|c: char| !c.is_digit(10) && c != ':' && c != '.').unwrap_or(rest.len()); let end = rest
.find(|c: char| !c.is_digit(10) && c != ':' && c != '.')
.unwrap_or(rest.len());
let time_str = &rest[..end]; let time_str = &rest[..end];
let parts: Vec<&str> = time_str.split(':').collect(); let parts: Vec<&str> = time_str.split(':').collect();
if parts.len() == 3 { if parts.len() == 3 {

View File

@ -10,17 +10,17 @@ mod whisper;
use models::{ use models::{
DefaultModelPaths, StartTaskPayload, SubtitleSegment, SubtitleTask, TranslationConfig, DefaultModelPaths, StartTaskPayload, SubtitleSegment, SubtitleTask, TranslationConfig,
}; };
#[cfg(target_os = "macos")]
use objc2_app_kit::NSWindow;
#[cfg(target_os = "macos")]
use objc2_foundation::NSSize;
use state::AppState; use state::AppState;
use tauri::{AppHandle, Manager, PhysicalSize, Size};
#[cfg(target_os = "macos")] #[cfg(target_os = "macos")]
use tauri::{ use tauri::{
menu::{MenuBuilder, MenuItemBuilder, PredefinedMenuItem, SubmenuBuilder}, menu::{MenuBuilder, MenuItemBuilder, PredefinedMenuItem, SubmenuBuilder},
Emitter, Emitter,
}; };
#[cfg(target_os = "macos")] use tauri::{AppHandle, Manager, PhysicalSize, Size};
use objc2_app_kit::NSWindow;
#[cfg(target_os = "macos")]
use objc2_foundation::NSSize;
#[cfg(target_os = "macos")] #[cfg(target_os = "macos")]
const WINDOW_RATIO_WIDTH: f64 = 16.0; const WINDOW_RATIO_WIDTH: f64 = 16.0;
@ -74,7 +74,9 @@ fn export_subtitles(
} }
#[tauri::command] #[tauri::command]
fn get_default_model_paths(app: tauri::AppHandle) -> std::result::Result<DefaultModelPaths, String> { fn get_default_model_paths(
app: tauri::AppHandle,
) -> std::result::Result<DefaultModelPaths, String> {
task::get_default_model_paths(&app).map_err(error_to_string) task::get_default_model_paths(&app).map_err(error_to_string)
} }
@ -152,7 +154,11 @@ fn configure_macos_menu(app: &AppHandle) -> tauri::Result<()> {
.build()?; .build()?;
let file_menu = SubmenuBuilder::new(app, "文件") let file_menu = SubmenuBuilder::new(app, "文件")
.item(&MenuItemBuilder::with_id("pick_files", "选择媒体文件").accelerator("CmdOrCtrl+O").build(app)?) .item(
&MenuItemBuilder::with_id("pick_files", "选择媒体文件")
.accelerator("CmdOrCtrl+O")
.build(app)?,
)
.separator() .separator()
.item(&MenuItemBuilder::with_id("export_srt", "导出 SRT").build(app)?) .item(&MenuItemBuilder::with_id("export_srt", "导出 SRT").build(app)?)
.item(&MenuItemBuilder::with_id("export_vtt", "导出 VTT").build(app)?) .item(&MenuItemBuilder::with_id("export_vtt", "导出 VTT").build(app)?)

View File

@ -11,13 +11,19 @@ pub struct AppState {
impl AppState { impl AppState {
pub fn upsert_task(&self, task: SubtitleTask) -> Result<()> { pub fn upsert_task(&self, task: SubtitleTask) -> Result<()> {
let mut guard = self.tasks.lock().map_err(|_| anyhow!("task store poisoned"))?; let mut guard = self
.tasks
.lock()
.map_err(|_| anyhow!("task store poisoned"))?;
guard.insert(task.id.clone(), task); guard.insert(task.id.clone(), task);
Ok(()) Ok(())
} }
pub fn get_task(&self, task_id: &str) -> Result<SubtitleTask> { pub fn get_task(&self, task_id: &str) -> Result<SubtitleTask> {
let guard = self.tasks.lock().map_err(|_| anyhow!("task store poisoned"))?; let guard = self
.tasks
.lock()
.map_err(|_| anyhow!("task store poisoned"))?;
guard guard
.get(task_id) .get(task_id)
.cloned() .cloned()
@ -25,20 +31,29 @@ impl AppState {
} }
pub fn list_tasks(&self) -> Result<Vec<SubtitleTask>> { pub fn list_tasks(&self) -> Result<Vec<SubtitleTask>> {
let guard = self.tasks.lock().map_err(|_| anyhow!("task store poisoned"))?; let guard = self
.tasks
.lock()
.map_err(|_| anyhow!("task store poisoned"))?;
let mut tasks = guard.values().cloned().collect::<Vec<_>>(); let mut tasks = guard.values().cloned().collect::<Vec<_>>();
tasks.sort_by(|left, right| right.id.cmp(&left.id)); tasks.sort_by(|left, right| right.id.cmp(&left.id));
Ok(tasks) Ok(tasks)
} }
pub fn delete_task(&self, task_id: &str) -> Result<()> { pub fn delete_task(&self, task_id: &str) -> Result<()> {
let mut guard = self.tasks.lock().map_err(|_| anyhow!("task store poisoned"))?; let mut guard = self
.tasks
.lock()
.map_err(|_| anyhow!("task store poisoned"))?;
guard.remove(task_id); guard.remove(task_id);
Ok(()) Ok(())
} }
pub fn update_segment(&self, segment: SubtitleSegment) -> Result<SubtitleTask> { pub fn update_segment(&self, segment: SubtitleSegment) -> Result<SubtitleTask> {
let mut guard = self.tasks.lock().map_err(|_| anyhow!("task store poisoned"))?; let mut guard = self
.tasks
.lock()
.map_err(|_| anyhow!("task store poisoned"))?;
let task = guard let task = guard
.get_mut(&segment.task_id) .get_mut(&segment.task_id)
.ok_or_else(|| anyhow!("task not found: {}", segment.task_id))?; .ok_or_else(|| anyhow!("task not found: {}", segment.task_id))?;

View File

@ -5,6 +5,7 @@ use std::{
atomic::{AtomicU32, Ordering}, atomic::{AtomicU32, Ordering},
Arc, Arc,
}, },
time::Duration,
}; };
use anyhow::{Context, Result}; use anyhow::{Context, Result};
@ -14,9 +15,9 @@ use uuid::Uuid;
use crate::{ use crate::{
audio::AudioPipeline, audio::AudioPipeline,
models::{ models::{
DefaultModelPaths, ErrorEvent, LogEvent, OutputMode, ProgressEvent, DefaultModelPaths, ErrorEvent, LogEvent, OutputMode, ProgressEvent, ResetSegmentsEvent,
ResetSegmentsEvent, StartTaskPayload, SubStageProgress, SubtitleSegment, StartTaskPayload, SubStageProgress, SubtitleSegment, SubtitleTask, TargetLanguage,
SubtitleTask, TargetLanguage, TaskStatus, TranslationConfig, TaskStatus, TranslationConfig,
}, },
state::AppState, state::AppState,
subtitle::{render, SubtitleFormat}, subtitle::{render, SubtitleFormat},
@ -45,7 +46,11 @@ pub async fn start_task(
state: tauri::State<'_, AppState>, state: tauri::State<'_, AppState>,
mut payload: StartTaskPayload, mut payload: StartTaskPayload,
) -> Result<SubtitleTask> { ) -> Result<SubtitleTask> {
if payload.whisper_model_path.as_deref().is_none_or(str::is_empty) { if payload
.whisper_model_path
.as_deref()
.is_none_or(str::is_empty)
{
payload.whisper_model_path = resolve_default_model_path(&app, DEFAULT_WHISPER_MODEL); payload.whisper_model_path = resolve_default_model_path(&app, DEFAULT_WHISPER_MODEL);
} }
if payload.vad_model_path.as_deref().is_none_or(str::is_empty) { if payload.vad_model_path.as_deref().is_none_or(str::is_empty) {
@ -85,11 +90,21 @@ pub async fn start_task(
let task_id = task.id.clone(); let task_id = task.id.clone();
tauri::async_runtime::spawn(async move { tauri::async_runtime::spawn(async move {
if let Err(error) = run_pipeline(app_handle, window_handle.clone(), task_for_spawn, payload_for_spawn).await { if let Err(error) = run_pipeline(
if let Ok(mut failed_task) = app_handle_for_error.state::<AppState>().get_task(&task_id) { app_handle,
window_handle.clone(),
task_for_spawn,
payload_for_spawn,
)
.await
{
if let Ok(mut failed_task) = app_handle_for_error.state::<AppState>().get_task(&task_id)
{
failed_task.status = TaskStatus::Failed; failed_task.status = TaskStatus::Failed;
failed_task.error = Some(error.to_string()); failed_task.error = Some(error.to_string());
let _ = app_handle_for_error.state::<AppState>().upsert_task(failed_task); let _ = app_handle_for_error
.state::<AppState>()
.upsert_task(failed_task);
} }
let _ = emit_error(&window_handle, &task_id, &error.to_string()); let _ = emit_error(&window_handle, &task_id, &error.to_string());
} }
@ -171,12 +186,28 @@ async fn run_pipeline(
} }
} }
let ffmpeg_path = resolve_ffmpeg_path(&app) let ffmpeg_path = resolve_ffmpeg_path(&app).ok_or_else(|| {
.ok_or_else(|| anyhow::anyhow!("未找到可用 ffmpeg请重新执行打包命令或在系统中安装 ffmpeg"))?; anyhow::anyhow!("未找到可用 ffmpeg请重新执行打包命令或在系统中安装 ffmpeg")
})?;
set_status(&window, &app_state, &mut task, TaskStatus::Extracting, 5.0, "正在抽取音频")?; set_status(
emit_log(&window, &task.id, format!("task: input file={}", payload.file_path))?; &window,
emit_log(&window, &task.id, format!("audio: ffmpeg={}", ffmpeg_path.display()))?; &app_state,
&mut task,
TaskStatus::Extracting,
5.0,
"正在抽取音频",
)?;
emit_log(
&window,
&task.id,
format!("task: input file={}", payload.file_path),
)?;
emit_log(
&window,
&task.id,
format!("audio: ffmpeg={}", ffmpeg_path.display()),
)?;
let window_for_extract = window.clone(); let window_for_extract = window.clone();
let task_id_for_extract = task.id.clone(); let task_id_for_extract = task.id.clone();
@ -204,15 +235,27 @@ async fn run_pipeline(
); );
}, },
)?; )?;
emit_log(&window, &task.id, format!("audio: normalized wav={}", wav_path.display()))?; emit_log(
&window,
&task.id,
format!("audio: normalized wav={}", wav_path.display()),
)?;
set_status(&window, &app_state, &mut task, TaskStatus::VadProcessing, 15.0, "正在分析语音片段")?; set_status(
&window,
&app_state,
&mut task,
TaskStatus::VadProcessing,
15.0,
"正在分析语音片段",
)?;
let samples = AudioPipeline::load_wav_f32(&wav_path)?; let samples = AudioPipeline::load_wav_f32(&wav_path)?;
let vad = VadEngine::new(payload.vad_model_path.clone(), VadConfig::default())?; let vad = VadEngine::new(payload.vad_model_path.clone(), VadConfig::default())?;
let window_for_vad = window.clone(); let window_for_vad = window.clone();
let task_id_for_vad = task.id.clone(); let task_id_for_vad = task.id.clone();
let speech_ranges = vad.detect_segments(&samples, move |ratio: f32| { let speech_ranges = vad
.detect_segments(&samples, move |ratio: f32| {
let overall = 15.0 + ratio.clamp(0.0, 1.0) * 15.0; let overall = 15.0 + ratio.clamp(0.0, 1.0) * 15.0;
let sub = SubStageProgress { let sub = SubStageProgress {
extracting: 100.0, extracting: 100.0,
@ -230,14 +273,22 @@ async fn run_pipeline(
sub_stage_progress: sub, sub_stage_progress: sub,
}, },
); );
}).await; })
.await;
emit_log( emit_log(
&window, &window,
&task.id, &task.id,
format!("vad: detected {} speech ranges", speech_ranges.len()), format!("vad: detected {} speech ranges", speech_ranges.len()),
)?; )?;
set_status(&window, &app_state, &mut task, TaskStatus::Transcribing, 30.0, "正在执行 Whisper")?; set_status(
&window,
&app_state,
&mut task,
TaskStatus::Transcribing,
30.0,
"正在执行 Whisper",
)?;
// Shared progress state between concurrent transcribing and translating // Shared progress state between concurrent transcribing and translating
let transcribing_pct: Arc<AtomicU32> = Arc::new(AtomicU32::new(0)); let transcribing_pct: Arc<AtomicU32> = Arc::new(AtomicU32::new(0));
@ -251,7 +302,11 @@ async fn run_pipeline(
.translation_config .translation_config
.clone() .clone()
.or_else(load_translation_config) .or_else(load_translation_config)
.ok_or_else(|| anyhow::anyhow!("翻译模式需要填写 LLM API 配置,或设置 OPENAI_API_BASE / OPENAI_API_KEY"))?; .ok_or_else(|| {
anyhow::anyhow!(
"翻译模式需要填写 LLM API 配置,或设置 OPENAI_API_BASE / OPENAI_API_KEY"
)
})?;
let translator = Translator::new(config)?; let translator = Translator::new(config)?;
let (tx, rx) = tokio::sync::mpsc::channel::<SubtitleSegment>(1024); let (tx, rx) = tokio::sync::mpsc::channel::<SubtitleSegment>(1024);
let window_for_worker = window.clone(); let window_for_worker = window.clone();
@ -399,8 +454,10 @@ async fn incremental_translate(
let mut all_segments: Vec<SubtitleSegment> = Vec::new(); let mut all_segments: Vec<SubtitleSegment> = Vec::new();
let mut buffer: Vec<SubtitleSegment> = Vec::new(); let mut buffer: Vec<SubtitleSegment> = Vec::new();
let mut translated_count: usize = 0; let mut translated_count: usize = 0;
let idle_flush_after = Duration::from_secs(3);
let emit_translate_progress = |window: &Window, task_id: &str, done: usize, total: usize| -> Result<()> { let emit_translate_progress =
|window: &Window, task_id: &str, done: usize, total: usize| -> Result<()> {
let ratio = if total > 0 { let ratio = if total > 0 {
(done as f32 / total as f32).clamp(0.0, 1.0) (done as f32 / total as f32).clamp(0.0, 1.0)
} else { } else {
@ -436,89 +493,68 @@ async fn incremental_translate(
Ok(()) Ok(())
}; };
while let Some(segment) = rx.recv().await { loop {
match tokio::time::timeout(idle_flush_after, rx.recv()).await {
Ok(Some(segment)) => {
all_segments.push(segment.clone()); all_segments.push(segment.clone());
buffer.push(segment); buffer.push(segment);
if buffer.len() >= batch_size { if buffer.len() >= batch_size {
let batch = std::mem::take(&mut buffer); translate_buffered_segments(
let context_end = all_segments.len().saturating_sub(batch.len()); &translator,
let context_start = context_end.saturating_sub(context_size); window,
let context = &all_segments[context_start..context_end]; app_state,
task_id,
emit_log(window, task_id, format!( target_lang,
"translation: batch segments={}", &all_segments,
batch.iter().map(|s| s.id.as_str()).collect::<Vec<_>>().join(", ") &mut buffer,
))?; &mut translated_count,
context_size,
let rows = translator "batch",
.translate_batch_with_retries(context, &batch, target_lang_name(target_lang)) )
.await?; .await?;
translated_count += rows.len();
emit_log(window, task_id, format!("translation: batch done, translated={}", rows.len()))?;
for row in rows {
if let Some(original) = batch.iter().find(|item| item.id == row.id) {
let mut emitted = original.clone();
emitted.translated_text = Some(row.text);
if let Ok(mut current_task) = app_state.get_task(task_id) {
upsert_segment(&mut current_task.segments, emitted.clone());
let _ = app_state.upsert_task(current_task);
}
let _ = window.emit(
"task:segment",
crate::models::SegmentEvent {
task_id: task_id.to_string(),
segment: emitted,
},
);
}
}
emit_translate_progress(window, task_id, translated_count, all_segments.len())?; emit_translate_progress(window, task_id, translated_count, all_segments.len())?;
} }
} }
Ok(None) => break,
Err(_) => {
if buffer.is_empty() {
continue;
}
translate_buffered_segments(
&translator,
window,
app_state,
task_id,
target_lang,
&all_segments,
&mut buffer,
&mut translated_count,
context_size,
"idle batch",
)
.await?;
emit_translate_progress(window, task_id, translated_count, all_segments.len())?;
}
}
}
// Flush remaining segments below batch_size // Flush remaining segments below batch_size
if !buffer.is_empty() { if !buffer.is_empty() {
let batch = std::mem::take(&mut buffer); translate_buffered_segments(
let context_end = all_segments.len().saturating_sub(batch.len()); &translator,
let context_start = context_end.saturating_sub(context_size); window,
let context = &all_segments[context_start..context_end]; app_state,
task_id,
emit_log(window, task_id, format!( target_lang,
"translation: final batch segments={}", &all_segments,
batch.iter().map(|s| s.id.as_str()).collect::<Vec<_>>().join(", ") &mut buffer,
))?; &mut translated_count,
context_size,
let rows = translator "final batch",
.translate_batch_with_retries(context, &batch, target_lang_name(target_lang)) )
.await?; .await?;
translated_count += rows.len();
for row in rows {
if let Some(original) = batch.iter().find(|item| item.id == row.id) {
let mut emitted = original.clone();
emitted.translated_text = Some(row.text);
if let Ok(mut current_task) = app_state.get_task(task_id) {
upsert_segment(&mut current_task.segments, emitted.clone());
let _ = app_state.upsert_task(current_task);
}
let _ = window.emit(
"task:segment",
crate::models::SegmentEvent {
task_id: task_id.to_string(),
segment: emitted,
},
);
}
}
emit_translate_progress(window, task_id, translated_count, all_segments.len())?; emit_translate_progress(window, task_id, translated_count, all_segments.len())?;
} }
@ -544,6 +580,75 @@ async fn incremental_translate(
Ok(()) Ok(())
} }
async fn translate_buffered_segments(
translator: &Translator,
window: &Window,
app_state: &AppState,
task_id: &str,
target_lang: &TargetLanguage,
all_segments: &[SubtitleSegment],
buffer: &mut Vec<SubtitleSegment>,
translated_count: &mut usize,
context_size: usize,
label: &str,
) -> Result<()> {
if buffer.is_empty() {
return Ok(());
}
let batch = std::mem::take(buffer);
let context_end = all_segments.len().saturating_sub(batch.len());
let context_start = context_end.saturating_sub(context_size);
let context = &all_segments[context_start..context_end];
emit_log(
window,
task_id,
format!(
"translation: {} segments={}",
label,
batch
.iter()
.map(|segment| segment.id.as_str())
.collect::<Vec<_>>()
.join(", ")
),
)?;
let rows = translator
.translate_batch_with_retries(context, &batch, target_lang_name(target_lang))
.await?;
*translated_count += rows.len();
emit_log(
window,
task_id,
format!("translation: {} done, translated={}", label, rows.len()),
)?;
for row in rows {
if let Some(original) = batch.iter().find(|item| item.id == row.id) {
let mut emitted = original.clone();
emitted.translated_text = Some(row.text);
if let Ok(mut current_task) = app_state.get_task(task_id) {
upsert_segment(&mut current_task.segments, emitted.clone());
let _ = app_state.upsert_task(current_task);
}
let _ = window.emit(
"task:segment",
crate::models::SegmentEvent {
task_id: task_id.to_string(),
segment: emitted,
},
);
}
}
Ok(())
}
fn target_lang_name(target_lang: &TargetLanguage) -> &'static str { fn target_lang_name(target_lang: &TargetLanguage) -> &'static str {
match target_lang { match target_lang {
TargetLanguage::Zh => "简体中文", TargetLanguage::Zh => "简体中文",
@ -616,7 +721,10 @@ fn set_status(
Ok(()) Ok(())
} }
pub fn update_segment_text(state: tauri::State<'_, AppState>, segment: SubtitleSegment) -> Result<SubtitleTask> { pub fn update_segment_text(
state: tauri::State<'_, AppState>,
segment: SubtitleSegment,
) -> Result<SubtitleTask> {
state.update_segment(segment) state.update_segment(segment)
} }
@ -628,7 +736,11 @@ pub fn delete_task(state: tauri::State<'_, AppState>, task_id: String) -> Result
state.delete_task(&task_id) state.delete_task(&task_id)
} }
pub fn export_task(state: tauri::State<'_, AppState>, task_id: String, format: String) -> Result<String> { pub fn export_task(
state: tauri::State<'_, AppState>,
task_id: String,
format: String,
) -> Result<String> {
let task = state.get_task(&task_id)?; let task = state.get_task(&task_id)?;
let format = SubtitleFormat::try_from(format.as_str())?; let format = SubtitleFormat::try_from(format.as_str())?;
let content = render(&task.segments, format, task.bilingual_output); let content = render(&task.segments, format, task.bilingual_output);
@ -803,7 +915,11 @@ fn upsert_segment(segments: &mut Vec<SubtitleSegment>, segment: SubtitleSegment)
left.start left.start
.partial_cmp(&right.start) .partial_cmp(&right.start)
.unwrap_or(std::cmp::Ordering::Equal) .unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| left.end.partial_cmp(&right.end).unwrap_or(std::cmp::Ordering::Equal)) .then_with(|| {
left.end
.partial_cmp(&right.end)
.unwrap_or(std::cmp::Ordering::Equal)
})
.then_with(|| left.id.cmp(&right.id)) .then_with(|| left.id.cmp(&right.id))
}); });
} }

View File

@ -120,7 +120,10 @@ impl Translator {
let rows = self let rows = self
.translate_batch_with_retries(context, batch, target_language_name) .translate_batch_with_retries(context, batch, target_language_name)
.await?; .await?;
log(format!("translation: batch done, translated={}", rows.len())); log(format!(
"translation: batch done, translated={}",
rows.len()
));
for row in rows { for row in rows {
if let Some(segment) = translated.iter_mut().find(|item| item.id == row.id) { if let Some(segment) = translated.iter_mut().find(|item| item.id == row.id) {
@ -257,7 +260,9 @@ impl Translator {
match response { match response {
Ok(response) => { Ok(response) => {
let response = response.error_for_status().context("translation http error")?; let response = response
.error_for_status()
.context("translation http error")?;
let raw_text = response.text().await.context("invalid response body")?; let raw_text = response.text().await.context("invalid response body")?;
eprintln!("translation raw response:\n{}", raw_text); eprintln!("translation raw response:\n{}", raw_text);
let payload: ChatCompletionResponse = let payload: ChatCompletionResponse =
@ -269,8 +274,9 @@ impl Translator {
.message .message
.content .content
.clone(); .clone();
let rows = parse_translation_response(&content) let rows = parse_translation_response(&content).with_context(|| {
.with_context(|| format!("translation json parse failed: {}", preview(&content)))?; format!("translation json parse failed: {}", preview(&content))
})?;
return Ok(rows); return Ok(rows);
} }
Err(error) => { Err(error) => {
@ -362,10 +368,7 @@ fn strip_code_fence(content: &str) -> String {
.trim_start_matches("```json") .trim_start_matches("```json")
.trim_start_matches("```JSON") .trim_start_matches("```JSON")
.trim_start_matches("```"); .trim_start_matches("```");
without_prefix without_prefix.trim_end_matches("```").trim().to_string()
.trim_end_matches("```")
.trim()
.to_string()
} }
fn extract_json_object(content: &str) -> Option<String> { fn extract_json_object(content: &str) -> Option<String> {
@ -403,7 +406,11 @@ fn mask_secret(secret: &str) -> String {
return "****".to_string(); return "****".to_string();
} }
format!("{}****{}", &secret[..4], &secret[secret.len().saturating_sub(4)..]) format!(
"{}****{}",
&secret[..4],
&secret[secret.len().saturating_sub(4)..]
)
} }
fn extract_rows_loose(content: &str) -> Vec<TranslatedRow> { fn extract_rows_loose(content: &str) -> Vec<TranslatedRow> {

View File

@ -69,7 +69,8 @@ impl VadEngine {
return None; return None;
} }
}; };
Self::detect_with_onnx(&mut session, &samples_owned, &config, on_progress_onnx).ok() Self::detect_with_onnx(&mut session, &samples_owned, &config, on_progress_onnx)
.ok()
}), }),
) )
.await .await
@ -156,7 +157,11 @@ impl VadEngine {
} }
on_progress(1.0); on_progress(1.0);
Ok(Self::merge_probabilities(&speech_probabilities, chunk_size, config)) Ok(Self::merge_probabilities(
&speech_probabilities,
chunk_size,
config,
))
} }
fn detect_segments_with_energy<F: Fn(f32)>( fn detect_segments_with_energy<F: Fn(f32)>(
@ -192,10 +197,19 @@ impl VadEngine {
energies.len(), energies.len(),
dynamic_threshold dynamic_threshold
); );
Self::merge_probabilities_with_threshold(&energies, frame_size, dynamic_threshold, &self.config) Self::merge_probabilities_with_threshold(
&energies,
frame_size,
dynamic_threshold,
&self.config,
)
} }
fn merge_probabilities(frames: &[f32], frame_size: usize, config: &VadConfig) -> Vec<(f32, f32)> { fn merge_probabilities(
frames: &[f32],
frame_size: usize,
config: &VadConfig,
) -> Vec<(f32, f32)> {
Self::merge_probabilities_with_threshold(frames, frame_size, config.threshold, config) Self::merge_probabilities_with_threshold(frames, frame_size, config.threshold, config)
} }
@ -228,7 +242,8 @@ impl VadEngine {
let end_frame = index.saturating_sub(silent_frames); let end_frame = index.saturating_sub(silent_frames);
if end_frame.saturating_sub(start) >= min_speech_frames { if end_frame.saturating_sub(start) >= min_speech_frames {
let start_sec = (start * frame_size) as f32 / config.sample_rate as f32; let start_sec = (start * frame_size) as f32 / config.sample_rate as f32;
let end_sec = ((end_frame + 1) * frame_size) as f32 / config.sample_rate as f32; let end_sec =
((end_frame + 1) * frame_size) as f32 / config.sample_rate as f32;
result.push(((start_sec - pad_seconds).max(0.0), end_sec + pad_seconds)); result.push(((start_sec - pad_seconds).max(0.0), end_sec + pad_seconds));
} }
start_frame = None; start_frame = None;

View File

@ -48,12 +48,12 @@ impl WhisperEngine {
let audio = load_audio_f32(wav_path)?; let audio = load_audio_f32(wav_path)?;
let total_seconds = audio.len() as f32 / 16_000.0; let total_seconds = audio.len() as f32 / 16_000.0;
let normalized_ranges = normalize_speech_ranges(speech_ranges, audio.len()); let normalized_ranges = normalize_speech_ranges(speech_ranges, audio.len());
let context = WhisperContext::new_with_params( let context =
model_path, WhisperContext::new_with_params(model_path, WhisperContextParameters::default())
WhisperContextParameters::default(),
)
.with_context(|| format!("failed to load whisper model: {model_path}"))?; .with_context(|| format!("failed to load whisper model: {model_path}"))?;
let mut state = context.create_state().context("failed to create whisper state")?; let mut state = context
.create_state()
.context("failed to create whisper state")?;
let detected_language = resolve_source_language(&mut state, &audio, source_lang) let detected_language = resolve_source_language(&mut state, &audio, source_lang)
.context("failed to resolve source language")?; .context("failed to resolve source language")?;
@ -187,7 +187,10 @@ impl WhisperEngine {
} }
} }
on_log(format!("whisper: total emitted segments={}", segments.len()))?; on_log(format!(
"whisper: total emitted segments={}",
segments.len()
))?;
Ok(segments) Ok(segments)
} }
} }
@ -226,7 +229,9 @@ fn transcribe_clip(
} }
} }
state.full(params, clip).context("whisper inference failed")?; state
.full(params, clip)
.context("whisper inference failed")?;
let num_segments = state.full_n_segments(); let num_segments = state.full_n_segments();
on_log(format!( on_log(format!(
@ -303,7 +308,10 @@ fn load_audio_f32(path: &Path) -> Result<Vec<f32>> {
.with_context(|| format!("failed to open wav file: {}", path.display()))?; .with_context(|| format!("failed to open wav file: {}", path.display()))?;
let spec = reader.spec(); let spec = reader.spec();
if spec.sample_rate != 16_000 { if spec.sample_rate != 16_000 {
return Err(anyhow!("whisper expects 16k audio, got {}", spec.sample_rate)); return Err(anyhow!(
"whisper expects 16k audio, got {}",
spec.sample_rate
));
} }
if spec.channels != 1 { if spec.channels != 1 {
return Err(anyhow!("whisper expects mono audio, got {}", spec.channels)); return Err(anyhow!("whisper expects mono audio, got {}", spec.channels));
@ -311,7 +319,11 @@ fn load_audio_f32(path: &Path) -> Result<Vec<f32>> {
let samples = reader let samples = reader
.into_samples::<i16>() .into_samples::<i16>()
.map(|sample| sample.map(|value| value as f32 / i16::MAX as f32).map_err(anyhow::Error::from)) .map(|sample| {
sample
.map(|value| value as f32 / i16::MAX as f32)
.map_err(anyhow::Error::from)
})
.collect::<Result<Vec<_>>>()?; .collect::<Result<Vec<_>>>()?;
Ok(samples) Ok(samples)
@ -372,7 +384,9 @@ fn should_prefer_full_audio(
full_text_len > vad_text_len + vad_text_len * 3 / 5 full_text_len > vad_text_len + vad_text_len * 3 / 5
|| full_audio_segments.len() > vad_segments.len() + 5 || full_audio_segments.len() > vad_segments.len() + 5
|| full_end > vad_end + 5.0 || full_end > vad_end + 5.0
|| (total_seconds > 60.0 && full_end + 1.5 >= total_seconds && vad_end + 5.0 < total_seconds) || (total_seconds > 60.0
&& full_end + 1.5 >= total_seconds
&& vad_end + 5.0 < total_seconds)
} }
fn resolve_source_language<'a>( fn resolve_source_language<'a>(