crosssubtitle-ai/src-tauri/src/task.rs
2026-03-19 15:37:19 +08:00

387 lines
13 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

use std::{
fs,
path::{Path, PathBuf},
};
use anyhow::{Context, Result};
use tauri::{Emitter, Manager, Window};
use uuid::Uuid;
use crate::{
audio::AudioPipeline,
models::{
DefaultModelPaths, ErrorEvent, LogEvent, OutputMode, ProgressEvent, ResetSegmentsEvent,
StartTaskPayload, SubtitleSegment, SubtitleTask, TaskStatus, TranslationConfig,
},
state::AppState,
subtitle::{render, SubtitleFormat},
translate::Translator,
vad::{VadConfig, VadEngine},
whisper::WhisperEngine,
};
const DEFAULT_WHISPER_MODEL: &str = "model/ggml-small-q5_1.bin";
const DEFAULT_VAD_MODEL: &str = "model/silero_vad.onnx";
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
const DEFAULT_FFMPEG_BINARY: &str = "vendor/ffmpeg/macos-arm64/bin/ffmpeg";
#[cfg(all(target_os = "macos", target_arch = "x86_64"))]
const DEFAULT_FFMPEG_BINARY: &str = "vendor/ffmpeg/macos-x86_64/bin/ffmpeg";
pub async fn start_task(
app: tauri::AppHandle,
window: Window,
state: tauri::State<'_, AppState>,
mut payload: StartTaskPayload,
) -> Result<SubtitleTask> {
if payload.whisper_model_path.as_deref().is_none_or(str::is_empty) {
payload.whisper_model_path = resolve_default_model_path(&app, DEFAULT_WHISPER_MODEL);
}
if payload.vad_model_path.as_deref().is_none_or(str::is_empty) {
payload.vad_model_path = resolve_default_model_path(&app, DEFAULT_VAD_MODEL);
}
if payload.source_lang.as_deref().is_none_or(str::is_empty) {
payload.source_lang = Some("auto".to_string());
}
let file_path = PathBuf::from(&payload.file_path);
let task = SubtitleTask {
id: Uuid::new_v4().to_string(),
file_name: file_path
.file_name()
.and_then(|item| item.to_str())
.unwrap_or("unknown")
.to_string(),
file_path: payload.file_path.clone(),
source_lang: payload.source_lang.clone(),
target_lang: payload.target_lang.clone(),
output_mode: payload.output_mode.clone(),
bilingual_output: payload.bilingual_output,
status: TaskStatus::Queued,
progress: 0.0,
segments: Vec::new(),
error: None,
};
state.upsert_task(task.clone())?;
let task_for_spawn = task.clone();
let payload_for_spawn = payload.clone();
let app_handle = app.clone();
let app_handle_for_error = app.clone();
let window_handle = window.clone();
let task_id = task.id.clone();
tauri::async_runtime::spawn(async move {
if let Err(error) = run_pipeline(app_handle, window_handle.clone(), task_for_spawn, payload_for_spawn).await {
if let Ok(mut failed_task) = app_handle_for_error.state::<AppState>().get_task(&task_id) {
failed_task.status = TaskStatus::Failed;
failed_task.error = Some(error.to_string());
let _ = app_handle_for_error.state::<AppState>().upsert_task(failed_task);
}
let _ = emit_error(&window_handle, &task_id, &error.to_string());
}
});
Ok(task)
}
pub fn get_default_model_paths(app: &tauri::AppHandle) -> Result<DefaultModelPaths> {
let whisper_model_path = resolve_default_model_path(app, DEFAULT_WHISPER_MODEL)
.ok_or_else(|| anyhow::anyhow!("未找到内置 Whisper 模型: {}", DEFAULT_WHISPER_MODEL))?;
let vad_model_path = resolve_default_model_path(app, DEFAULT_VAD_MODEL)
.ok_or_else(|| anyhow::anyhow!("未找到内置 VAD 模型: {}", DEFAULT_VAD_MODEL))?;
Ok(DefaultModelPaths {
whisper_model_path,
vad_model_path,
})
}
fn resolve_default_model_path(app: &tauri::AppHandle, relative_path: &str) -> Option<String> {
if let Ok(resource_dir) = app.path().resource_dir() {
let bundled_path = resource_dir.join(relative_path);
if bundled_path.exists() {
return Some(bundled_path.to_string_lossy().to_string());
}
}
let local_fallback_path = Path::new(env!("CARGO_MANIFEST_DIR")).join(relative_path);
if local_fallback_path.exists() {
return Some(local_fallback_path.to_string_lossy().to_string());
}
None
}
fn resolve_ffmpeg_path(app: &tauri::AppHandle) -> Option<PathBuf> {
if let Ok(resource_dir) = app.path().resource_dir() {
let bundled_path = resource_dir.join(DEFAULT_FFMPEG_BINARY);
if bundled_path.exists() {
return Some(bundled_path);
}
}
let local_bundled_path = Path::new(env!("CARGO_MANIFEST_DIR")).join(DEFAULT_FFMPEG_BINARY);
if local_bundled_path.exists() {
return Some(local_bundled_path);
}
let path_var = std::env::var_os("PATH")?;
for directory in std::env::split_paths(&path_var) {
let candidate = directory.join("ffmpeg");
if candidate.exists() {
return Some(candidate);
}
}
None
}
async fn run_pipeline(
app: tauri::AppHandle,
window: Window,
mut task: SubtitleTask,
payload: StartTaskPayload,
) -> Result<()> {
let app_state = app.state::<AppState>();
let workspace = std::env::temp_dir().join("crosssubtitle-ai").join(&task.id);
let should_translate = matches!(payload.output_mode, OutputMode::Translate);
let ffmpeg_path = resolve_ffmpeg_path(&app)
.ok_or_else(|| anyhow::anyhow!("未找到可用 ffmpeg请重新执行打包命令或在系统中安装 ffmpeg"))?;
set_status(&window, &app_state, &mut task, TaskStatus::Extracting, 8.0, "正在抽取音频")?;
emit_log(&window, &task.id, format!("task: input file={}", payload.file_path))?;
emit_log(&window, &task.id, format!("audio: ffmpeg={}", ffmpeg_path.display()))?;
let wav_path = AudioPipeline::extract_to_wav(&ffmpeg_path, &payload.file_path, &workspace)?;
emit_log(&window, &task.id, format!("audio: normalized wav={}", wav_path.display()))?;
set_status(&window, &app_state, &mut task, TaskStatus::VadProcessing, 22.0, "正在分析语音片段")?;
let samples = AudioPipeline::load_wav_f32(&wav_path)?;
let vad = VadEngine::new(payload.vad_model_path.clone(), VadConfig::default())?;
let speech_ranges = vad.detect_segments(&samples);
emit_log(
&window,
&task.id,
format!("vad: detected {} speech ranges", speech_ranges.len()),
)?;
set_status(&window, &app_state, &mut task, TaskStatus::Transcribing, 45.0, "正在执行 Whisper")?;
let whisper = WhisperEngine::new(payload.whisper_model_path.clone());
let task_id_for_progress = task.id.clone();
let task_id_for_segment = task.id.clone();
let task_id_for_reset = task.id.clone();
let task_id_for_log = task.id.clone();
let app_state_for_segment = app_state.clone();
let app_state_for_reset = app_state.clone();
let mut segments = whisper.infer_segments(
&wav_path,
&task.id,
task.source_lang.as_deref(),
&task.target_lang,
should_translate,
&speech_ranges,
|ratio| {
let progress = 45.0 + ratio.clamp(0.0, 1.0) * 27.0;
window.emit(
"task:progress",
ProgressEvent {
task_id: task_id_for_progress.clone(),
status: TaskStatus::Transcribing,
progress,
message: "正在执行 Whisper".to_string(),
},
)?;
Ok(())
},
|| {
if let Ok(mut current_task) = app_state_for_reset.get_task(&task_id_for_reset) {
current_task.segments.clear();
let _ = app_state_for_reset.upsert_task(current_task);
}
window.emit(
"task:segments_reset",
ResetSegmentsEvent {
task_id: task_id_for_reset.clone(),
},
)?;
Ok(())
},
|segment| {
if let Ok(mut current_task) = app_state_for_segment.get_task(&task_id_for_segment) {
upsert_segment(&mut current_task.segments, segment.clone());
let _ = app_state_for_segment.upsert_task(current_task);
}
window.emit(
"task:segment",
crate::models::SegmentEvent {
task_id: task_id_for_segment.clone(),
segment,
},
)?;
Ok(())
},
|message| emit_log(&window, &task_id_for_log, message),
)?;
task.segments = segments.clone();
app_state.upsert_task(task.clone())?;
if should_translate {
let config = payload
.translation_config
.clone()
.or_else(load_translation_config)
.ok_or_else(|| anyhow::anyhow!("翻译模式需要填写 LLM API 配置,或设置 OPENAI_API_BASE / OPENAI_API_KEY"))?;
set_status(&window, &app_state, &mut task, TaskStatus::Translating, 72.0, "正在生成译文")?;
let translator = Translator::new(config)?;
let task_id_for_translate = task.id.clone();
let app_state_for_translate = app_state.clone();
let window_for_translate = window.clone();
segments = translator
.translate_segments_with_progress(
&segments,
&task.target_lang,
|message| {
let _ = emit_log(&window_for_translate, &task_id_for_translate, message);
},
|segment| {
if let Ok(mut current_task) = app_state_for_translate.get_task(&task_id_for_translate) {
upsert_segment(&mut current_task.segments, segment.clone());
let _ = app_state_for_translate.upsert_task(current_task);
}
let _ = window_for_translate.emit(
"task:segment",
crate::models::SegmentEvent {
task_id: task_id_for_translate.clone(),
segment,
},
);
},
)
.await?;
task.segments = segments.clone();
app_state.upsert_task(task.clone())?;
for segment in segments {
window.emit(
"task:segment",
crate::models::SegmentEvent {
task_id: task.id.clone(),
segment,
},
)?;
}
}
task.status = TaskStatus::Completed;
task.progress = 100.0;
app_state.upsert_task(task.clone())?;
window.emit("task:done", task)?;
Ok(())
}
fn load_translation_config() -> Option<TranslationConfig> {
let api_base = std::env::var("OPENAI_API_BASE").ok()?;
let api_key = std::env::var("OPENAI_API_KEY").ok()?;
let model = std::env::var("OPENAI_MODEL").unwrap_or_else(|_| "GLM-4-Flash-250414".to_string());
Some(TranslationConfig {
api_base,
api_key,
model,
batch_size: 12,
context_size: 3,
})
}
fn set_status(
window: &Window,
state: &AppState,
task: &mut SubtitleTask,
status: TaskStatus,
progress: f32,
message: &str,
) -> Result<()> {
task.status = status.clone();
task.progress = progress;
state.upsert_task(task.clone())?;
window.emit(
"task:progress",
ProgressEvent {
task_id: task.id.clone(),
status,
progress,
message: message.to_string(),
},
)?;
Ok(())
}
pub fn update_segment_text(state: tauri::State<'_, AppState>, segment: SubtitleSegment) -> Result<SubtitleTask> {
state.update_segment(segment)
}
pub fn list_tasks(state: tauri::State<'_, AppState>) -> Result<Vec<SubtitleTask>> {
state.list_tasks()
}
pub fn export_task(state: tauri::State<'_, AppState>, task_id: String, format: String) -> Result<String> {
let task = state.get_task(&task_id)?;
let format = SubtitleFormat::try_from(format.as_str())?;
let content = render(&task.segments, format, task.bilingual_output);
let source_path = PathBuf::from(&task.file_path);
let stem = source_path
.file_stem()
.and_then(|item| item.to_str())
.unwrap_or("subtitle");
let output_dir = source_path
.parent()
.map(PathBuf::from)
.unwrap_or(std::env::current_dir().context("failed to get current directory")?);
fs::create_dir_all(&output_dir)?;
let output_path = output_dir.join(format!("{stem}.{}", format.extension()));
fs::write(&output_path, content)?;
Ok(output_path.display().to_string())
}
fn emit_error(window: &Window, task_id: &str, message: &str) -> Result<()> {
window.emit(
"task:error",
ErrorEvent {
task_id: task_id.to_string(),
message: message.to_string(),
},
)?;
Ok(())
}
fn emit_log(window: &Window, task_id: &str, message: String) -> Result<()> {
window.emit(
"task:log",
LogEvent {
task_id: task_id.to_string(),
message,
},
)?;
Ok(())
}
fn upsert_segment(segments: &mut Vec<SubtitleSegment>, segment: SubtitleSegment) {
if let Some(existing) = segments.iter_mut().find(|item| item.id == segment.id) {
*existing = segment;
} else {
segments.push(segment);
}
segments.sort_by(|left, right| {
left.start
.partial_cmp(&right.start)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| left.end.partial_cmp(&right.end).unwrap_or(std::cmp::Ordering::Equal))
.then_with(|| left.id.cmp(&right.id))
});
}