修复最后一批次的问题
This commit is contained in:
parent
28294a6eb2
commit
78c750bcbf
@ -22,7 +22,11 @@ impl AudioPipeline {
|
||||
let output_path = workspace.join("normalized.wav");
|
||||
let mut command = Command::new(ffmpeg_path);
|
||||
#[cfg(target_os = "macos")]
|
||||
if let Some(lib_dir) = ffmpeg_path.parent().and_then(|bin_dir| bin_dir.parent()).map(|root| root.join("lib")) {
|
||||
if let Some(lib_dir) = ffmpeg_path
|
||||
.parent()
|
||||
.and_then(|bin_dir| bin_dir.parent())
|
||||
.map(|root| root.join("lib"))
|
||||
{
|
||||
if lib_dir.exists() {
|
||||
command.env("DYLD_FALLBACK_LIBRARY_PATH", &lib_dir);
|
||||
}
|
||||
@ -70,7 +74,9 @@ impl AudioPipeline {
|
||||
}
|
||||
}
|
||||
|
||||
let status = child.wait().with_context(|| "ffmpeg process failed to wait")?;
|
||||
let status = child
|
||||
.wait()
|
||||
.with_context(|| "ffmpeg process failed to wait")?;
|
||||
if !status.success() {
|
||||
return Err(anyhow!("ffmpeg exited with status: {}", status));
|
||||
}
|
||||
@ -80,8 +86,8 @@ impl AudioPipeline {
|
||||
}
|
||||
|
||||
pub fn load_wav_f32(path: &Path) -> Result<Vec<f32>> {
|
||||
let mut reader =
|
||||
hound::WavReader::open(path).with_context(|| format!("failed to open {}", path.display()))?;
|
||||
let mut reader = hound::WavReader::open(path)
|
||||
.with_context(|| format!("failed to open {}", path.display()))?;
|
||||
let spec = reader.spec();
|
||||
|
||||
if spec.channels != 1 {
|
||||
@ -124,7 +130,9 @@ fn parse_ffmpeg_duration(line: &str) -> Option<f64> {
|
||||
fn parse_ffmpeg_time(line: &str) -> Option<f64> {
|
||||
let pos = line.find("time=")?;
|
||||
let rest = &line[pos + 5..];
|
||||
let end = rest.find(|c: char| !c.is_digit(10) && c != ':' && c != '.').unwrap_or(rest.len());
|
||||
let end = rest
|
||||
.find(|c: char| !c.is_digit(10) && c != ':' && c != '.')
|
||||
.unwrap_or(rest.len());
|
||||
let time_str = &rest[..end];
|
||||
let parts: Vec<&str> = time_str.split(':').collect();
|
||||
if parts.len() == 3 {
|
||||
|
||||
@ -10,17 +10,17 @@ mod whisper;
|
||||
use models::{
|
||||
DefaultModelPaths, StartTaskPayload, SubtitleSegment, SubtitleTask, TranslationConfig,
|
||||
};
|
||||
#[cfg(target_os = "macos")]
|
||||
use objc2_app_kit::NSWindow;
|
||||
#[cfg(target_os = "macos")]
|
||||
use objc2_foundation::NSSize;
|
||||
use state::AppState;
|
||||
use tauri::{AppHandle, Manager, PhysicalSize, Size};
|
||||
#[cfg(target_os = "macos")]
|
||||
use tauri::{
|
||||
menu::{MenuBuilder, MenuItemBuilder, PredefinedMenuItem, SubmenuBuilder},
|
||||
Emitter,
|
||||
};
|
||||
#[cfg(target_os = "macos")]
|
||||
use objc2_app_kit::NSWindow;
|
||||
#[cfg(target_os = "macos")]
|
||||
use objc2_foundation::NSSize;
|
||||
use tauri::{AppHandle, Manager, PhysicalSize, Size};
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
const WINDOW_RATIO_WIDTH: f64 = 16.0;
|
||||
@ -74,7 +74,9 @@ fn export_subtitles(
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
fn get_default_model_paths(app: tauri::AppHandle) -> std::result::Result<DefaultModelPaths, String> {
|
||||
fn get_default_model_paths(
|
||||
app: tauri::AppHandle,
|
||||
) -> std::result::Result<DefaultModelPaths, String> {
|
||||
task::get_default_model_paths(&app).map_err(error_to_string)
|
||||
}
|
||||
|
||||
@ -152,7 +154,11 @@ fn configure_macos_menu(app: &AppHandle) -> tauri::Result<()> {
|
||||
.build()?;
|
||||
|
||||
let file_menu = SubmenuBuilder::new(app, "文件")
|
||||
.item(&MenuItemBuilder::with_id("pick_files", "选择媒体文件").accelerator("CmdOrCtrl+O").build(app)?)
|
||||
.item(
|
||||
&MenuItemBuilder::with_id("pick_files", "选择媒体文件")
|
||||
.accelerator("CmdOrCtrl+O")
|
||||
.build(app)?,
|
||||
)
|
||||
.separator()
|
||||
.item(&MenuItemBuilder::with_id("export_srt", "导出 SRT").build(app)?)
|
||||
.item(&MenuItemBuilder::with_id("export_vtt", "导出 VTT").build(app)?)
|
||||
|
||||
@ -11,13 +11,19 @@ pub struct AppState {
|
||||
|
||||
impl AppState {
|
||||
pub fn upsert_task(&self, task: SubtitleTask) -> Result<()> {
|
||||
let mut guard = self.tasks.lock().map_err(|_| anyhow!("task store poisoned"))?;
|
||||
let mut guard = self
|
||||
.tasks
|
||||
.lock()
|
||||
.map_err(|_| anyhow!("task store poisoned"))?;
|
||||
guard.insert(task.id.clone(), task);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn get_task(&self, task_id: &str) -> Result<SubtitleTask> {
|
||||
let guard = self.tasks.lock().map_err(|_| anyhow!("task store poisoned"))?;
|
||||
let guard = self
|
||||
.tasks
|
||||
.lock()
|
||||
.map_err(|_| anyhow!("task store poisoned"))?;
|
||||
guard
|
||||
.get(task_id)
|
||||
.cloned()
|
||||
@ -25,20 +31,29 @@ impl AppState {
|
||||
}
|
||||
|
||||
pub fn list_tasks(&self) -> Result<Vec<SubtitleTask>> {
|
||||
let guard = self.tasks.lock().map_err(|_| anyhow!("task store poisoned"))?;
|
||||
let guard = self
|
||||
.tasks
|
||||
.lock()
|
||||
.map_err(|_| anyhow!("task store poisoned"))?;
|
||||
let mut tasks = guard.values().cloned().collect::<Vec<_>>();
|
||||
tasks.sort_by(|left, right| right.id.cmp(&left.id));
|
||||
Ok(tasks)
|
||||
}
|
||||
|
||||
pub fn delete_task(&self, task_id: &str) -> Result<()> {
|
||||
let mut guard = self.tasks.lock().map_err(|_| anyhow!("task store poisoned"))?;
|
||||
let mut guard = self
|
||||
.tasks
|
||||
.lock()
|
||||
.map_err(|_| anyhow!("task store poisoned"))?;
|
||||
guard.remove(task_id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn update_segment(&self, segment: SubtitleSegment) -> Result<SubtitleTask> {
|
||||
let mut guard = self.tasks.lock().map_err(|_| anyhow!("task store poisoned"))?;
|
||||
let mut guard = self
|
||||
.tasks
|
||||
.lock()
|
||||
.map_err(|_| anyhow!("task store poisoned"))?;
|
||||
let task = guard
|
||||
.get_mut(&segment.task_id)
|
||||
.ok_or_else(|| anyhow!("task not found: {}", segment.task_id))?;
|
||||
|
||||
@ -5,6 +5,7 @@ use std::{
|
||||
atomic::{AtomicU32, Ordering},
|
||||
Arc,
|
||||
},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
@ -14,9 +15,9 @@ use uuid::Uuid;
|
||||
use crate::{
|
||||
audio::AudioPipeline,
|
||||
models::{
|
||||
DefaultModelPaths, ErrorEvent, LogEvent, OutputMode, ProgressEvent,
|
||||
ResetSegmentsEvent, StartTaskPayload, SubStageProgress, SubtitleSegment,
|
||||
SubtitleTask, TargetLanguage, TaskStatus, TranslationConfig,
|
||||
DefaultModelPaths, ErrorEvent, LogEvent, OutputMode, ProgressEvent, ResetSegmentsEvent,
|
||||
StartTaskPayload, SubStageProgress, SubtitleSegment, SubtitleTask, TargetLanguage,
|
||||
TaskStatus, TranslationConfig,
|
||||
},
|
||||
state::AppState,
|
||||
subtitle::{render, SubtitleFormat},
|
||||
@ -45,7 +46,11 @@ pub async fn start_task(
|
||||
state: tauri::State<'_, AppState>,
|
||||
mut payload: StartTaskPayload,
|
||||
) -> Result<SubtitleTask> {
|
||||
if payload.whisper_model_path.as_deref().is_none_or(str::is_empty) {
|
||||
if payload
|
||||
.whisper_model_path
|
||||
.as_deref()
|
||||
.is_none_or(str::is_empty)
|
||||
{
|
||||
payload.whisper_model_path = resolve_default_model_path(&app, DEFAULT_WHISPER_MODEL);
|
||||
}
|
||||
if payload.vad_model_path.as_deref().is_none_or(str::is_empty) {
|
||||
@ -85,11 +90,21 @@ pub async fn start_task(
|
||||
let task_id = task.id.clone();
|
||||
|
||||
tauri::async_runtime::spawn(async move {
|
||||
if let Err(error) = run_pipeline(app_handle, window_handle.clone(), task_for_spawn, payload_for_spawn).await {
|
||||
if let Ok(mut failed_task) = app_handle_for_error.state::<AppState>().get_task(&task_id) {
|
||||
if let Err(error) = run_pipeline(
|
||||
app_handle,
|
||||
window_handle.clone(),
|
||||
task_for_spawn,
|
||||
payload_for_spawn,
|
||||
)
|
||||
.await
|
||||
{
|
||||
if let Ok(mut failed_task) = app_handle_for_error.state::<AppState>().get_task(&task_id)
|
||||
{
|
||||
failed_task.status = TaskStatus::Failed;
|
||||
failed_task.error = Some(error.to_string());
|
||||
let _ = app_handle_for_error.state::<AppState>().upsert_task(failed_task);
|
||||
let _ = app_handle_for_error
|
||||
.state::<AppState>()
|
||||
.upsert_task(failed_task);
|
||||
}
|
||||
let _ = emit_error(&window_handle, &task_id, &error.to_string());
|
||||
}
|
||||
@ -171,12 +186,28 @@ async fn run_pipeline(
|
||||
}
|
||||
}
|
||||
|
||||
let ffmpeg_path = resolve_ffmpeg_path(&app)
|
||||
.ok_or_else(|| anyhow::anyhow!("未找到可用 ffmpeg,请重新执行打包命令或在系统中安装 ffmpeg"))?;
|
||||
let ffmpeg_path = resolve_ffmpeg_path(&app).ok_or_else(|| {
|
||||
anyhow::anyhow!("未找到可用 ffmpeg,请重新执行打包命令或在系统中安装 ffmpeg")
|
||||
})?;
|
||||
|
||||
set_status(&window, &app_state, &mut task, TaskStatus::Extracting, 5.0, "正在抽取音频")?;
|
||||
emit_log(&window, &task.id, format!("task: input file={}", payload.file_path))?;
|
||||
emit_log(&window, &task.id, format!("audio: ffmpeg={}", ffmpeg_path.display()))?;
|
||||
set_status(
|
||||
&window,
|
||||
&app_state,
|
||||
&mut task,
|
||||
TaskStatus::Extracting,
|
||||
5.0,
|
||||
"正在抽取音频",
|
||||
)?;
|
||||
emit_log(
|
||||
&window,
|
||||
&task.id,
|
||||
format!("task: input file={}", payload.file_path),
|
||||
)?;
|
||||
emit_log(
|
||||
&window,
|
||||
&task.id,
|
||||
format!("audio: ffmpeg={}", ffmpeg_path.display()),
|
||||
)?;
|
||||
|
||||
let window_for_extract = window.clone();
|
||||
let task_id_for_extract = task.id.clone();
|
||||
@ -204,40 +235,60 @@ async fn run_pipeline(
|
||||
);
|
||||
},
|
||||
)?;
|
||||
emit_log(&window, &task.id, format!("audio: normalized wav={}", wav_path.display()))?;
|
||||
emit_log(
|
||||
&window,
|
||||
&task.id,
|
||||
format!("audio: normalized wav={}", wav_path.display()),
|
||||
)?;
|
||||
|
||||
set_status(&window, &app_state, &mut task, TaskStatus::VadProcessing, 15.0, "正在分析语音片段")?;
|
||||
set_status(
|
||||
&window,
|
||||
&app_state,
|
||||
&mut task,
|
||||
TaskStatus::VadProcessing,
|
||||
15.0,
|
||||
"正在分析语音片段",
|
||||
)?;
|
||||
let samples = AudioPipeline::load_wav_f32(&wav_path)?;
|
||||
let vad = VadEngine::new(payload.vad_model_path.clone(), VadConfig::default())?;
|
||||
|
||||
let window_for_vad = window.clone();
|
||||
let task_id_for_vad = task.id.clone();
|
||||
let speech_ranges = vad.detect_segments(&samples, move |ratio: f32| {
|
||||
let overall = 15.0 + ratio.clamp(0.0, 1.0) * 15.0;
|
||||
let sub = SubStageProgress {
|
||||
extracting: 100.0,
|
||||
vad: ratio.clamp(0.0, 1.0) * 100.0,
|
||||
transcribing: 0.0,
|
||||
translating: 0.0,
|
||||
};
|
||||
let _ = window_for_vad.emit(
|
||||
"task:progress",
|
||||
ProgressEvent {
|
||||
task_id: task_id_for_vad.clone(),
|
||||
status: TaskStatus::VadProcessing,
|
||||
progress: overall,
|
||||
message: "正在分析语音片段".to_string(),
|
||||
sub_stage_progress: sub,
|
||||
},
|
||||
);
|
||||
}).await;
|
||||
let speech_ranges = vad
|
||||
.detect_segments(&samples, move |ratio: f32| {
|
||||
let overall = 15.0 + ratio.clamp(0.0, 1.0) * 15.0;
|
||||
let sub = SubStageProgress {
|
||||
extracting: 100.0,
|
||||
vad: ratio.clamp(0.0, 1.0) * 100.0,
|
||||
transcribing: 0.0,
|
||||
translating: 0.0,
|
||||
};
|
||||
let _ = window_for_vad.emit(
|
||||
"task:progress",
|
||||
ProgressEvent {
|
||||
task_id: task_id_for_vad.clone(),
|
||||
status: TaskStatus::VadProcessing,
|
||||
progress: overall,
|
||||
message: "正在分析语音片段".to_string(),
|
||||
sub_stage_progress: sub,
|
||||
},
|
||||
);
|
||||
})
|
||||
.await;
|
||||
emit_log(
|
||||
&window,
|
||||
&task.id,
|
||||
format!("vad: detected {} speech ranges", speech_ranges.len()),
|
||||
)?;
|
||||
|
||||
set_status(&window, &app_state, &mut task, TaskStatus::Transcribing, 30.0, "正在执行 Whisper")?;
|
||||
set_status(
|
||||
&window,
|
||||
&app_state,
|
||||
&mut task,
|
||||
TaskStatus::Transcribing,
|
||||
30.0,
|
||||
"正在执行 Whisper",
|
||||
)?;
|
||||
|
||||
// Shared progress state between concurrent transcribing and translating
|
||||
let transcribing_pct: Arc<AtomicU32> = Arc::new(AtomicU32::new(0));
|
||||
@ -251,7 +302,11 @@ async fn run_pipeline(
|
||||
.translation_config
|
||||
.clone()
|
||||
.or_else(load_translation_config)
|
||||
.ok_or_else(|| anyhow::anyhow!("翻译模式需要填写 LLM API 配置,或设置 OPENAI_API_BASE / OPENAI_API_KEY"))?;
|
||||
.ok_or_else(|| {
|
||||
anyhow::anyhow!(
|
||||
"翻译模式需要填写 LLM API 配置,或设置 OPENAI_API_BASE / OPENAI_API_KEY"
|
||||
)
|
||||
})?;
|
||||
let translator = Translator::new(config)?;
|
||||
let (tx, rx) = tokio::sync::mpsc::channel::<SubtitleSegment>(1024);
|
||||
let window_for_worker = window.clone();
|
||||
@ -399,126 +454,107 @@ async fn incremental_translate(
|
||||
let mut all_segments: Vec<SubtitleSegment> = Vec::new();
|
||||
let mut buffer: Vec<SubtitleSegment> = Vec::new();
|
||||
let mut translated_count: usize = 0;
|
||||
let idle_flush_after = Duration::from_secs(3);
|
||||
|
||||
let emit_translate_progress = |window: &Window, task_id: &str, done: usize, total: usize| -> Result<()> {
|
||||
let ratio = if total > 0 {
|
||||
(done as f32 / total as f32).clamp(0.0, 1.0)
|
||||
} else {
|
||||
0.0
|
||||
let emit_translate_progress =
|
||||
|window: &Window, task_id: &str, done: usize, total: usize| -> Result<()> {
|
||||
let ratio = if total > 0 {
|
||||
(done as f32 / total as f32).clamp(0.0, 1.0)
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
translating_pct.store((ratio * 100.0) as u32, Ordering::Release);
|
||||
let transcribing = transcribing_pct.load(Ordering::Acquire) as f32 / 100.0;
|
||||
// translation progress must not exceed transcription progress
|
||||
let translating_capped = ratio.min(transcribing);
|
||||
|
||||
let overall = 30.0 + transcribing * 40.0 + translating_capped * 25.0;
|
||||
let sub = SubStageProgress {
|
||||
extracting: 100.0,
|
||||
vad: 100.0,
|
||||
transcribing: (transcribing * 100.0).min(100.0),
|
||||
translating: translating_capped * 100.0,
|
||||
};
|
||||
let status = if transcribing >= 1.0 {
|
||||
TaskStatus::Translating
|
||||
} else {
|
||||
TaskStatus::Transcribing
|
||||
};
|
||||
window.emit(
|
||||
"task:progress",
|
||||
ProgressEvent {
|
||||
task_id: task_id.to_string(),
|
||||
status,
|
||||
progress: overall.min(95.0),
|
||||
message: "正在生成译文".to_string(),
|
||||
sub_stage_progress: sub,
|
||||
},
|
||||
)?;
|
||||
Ok(())
|
||||
};
|
||||
translating_pct.store((ratio * 100.0) as u32, Ordering::Release);
|
||||
let transcribing = transcribing_pct.load(Ordering::Acquire) as f32 / 100.0;
|
||||
// translation progress must not exceed transcription progress
|
||||
let translating_capped = ratio.min(transcribing);
|
||||
|
||||
let overall = 30.0 + transcribing * 40.0 + translating_capped * 25.0;
|
||||
let sub = SubStageProgress {
|
||||
extracting: 100.0,
|
||||
vad: 100.0,
|
||||
transcribing: (transcribing * 100.0).min(100.0),
|
||||
translating: translating_capped * 100.0,
|
||||
};
|
||||
let status = if transcribing >= 1.0 {
|
||||
TaskStatus::Translating
|
||||
} else {
|
||||
TaskStatus::Transcribing
|
||||
};
|
||||
window.emit(
|
||||
"task:progress",
|
||||
ProgressEvent {
|
||||
task_id: task_id.to_string(),
|
||||
status,
|
||||
progress: overall.min(95.0),
|
||||
message: "正在生成译文".to_string(),
|
||||
sub_stage_progress: sub,
|
||||
},
|
||||
)?;
|
||||
Ok(())
|
||||
};
|
||||
loop {
|
||||
match tokio::time::timeout(idle_flush_after, rx.recv()).await {
|
||||
Ok(Some(segment)) => {
|
||||
all_segments.push(segment.clone());
|
||||
buffer.push(segment);
|
||||
|
||||
while let Some(segment) = rx.recv().await {
|
||||
all_segments.push(segment.clone());
|
||||
buffer.push(segment);
|
||||
|
||||
if buffer.len() >= batch_size {
|
||||
let batch = std::mem::take(&mut buffer);
|
||||
let context_end = all_segments.len().saturating_sub(batch.len());
|
||||
let context_start = context_end.saturating_sub(context_size);
|
||||
let context = &all_segments[context_start..context_end];
|
||||
|
||||
emit_log(window, task_id, format!(
|
||||
"translation: batch segments={}",
|
||||
batch.iter().map(|s| s.id.as_str()).collect::<Vec<_>>().join(", ")
|
||||
))?;
|
||||
|
||||
let rows = translator
|
||||
.translate_batch_with_retries(context, &batch, target_lang_name(target_lang))
|
||||
.await?;
|
||||
|
||||
translated_count += rows.len();
|
||||
|
||||
emit_log(window, task_id, format!("translation: batch done, translated={}", rows.len()))?;
|
||||
|
||||
for row in rows {
|
||||
if let Some(original) = batch.iter().find(|item| item.id == row.id) {
|
||||
let mut emitted = original.clone();
|
||||
emitted.translated_text = Some(row.text);
|
||||
|
||||
if let Ok(mut current_task) = app_state.get_task(task_id) {
|
||||
upsert_segment(&mut current_task.segments, emitted.clone());
|
||||
let _ = app_state.upsert_task(current_task);
|
||||
}
|
||||
let _ = window.emit(
|
||||
"task:segment",
|
||||
crate::models::SegmentEvent {
|
||||
task_id: task_id.to_string(),
|
||||
segment: emitted,
|
||||
},
|
||||
);
|
||||
if buffer.len() >= batch_size {
|
||||
translate_buffered_segments(
|
||||
&translator,
|
||||
window,
|
||||
app_state,
|
||||
task_id,
|
||||
target_lang,
|
||||
&all_segments,
|
||||
&mut buffer,
|
||||
&mut translated_count,
|
||||
context_size,
|
||||
"batch",
|
||||
)
|
||||
.await?;
|
||||
emit_translate_progress(window, task_id, translated_count, all_segments.len())?;
|
||||
}
|
||||
}
|
||||
Ok(None) => break,
|
||||
Err(_) => {
|
||||
if buffer.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
emit_translate_progress(window, task_id, translated_count, all_segments.len())?;
|
||||
translate_buffered_segments(
|
||||
&translator,
|
||||
window,
|
||||
app_state,
|
||||
task_id,
|
||||
target_lang,
|
||||
&all_segments,
|
||||
&mut buffer,
|
||||
&mut translated_count,
|
||||
context_size,
|
||||
"idle batch",
|
||||
)
|
||||
.await?;
|
||||
emit_translate_progress(window, task_id, translated_count, all_segments.len())?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Flush remaining segments below batch_size
|
||||
if !buffer.is_empty() {
|
||||
let batch = std::mem::take(&mut buffer);
|
||||
let context_end = all_segments.len().saturating_sub(batch.len());
|
||||
let context_start = context_end.saturating_sub(context_size);
|
||||
let context = &all_segments[context_start..context_end];
|
||||
|
||||
emit_log(window, task_id, format!(
|
||||
"translation: final batch segments={}",
|
||||
batch.iter().map(|s| s.id.as_str()).collect::<Vec<_>>().join(", ")
|
||||
))?;
|
||||
|
||||
let rows = translator
|
||||
.translate_batch_with_retries(context, &batch, target_lang_name(target_lang))
|
||||
.await?;
|
||||
|
||||
translated_count += rows.len();
|
||||
|
||||
for row in rows {
|
||||
if let Some(original) = batch.iter().find(|item| item.id == row.id) {
|
||||
let mut emitted = original.clone();
|
||||
emitted.translated_text = Some(row.text);
|
||||
|
||||
if let Ok(mut current_task) = app_state.get_task(task_id) {
|
||||
upsert_segment(&mut current_task.segments, emitted.clone());
|
||||
let _ = app_state.upsert_task(current_task);
|
||||
}
|
||||
let _ = window.emit(
|
||||
"task:segment",
|
||||
crate::models::SegmentEvent {
|
||||
task_id: task_id.to_string(),
|
||||
segment: emitted,
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
translate_buffered_segments(
|
||||
&translator,
|
||||
window,
|
||||
app_state,
|
||||
task_id,
|
||||
target_lang,
|
||||
&all_segments,
|
||||
&mut buffer,
|
||||
&mut translated_count,
|
||||
context_size,
|
||||
"final batch",
|
||||
)
|
||||
.await?;
|
||||
emit_translate_progress(window, task_id, translated_count, all_segments.len())?;
|
||||
}
|
||||
|
||||
@ -544,6 +580,75 @@ async fn incremental_translate(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn translate_buffered_segments(
|
||||
translator: &Translator,
|
||||
window: &Window,
|
||||
app_state: &AppState,
|
||||
task_id: &str,
|
||||
target_lang: &TargetLanguage,
|
||||
all_segments: &[SubtitleSegment],
|
||||
buffer: &mut Vec<SubtitleSegment>,
|
||||
translated_count: &mut usize,
|
||||
context_size: usize,
|
||||
label: &str,
|
||||
) -> Result<()> {
|
||||
if buffer.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let batch = std::mem::take(buffer);
|
||||
let context_end = all_segments.len().saturating_sub(batch.len());
|
||||
let context_start = context_end.saturating_sub(context_size);
|
||||
let context = &all_segments[context_start..context_end];
|
||||
|
||||
emit_log(
|
||||
window,
|
||||
task_id,
|
||||
format!(
|
||||
"translation: {} segments={}",
|
||||
label,
|
||||
batch
|
||||
.iter()
|
||||
.map(|segment| segment.id.as_str())
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ")
|
||||
),
|
||||
)?;
|
||||
|
||||
let rows = translator
|
||||
.translate_batch_with_retries(context, &batch, target_lang_name(target_lang))
|
||||
.await?;
|
||||
|
||||
*translated_count += rows.len();
|
||||
|
||||
emit_log(
|
||||
window,
|
||||
task_id,
|
||||
format!("translation: {} done, translated={}", label, rows.len()),
|
||||
)?;
|
||||
|
||||
for row in rows {
|
||||
if let Some(original) = batch.iter().find(|item| item.id == row.id) {
|
||||
let mut emitted = original.clone();
|
||||
emitted.translated_text = Some(row.text);
|
||||
|
||||
if let Ok(mut current_task) = app_state.get_task(task_id) {
|
||||
upsert_segment(&mut current_task.segments, emitted.clone());
|
||||
let _ = app_state.upsert_task(current_task);
|
||||
}
|
||||
let _ = window.emit(
|
||||
"task:segment",
|
||||
crate::models::SegmentEvent {
|
||||
task_id: task_id.to_string(),
|
||||
segment: emitted,
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn target_lang_name(target_lang: &TargetLanguage) -> &'static str {
|
||||
match target_lang {
|
||||
TargetLanguage::Zh => "简体中文",
|
||||
@ -616,7 +721,10 @@ fn set_status(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn update_segment_text(state: tauri::State<'_, AppState>, segment: SubtitleSegment) -> Result<SubtitleTask> {
|
||||
pub fn update_segment_text(
|
||||
state: tauri::State<'_, AppState>,
|
||||
segment: SubtitleSegment,
|
||||
) -> Result<SubtitleTask> {
|
||||
state.update_segment(segment)
|
||||
}
|
||||
|
||||
@ -628,7 +736,11 @@ pub fn delete_task(state: tauri::State<'_, AppState>, task_id: String) -> Result
|
||||
state.delete_task(&task_id)
|
||||
}
|
||||
|
||||
pub fn export_task(state: tauri::State<'_, AppState>, task_id: String, format: String) -> Result<String> {
|
||||
pub fn export_task(
|
||||
state: tauri::State<'_, AppState>,
|
||||
task_id: String,
|
||||
format: String,
|
||||
) -> Result<String> {
|
||||
let task = state.get_task(&task_id)?;
|
||||
let format = SubtitleFormat::try_from(format.as_str())?;
|
||||
let content = render(&task.segments, format, task.bilingual_output);
|
||||
@ -803,7 +915,11 @@ fn upsert_segment(segments: &mut Vec<SubtitleSegment>, segment: SubtitleSegment)
|
||||
left.start
|
||||
.partial_cmp(&right.start)
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
.then_with(|| left.end.partial_cmp(&right.end).unwrap_or(std::cmp::Ordering::Equal))
|
||||
.then_with(|| {
|
||||
left.end
|
||||
.partial_cmp(&right.end)
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
})
|
||||
.then_with(|| left.id.cmp(&right.id))
|
||||
});
|
||||
}
|
||||
|
||||
@ -120,7 +120,10 @@ impl Translator {
|
||||
let rows = self
|
||||
.translate_batch_with_retries(context, batch, target_language_name)
|
||||
.await?;
|
||||
log(format!("translation: batch done, translated={}", rows.len()));
|
||||
log(format!(
|
||||
"translation: batch done, translated={}",
|
||||
rows.len()
|
||||
));
|
||||
|
||||
for row in rows {
|
||||
if let Some(segment) = translated.iter_mut().find(|item| item.id == row.id) {
|
||||
@ -257,7 +260,9 @@ impl Translator {
|
||||
|
||||
match response {
|
||||
Ok(response) => {
|
||||
let response = response.error_for_status().context("translation http error")?;
|
||||
let response = response
|
||||
.error_for_status()
|
||||
.context("translation http error")?;
|
||||
let raw_text = response.text().await.context("invalid response body")?;
|
||||
eprintln!("translation raw response:\n{}", raw_text);
|
||||
let payload: ChatCompletionResponse =
|
||||
@ -269,8 +274,9 @@ impl Translator {
|
||||
.message
|
||||
.content
|
||||
.clone();
|
||||
let rows = parse_translation_response(&content)
|
||||
.with_context(|| format!("translation json parse failed: {}", preview(&content)))?;
|
||||
let rows = parse_translation_response(&content).with_context(|| {
|
||||
format!("translation json parse failed: {}", preview(&content))
|
||||
})?;
|
||||
return Ok(rows);
|
||||
}
|
||||
Err(error) => {
|
||||
@ -362,10 +368,7 @@ fn strip_code_fence(content: &str) -> String {
|
||||
.trim_start_matches("```json")
|
||||
.trim_start_matches("```JSON")
|
||||
.trim_start_matches("```");
|
||||
without_prefix
|
||||
.trim_end_matches("```")
|
||||
.trim()
|
||||
.to_string()
|
||||
without_prefix.trim_end_matches("```").trim().to_string()
|
||||
}
|
||||
|
||||
fn extract_json_object(content: &str) -> Option<String> {
|
||||
@ -403,7 +406,11 @@ fn mask_secret(secret: &str) -> String {
|
||||
return "****".to_string();
|
||||
}
|
||||
|
||||
format!("{}****{}", &secret[..4], &secret[secret.len().saturating_sub(4)..])
|
||||
format!(
|
||||
"{}****{}",
|
||||
&secret[..4],
|
||||
&secret[secret.len().saturating_sub(4)..]
|
||||
)
|
||||
}
|
||||
|
||||
fn extract_rows_loose(content: &str) -> Vec<TranslatedRow> {
|
||||
|
||||
@ -69,7 +69,8 @@ impl VadEngine {
|
||||
return None;
|
||||
}
|
||||
};
|
||||
Self::detect_with_onnx(&mut session, &samples_owned, &config, on_progress_onnx).ok()
|
||||
Self::detect_with_onnx(&mut session, &samples_owned, &config, on_progress_onnx)
|
||||
.ok()
|
||||
}),
|
||||
)
|
||||
.await
|
||||
@ -156,7 +157,11 @@ impl VadEngine {
|
||||
}
|
||||
|
||||
on_progress(1.0);
|
||||
Ok(Self::merge_probabilities(&speech_probabilities, chunk_size, config))
|
||||
Ok(Self::merge_probabilities(
|
||||
&speech_probabilities,
|
||||
chunk_size,
|
||||
config,
|
||||
))
|
||||
}
|
||||
|
||||
fn detect_segments_with_energy<F: Fn(f32)>(
|
||||
@ -192,10 +197,19 @@ impl VadEngine {
|
||||
energies.len(),
|
||||
dynamic_threshold
|
||||
);
|
||||
Self::merge_probabilities_with_threshold(&energies, frame_size, dynamic_threshold, &self.config)
|
||||
Self::merge_probabilities_with_threshold(
|
||||
&energies,
|
||||
frame_size,
|
||||
dynamic_threshold,
|
||||
&self.config,
|
||||
)
|
||||
}
|
||||
|
||||
fn merge_probabilities(frames: &[f32], frame_size: usize, config: &VadConfig) -> Vec<(f32, f32)> {
|
||||
fn merge_probabilities(
|
||||
frames: &[f32],
|
||||
frame_size: usize,
|
||||
config: &VadConfig,
|
||||
) -> Vec<(f32, f32)> {
|
||||
Self::merge_probabilities_with_threshold(frames, frame_size, config.threshold, config)
|
||||
}
|
||||
|
||||
@ -228,7 +242,8 @@ impl VadEngine {
|
||||
let end_frame = index.saturating_sub(silent_frames);
|
||||
if end_frame.saturating_sub(start) >= min_speech_frames {
|
||||
let start_sec = (start * frame_size) as f32 / config.sample_rate as f32;
|
||||
let end_sec = ((end_frame + 1) * frame_size) as f32 / config.sample_rate as f32;
|
||||
let end_sec =
|
||||
((end_frame + 1) * frame_size) as f32 / config.sample_rate as f32;
|
||||
result.push(((start_sec - pad_seconds).max(0.0), end_sec + pad_seconds));
|
||||
}
|
||||
start_frame = None;
|
||||
|
||||
@ -48,12 +48,12 @@ impl WhisperEngine {
|
||||
let audio = load_audio_f32(wav_path)?;
|
||||
let total_seconds = audio.len() as f32 / 16_000.0;
|
||||
let normalized_ranges = normalize_speech_ranges(speech_ranges, audio.len());
|
||||
let context = WhisperContext::new_with_params(
|
||||
model_path,
|
||||
WhisperContextParameters::default(),
|
||||
)
|
||||
.with_context(|| format!("failed to load whisper model: {model_path}"))?;
|
||||
let mut state = context.create_state().context("failed to create whisper state")?;
|
||||
let context =
|
||||
WhisperContext::new_with_params(model_path, WhisperContextParameters::default())
|
||||
.with_context(|| format!("failed to load whisper model: {model_path}"))?;
|
||||
let mut state = context
|
||||
.create_state()
|
||||
.context("failed to create whisper state")?;
|
||||
let detected_language = resolve_source_language(&mut state, &audio, source_lang)
|
||||
.context("failed to resolve source language")?;
|
||||
|
||||
@ -187,7 +187,10 @@ impl WhisperEngine {
|
||||
}
|
||||
}
|
||||
|
||||
on_log(format!("whisper: total emitted segments={}", segments.len()))?;
|
||||
on_log(format!(
|
||||
"whisper: total emitted segments={}",
|
||||
segments.len()
|
||||
))?;
|
||||
Ok(segments)
|
||||
}
|
||||
}
|
||||
@ -226,7 +229,9 @@ fn transcribe_clip(
|
||||
}
|
||||
}
|
||||
|
||||
state.full(params, clip).context("whisper inference failed")?;
|
||||
state
|
||||
.full(params, clip)
|
||||
.context("whisper inference failed")?;
|
||||
|
||||
let num_segments = state.full_n_segments();
|
||||
on_log(format!(
|
||||
@ -303,7 +308,10 @@ fn load_audio_f32(path: &Path) -> Result<Vec<f32>> {
|
||||
.with_context(|| format!("failed to open wav file: {}", path.display()))?;
|
||||
let spec = reader.spec();
|
||||
if spec.sample_rate != 16_000 {
|
||||
return Err(anyhow!("whisper expects 16k audio, got {}", spec.sample_rate));
|
||||
return Err(anyhow!(
|
||||
"whisper expects 16k audio, got {}",
|
||||
spec.sample_rate
|
||||
));
|
||||
}
|
||||
if spec.channels != 1 {
|
||||
return Err(anyhow!("whisper expects mono audio, got {}", spec.channels));
|
||||
@ -311,7 +319,11 @@ fn load_audio_f32(path: &Path) -> Result<Vec<f32>> {
|
||||
|
||||
let samples = reader
|
||||
.into_samples::<i16>()
|
||||
.map(|sample| sample.map(|value| value as f32 / i16::MAX as f32).map_err(anyhow::Error::from))
|
||||
.map(|sample| {
|
||||
sample
|
||||
.map(|value| value as f32 / i16::MAX as f32)
|
||||
.map_err(anyhow::Error::from)
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
|
||||
Ok(samples)
|
||||
@ -372,7 +384,9 @@ fn should_prefer_full_audio(
|
||||
full_text_len > vad_text_len + vad_text_len * 3 / 5
|
||||
|| full_audio_segments.len() > vad_segments.len() + 5
|
||||
|| full_end > vad_end + 5.0
|
||||
|| (total_seconds > 60.0 && full_end + 1.5 >= total_seconds && vad_end + 5.0 < total_seconds)
|
||||
|| (total_seconds > 60.0
|
||||
&& full_end + 1.5 >= total_seconds
|
||||
&& vad_end + 5.0 < total_seconds)
|
||||
}
|
||||
|
||||
fn resolve_source_language<'a>(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user