diff --git a/src-tauri/src/lib.rs b/src-tauri/src/lib.rs index cd3d8ec..8e74949 100644 --- a/src-tauri/src/lib.rs +++ b/src-tauri/src/lib.rs @@ -7,7 +7,9 @@ mod translate; mod vad; mod whisper; -use models::{DefaultModelPaths, StartTaskPayload, SubtitleSegment, SubtitleTask}; +use models::{ + DefaultModelPaths, StartTaskPayload, SubtitleSegment, SubtitleTask, TranslationConfig, +}; use state::AppState; use tauri::{ menu::{MenuBuilder, MenuItemBuilder, PredefinedMenuItem, SubmenuBuilder}, @@ -72,6 +74,19 @@ fn get_default_model_paths(app: tauri::AppHandle) -> std::result::Result, + task_id: String, + translation_config: TranslationConfig, +) -> std::result::Result { + task::retry_translation(app, window, state, task_id, translation_config) + .await + .map_err(error_to_string) +} + fn error_to_string(error: anyhow::Error) -> String { format!("{error:#}") } @@ -93,7 +108,8 @@ pub fn run() { update_segment_text, delete_task, export_subtitles, - get_default_model_paths + get_default_model_paths, + retry_translation ]) .run(tauri::generate_context!()) .expect("error while running tauri application"); diff --git a/src-tauri/src/task.rs b/src-tauri/src/task.rs index 7fa4a68..071c73d 100644 --- a/src-tauri/src/task.rs +++ b/src-tauri/src/task.rs @@ -10,8 +10,9 @@ use uuid::Uuid; use crate::{ audio::AudioPipeline, models::{ - DefaultModelPaths, ErrorEvent, LogEvent, OutputMode, ProgressEvent, ResetSegmentsEvent, - StartTaskPayload, SubtitleSegment, SubtitleTask, TaskStatus, TranslationConfig, + DefaultModelPaths, ErrorEvent, LogEvent, OutputMode, ProgressEvent, + ResetSegmentsEvent, StartTaskPayload, SubtitleSegment, SubtitleTask, + TargetLanguage, TaskStatus, TranslationConfig, }, state::AppState, subtitle::{render, SubtitleFormat}, @@ -185,6 +186,42 @@ async fn run_pipeline( )?; set_status(&window, &app_state, &mut task, TaskStatus::Transcribing, 30.0, "正在执行 Whisper")?; + + // Setup concurrent translation: as whisper emits segments, send them to + // the translation worker so it can start translating immediately in batches + let app_handle = app.clone(); + let (segment_tx, translate_join_handle) = if should_translate { + let config = payload + .translation_config + .clone() + .or_else(load_translation_config) + .ok_or_else(|| anyhow::anyhow!("翻译模式需要填写 LLM API 配置,或设置 OPENAI_API_BASE / OPENAI_API_KEY"))?; + let translator = Translator::new(config)?; + let (tx, rx) = tokio::sync::mpsc::channel::(1024); + let window_for_worker = window.clone(); + let task_id_for_worker = task.id.clone(); + let target_lang_for_worker = task.target_lang.clone(); + let app_handle_for_worker = app_handle.clone(); + let handle = tauri::async_runtime::spawn(async move { + let state = app_handle_for_worker.state::(); + if let Err(error) = incremental_translate( + translator, + rx, + &window_for_worker, + &state, + &task_id_for_worker, + &target_lang_for_worker, + ) + .await + { + eprintln!("incremental translation error: {error:#}"); + } + }); + (Some(tx), Some(handle)) + } else { + (None, None) + }; + let whisper = WhisperEngine::new(payload.whisper_model_path.clone()); let task_id_for_progress = task.id.clone(); let task_id_for_segment = task.id.clone(); @@ -192,7 +229,8 @@ async fn run_pipeline( let task_id_for_log = task.id.clone(); let app_state_for_segment = app_state.clone(); let app_state_for_reset = app_state.clone(); - let mut segments = whisper.infer_segments( + let seg_tx_for_callback = segment_tx.clone(); + let _segments = whisper.infer_segments( &wav_path, &task.id, task.source_lang.as_deref(), @@ -226,6 +264,9 @@ async fn run_pipeline( Ok(()) }, |segment| { + if let Some(ref tx) = seg_tx_for_callback { + let _ = tx.try_send(segment.clone()); + } if let Ok(mut current_task) = app_state_for_segment.get_task(&task_id_for_segment) { upsert_segment(&mut current_task.segments, segment.clone()); let _ = app_state_for_segment.upsert_task(current_task); @@ -242,70 +283,17 @@ async fn run_pipeline( |message| emit_log(&window, &task_id_for_log, message), )?; - task.segments = segments.clone(); - app_state.upsert_task(task.clone())?; - - if should_translate { - let config = payload - .translation_config - .clone() - .or_else(load_translation_config) - .ok_or_else(|| anyhow::anyhow!("翻译模式需要填写 LLM API 配置,或设置 OPENAI_API_BASE / OPENAI_API_KEY"))?; - set_status(&window, &app_state, &mut task, TaskStatus::Translating, 70.0, "正在生成译文")?; - let translator = Translator::new(config)?; - let task_id_for_translate = task.id.clone(); - let app_state_for_translate = app_state.clone(); - let window_for_translate = window.clone(); - let task_id_for_translate_progress = task.id.clone(); - let window_for_translate_progress = window.clone(); - segments = translator - .translate_segments_with_progress( - &segments, - &task.target_lang, - |message| { - let _ = emit_log(&window_for_translate, &task_id_for_translate, message); - }, - |ratio| { - let progress = 70.0 + ratio.clamp(0.0, 1.0) * 25.0; - let _ = window_for_translate_progress.emit( - "task:progress", - ProgressEvent { - task_id: task_id_for_translate_progress.clone(), - status: TaskStatus::Translating, - progress, - message: "正在生成译文".to_string(), - }, - ); - }, - |segment| { - if let Ok(mut current_task) = app_state_for_translate.get_task(&task_id_for_translate) { - upsert_segment(&mut current_task.segments, segment.clone()); - let _ = app_state_for_translate.upsert_task(current_task); - } - let _ = window_for_translate.emit( - "task:segment", - crate::models::SegmentEvent { - task_id: task_id_for_translate.clone(), - segment, - }, - ); - }, - ) - .await?; - task.segments = segments.clone(); - app_state.upsert_task(task.clone())?; - - for segment in segments { - window.emit( - "task:segment", - crate::models::SegmentEvent { - task_id: task.id.clone(), - segment, - }, - )?; - } + // Close channel to signal translation worker to flush and finish + drop(segment_tx); + if let Some(handle) = translate_join_handle { + handle.await.unwrap_or_else(|join_error| { + eprintln!("translation worker panicked: {join_error:?}"); + }); } + // Reload task from state (all segments and translations applied by callbacks) + task = app_state.get_task(&task.id)?; + task.status = TaskStatus::Completed; task.progress = 100.0; app_state.upsert_task(task.clone())?; @@ -313,6 +301,107 @@ async fn run_pipeline( Ok(()) } +async fn incremental_translate( + translator: Translator, + mut rx: tokio::sync::mpsc::Receiver, + window: &Window, + app_state: &AppState, + task_id: &str, + target_lang: &TargetLanguage, +) -> Result<()> { + let batch_size = translator.batch_size().clamp(10, 15); + let context_size = translator.context_size().min(5); + let mut all_segments: Vec = Vec::new(); + let mut buffer: Vec = Vec::new(); + + while let Some(segment) = rx.recv().await { + all_segments.push(segment.clone()); + buffer.push(segment); + + if buffer.len() >= batch_size { + let batch = std::mem::take(&mut buffer); + let context_end = all_segments.len().saturating_sub(batch.len()); + let context_start = context_end.saturating_sub(context_size); + let context = &all_segments[context_start..context_end]; + + emit_log(window, task_id, format!( + "translation: batch segments={}", + batch.iter().map(|s| s.id.as_str()).collect::>().join(", ") + ))?; + + let rows = translator + .translate_batch_with_retries(context, &batch, target_lang_name(target_lang)) + .await?; + + emit_log(window, task_id, format!("translation: batch done, translated={}", rows.len()))?; + + for row in rows { + if let Some(original) = batch.iter().find(|item| item.id == row.id) { + let mut emitted = original.clone(); + emitted.translated_text = Some(row.text); + + if let Ok(mut current_task) = app_state.get_task(task_id) { + upsert_segment(&mut current_task.segments, emitted.clone()); + let _ = app_state.upsert_task(current_task); + } + let _ = window.emit( + "task:segment", + crate::models::SegmentEvent { + task_id: task_id.to_string(), + segment: emitted, + }, + ); + } + } + } + } + + // Flush remaining segments below batch_size + if !buffer.is_empty() { + let batch = std::mem::take(&mut buffer); + let context_end = all_segments.len().saturating_sub(batch.len()); + let context_start = context_end.saturating_sub(context_size); + let context = &all_segments[context_start..context_end]; + + emit_log(window, task_id, format!( + "translation: final batch segments={}", + batch.iter().map(|s| s.id.as_str()).collect::>().join(", ") + ))?; + + let rows = translator + .translate_batch_with_retries(context, &batch, target_lang_name(target_lang)) + .await?; + + for row in rows { + if let Some(original) = batch.iter().find(|item| item.id == row.id) { + let mut emitted = original.clone(); + emitted.translated_text = Some(row.text); + + if let Ok(mut current_task) = app_state.get_task(task_id) { + upsert_segment(&mut current_task.segments, emitted.clone()); + let _ = app_state.upsert_task(current_task); + } + let _ = window.emit( + "task:segment", + crate::models::SegmentEvent { + task_id: task_id.to_string(), + segment: emitted, + }, + ); + } + } + } + + Ok(()) +} + +fn target_lang_name(target_lang: &TargetLanguage) -> &'static str { + match target_lang { + TargetLanguage::Zh => "简体中文", + TargetLanguage::En => "英文", + } +} + fn load_translation_config() -> Option { let api_base = std::env::var("OPENAI_API_BASE").ok()?; let api_key = std::env::var("OPENAI_API_KEY").ok()?; @@ -406,6 +495,118 @@ fn emit_log(window: &Window, task_id: &str, message: String) -> Result<()> { Ok(()) } +pub async fn retry_translation( + app: tauri::AppHandle, + window: Window, + state: tauri::State<'_, AppState>, + task_id: String, + translation_config: TranslationConfig, +) -> Result { + let task = state.get_task(&task_id)?; + + if task.segments.is_empty() { + return Err(anyhow::anyhow!("任务没有可翻译的字幕片段,请重新添加任务")); + } + + let mut initial_task = task.clone(); + set_status( + &window, + &state, + &mut initial_task, + TaskStatus::Translating, + 5.0, + "正在生成译文", + )?; + + let app_handle = app.clone(); + let window_handle = window.clone(); + let task_id_for_spawn = task.id.clone(); + let segments = task.segments.clone(); + let target_lang = task.target_lang.clone(); + + tauri::async_runtime::spawn(async move { + let result = async { + let state = app_handle.state::(); + let translator = Translator::new(translation_config)?; + + let task_id_for_progress = task_id_for_spawn.clone(); + let window_for_progress = window_handle.clone(); + let task_id_for_segment = task_id_for_spawn.clone(); + let window_for_segment = window_handle.clone(); + let app_handle_for_closures = app_handle.clone(); + + let translated_segments = translator + .translate_segments_with_progress( + &segments, + &target_lang, + |message| { + let _ = emit_log(&window_for_segment, &task_id_for_segment, message); + }, + |ratio| { + let progress = 5.0 + ratio.clamp(0.0, 1.0) * 90.0; + let _ = window_for_progress.emit( + "task:progress", + ProgressEvent { + task_id: task_id_for_progress.clone(), + status: TaskStatus::Translating, + progress, + message: "正在生成译文".to_string(), + }, + ); + }, + |segment| { + let state = app_handle_for_closures.state::(); + if let Ok(mut current_task) = state.get_task(&task_id_for_segment) { + upsert_segment(&mut current_task.segments, segment.clone()); + let _ = state.upsert_task(current_task); + } + let _ = window_for_segment.emit( + "task:segment", + crate::models::SegmentEvent { + task_id: task_id_for_segment.clone(), + segment, + }, + ); + }, + ) + .await?; + + let mut current_task = state.get_task(&task_id_for_spawn)?; + current_task.segments = translated_segments.clone(); + current_task.status = TaskStatus::Completed; + current_task.progress = 100.0; + state.upsert_task(current_task.clone())?; + + for segment in translated_segments { + window_handle.emit( + "task:segment", + crate::models::SegmentEvent { + task_id: task_id_for_spawn.clone(), + segment, + }, + )?; + } + + window_handle.emit("task:done", current_task)?; + + Ok::<_, anyhow::Error>(()) + } + .await; + + if let Err(error) = result { + let state = app_handle.state::(); + if let Ok(mut failed_task) = state.get_task(&task_id_for_spawn) { + failed_task.status = TaskStatus::Failed; + failed_task.error = Some(error.to_string()); + let _ = state.upsert_task(failed_task); + } + let _ = emit_error(&window_handle, &task_id_for_spawn, &error.to_string()); + } + }); + + Ok(task) +} + fn upsert_segment(segments: &mut Vec, segment: SubtitleSegment) { if let Some(existing) = segments.iter_mut().find(|item| item.id == segment.id) { *existing = segment; diff --git a/src-tauri/src/translate.rs b/src-tauri/src/translate.rs index b45e45a..306a202 100644 --- a/src-tauri/src/translate.rs +++ b/src-tauri/src/translate.rs @@ -50,9 +50,9 @@ struct TranslationResponse { } #[derive(Debug, Clone, Deserialize)] -struct TranslatedRow { - id: String, - text: String, +pub(crate) struct TranslatedRow { + pub(crate) id: String, + pub(crate) text: String, } pub struct Translator { @@ -70,6 +70,14 @@ impl Translator { Ok(Self { client, config }) } + pub fn batch_size(&self) -> usize { + self.config.batch_size + } + + pub fn context_size(&self) -> usize { + self.config.context_size + } + pub async fn translate_segments_with_progress( &self, segments: &[SubtitleSegment], @@ -125,7 +133,7 @@ impl Translator { Ok(translated) } - async fn translate_batch_with_retries( + pub(crate) async fn translate_batch_with_retries( &self, context: &[SubtitleSegment], batch: &[SubtitleSegment], @@ -177,6 +185,31 @@ impl Translator { Ok(order_rows(batch, &collected)) } + /// Translate a batch of segments with retries, returning the batch with `translated_text` filled in. + pub async fn translate_batch( + &self, + context: &[SubtitleSegment], + batch: &[SubtitleSegment], + target_language: &TargetLanguage, + ) -> Result> { + let target_language_name = match target_language { + TargetLanguage::Zh => "简体中文", + TargetLanguage::En => "英文", + }; + let rows = self + .translate_batch_with_retries(context, batch, target_language_name) + .await?; + + let mut result = batch.to_vec(); + for row in rows { + if let Some(segment) = result.iter_mut().find(|item| item.id == row.id) { + segment.translated_text = Some(row.text); + } + } + + Ok(result) + } + async fn request_translation( &self, context: &[SubtitleSegment], diff --git a/src/App.vue b/src/App.vue index ea5eeeb..b84bf37 100644 --- a/src/App.vue +++ b/src/App.vue @@ -241,6 +241,25 @@ async function handleExport(format: 'srt' | 'vtt' | 'ass') { const output = await taskStore.exportTask(selectedTask.value.id, format) feedback.value = output } + +async function handleRetryTranslate(taskId: string) { + persistTranslationConfig() + if (!translationConfig.value.apiKey.trim()) { + feedback.value = t('app.feedback.noApiKey') + return + } + try { + await taskStore.retryTranslation(taskId, translationConfig.value) + feedback.value = t('app.feedback.translationStarted') + } catch (error) { + feedback.value = error instanceof Error ? error.message : t('app.feedback.translationFailed') + } +} + +async function handleTranslateFromEditor() { + if (!selectedTask.value) return + await handleRetryTranslate(selectedTask.value.id) +}