387 lines
13 KiB
Rust
387 lines
13 KiB
Rust
use std::{
|
||
fs,
|
||
path::{Path, PathBuf},
|
||
};
|
||
|
||
use anyhow::{Context, Result};
|
||
use tauri::{Emitter, Manager, Window};
|
||
use uuid::Uuid;
|
||
|
||
use crate::{
|
||
audio::AudioPipeline,
|
||
models::{
|
||
DefaultModelPaths, ErrorEvent, LogEvent, OutputMode, ProgressEvent, ResetSegmentsEvent,
|
||
StartTaskPayload, SubtitleSegment, SubtitleTask, TaskStatus, TranslationConfig,
|
||
},
|
||
state::AppState,
|
||
subtitle::{render, SubtitleFormat},
|
||
translate::Translator,
|
||
vad::{VadConfig, VadEngine},
|
||
whisper::WhisperEngine,
|
||
};
|
||
|
||
const DEFAULT_WHISPER_MODEL: &str = "model/ggml-small-q5_1.bin";
|
||
const DEFAULT_VAD_MODEL: &str = "model/silero_vad.onnx";
|
||
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
|
||
const DEFAULT_FFMPEG_BINARY: &str = "vendor/ffmpeg/macos-arm64/bin/ffmpeg";
|
||
#[cfg(all(target_os = "macos", target_arch = "x86_64"))]
|
||
const DEFAULT_FFMPEG_BINARY: &str = "vendor/ffmpeg/macos-x86_64/bin/ffmpeg";
|
||
|
||
pub async fn start_task(
|
||
app: tauri::AppHandle,
|
||
window: Window,
|
||
state: tauri::State<'_, AppState>,
|
||
mut payload: StartTaskPayload,
|
||
) -> Result<SubtitleTask> {
|
||
if payload.whisper_model_path.as_deref().is_none_or(str::is_empty) {
|
||
payload.whisper_model_path = resolve_default_model_path(&app, DEFAULT_WHISPER_MODEL);
|
||
}
|
||
if payload.vad_model_path.as_deref().is_none_or(str::is_empty) {
|
||
payload.vad_model_path = resolve_default_model_path(&app, DEFAULT_VAD_MODEL);
|
||
}
|
||
if payload.source_lang.as_deref().is_none_or(str::is_empty) {
|
||
payload.source_lang = Some("auto".to_string());
|
||
}
|
||
|
||
let file_path = PathBuf::from(&payload.file_path);
|
||
let task = SubtitleTask {
|
||
id: Uuid::new_v4().to_string(),
|
||
file_name: file_path
|
||
.file_name()
|
||
.and_then(|item| item.to_str())
|
||
.unwrap_or("unknown")
|
||
.to_string(),
|
||
file_path: payload.file_path.clone(),
|
||
source_lang: payload.source_lang.clone(),
|
||
target_lang: payload.target_lang.clone(),
|
||
output_mode: payload.output_mode.clone(),
|
||
bilingual_output: payload.bilingual_output,
|
||
status: TaskStatus::Queued,
|
||
progress: 0.0,
|
||
segments: Vec::new(),
|
||
error: None,
|
||
};
|
||
|
||
state.upsert_task(task.clone())?;
|
||
|
||
let task_for_spawn = task.clone();
|
||
let payload_for_spawn = payload.clone();
|
||
let app_handle = app.clone();
|
||
let app_handle_for_error = app.clone();
|
||
let window_handle = window.clone();
|
||
let task_id = task.id.clone();
|
||
|
||
tauri::async_runtime::spawn(async move {
|
||
if let Err(error) = run_pipeline(app_handle, window_handle.clone(), task_for_spawn, payload_for_spawn).await {
|
||
if let Ok(mut failed_task) = app_handle_for_error.state::<AppState>().get_task(&task_id) {
|
||
failed_task.status = TaskStatus::Failed;
|
||
failed_task.error = Some(error.to_string());
|
||
let _ = app_handle_for_error.state::<AppState>().upsert_task(failed_task);
|
||
}
|
||
let _ = emit_error(&window_handle, &task_id, &error.to_string());
|
||
}
|
||
});
|
||
|
||
Ok(task)
|
||
}
|
||
|
||
pub fn get_default_model_paths(app: &tauri::AppHandle) -> Result<DefaultModelPaths> {
|
||
let whisper_model_path = resolve_default_model_path(app, DEFAULT_WHISPER_MODEL)
|
||
.ok_or_else(|| anyhow::anyhow!("未找到内置 Whisper 模型: {}", DEFAULT_WHISPER_MODEL))?;
|
||
let vad_model_path = resolve_default_model_path(app, DEFAULT_VAD_MODEL)
|
||
.ok_or_else(|| anyhow::anyhow!("未找到内置 VAD 模型: {}", DEFAULT_VAD_MODEL))?;
|
||
|
||
Ok(DefaultModelPaths {
|
||
whisper_model_path,
|
||
vad_model_path,
|
||
})
|
||
}
|
||
|
||
fn resolve_default_model_path(app: &tauri::AppHandle, relative_path: &str) -> Option<String> {
|
||
if let Ok(resource_dir) = app.path().resource_dir() {
|
||
let bundled_path = resource_dir.join(relative_path);
|
||
if bundled_path.exists() {
|
||
return Some(bundled_path.to_string_lossy().to_string());
|
||
}
|
||
}
|
||
|
||
let local_fallback_path = Path::new(env!("CARGO_MANIFEST_DIR")).join(relative_path);
|
||
if local_fallback_path.exists() {
|
||
return Some(local_fallback_path.to_string_lossy().to_string());
|
||
}
|
||
|
||
None
|
||
}
|
||
|
||
fn resolve_ffmpeg_path(app: &tauri::AppHandle) -> Option<PathBuf> {
|
||
if let Ok(resource_dir) = app.path().resource_dir() {
|
||
let bundled_path = resource_dir.join(DEFAULT_FFMPEG_BINARY);
|
||
if bundled_path.exists() {
|
||
return Some(bundled_path);
|
||
}
|
||
}
|
||
|
||
let local_bundled_path = Path::new(env!("CARGO_MANIFEST_DIR")).join(DEFAULT_FFMPEG_BINARY);
|
||
if local_bundled_path.exists() {
|
||
return Some(local_bundled_path);
|
||
}
|
||
|
||
let path_var = std::env::var_os("PATH")?;
|
||
for directory in std::env::split_paths(&path_var) {
|
||
let candidate = directory.join("ffmpeg");
|
||
if candidate.exists() {
|
||
return Some(candidate);
|
||
}
|
||
}
|
||
|
||
None
|
||
}
|
||
|
||
async fn run_pipeline(
|
||
app: tauri::AppHandle,
|
||
window: Window,
|
||
mut task: SubtitleTask,
|
||
payload: StartTaskPayload,
|
||
) -> Result<()> {
|
||
let app_state = app.state::<AppState>();
|
||
let workspace = std::env::temp_dir().join("crosssubtitle-ai").join(&task.id);
|
||
let should_translate = matches!(payload.output_mode, OutputMode::Translate);
|
||
let ffmpeg_path = resolve_ffmpeg_path(&app)
|
||
.ok_or_else(|| anyhow::anyhow!("未找到可用 ffmpeg,请重新执行打包命令或在系统中安装 ffmpeg"))?;
|
||
|
||
set_status(&window, &app_state, &mut task, TaskStatus::Extracting, 8.0, "正在抽取音频")?;
|
||
emit_log(&window, &task.id, format!("task: input file={}", payload.file_path))?;
|
||
emit_log(&window, &task.id, format!("audio: ffmpeg={}", ffmpeg_path.display()))?;
|
||
let wav_path = AudioPipeline::extract_to_wav(&ffmpeg_path, &payload.file_path, &workspace)?;
|
||
emit_log(&window, &task.id, format!("audio: normalized wav={}", wav_path.display()))?;
|
||
|
||
set_status(&window, &app_state, &mut task, TaskStatus::VadProcessing, 22.0, "正在分析语音片段")?;
|
||
let samples = AudioPipeline::load_wav_f32(&wav_path)?;
|
||
let vad = VadEngine::new(payload.vad_model_path.clone(), VadConfig::default())?;
|
||
let speech_ranges = vad.detect_segments(&samples);
|
||
emit_log(
|
||
&window,
|
||
&task.id,
|
||
format!("vad: detected {} speech ranges", speech_ranges.len()),
|
||
)?;
|
||
|
||
set_status(&window, &app_state, &mut task, TaskStatus::Transcribing, 45.0, "正在执行 Whisper")?;
|
||
let whisper = WhisperEngine::new(payload.whisper_model_path.clone());
|
||
let task_id_for_progress = task.id.clone();
|
||
let task_id_for_segment = task.id.clone();
|
||
let task_id_for_reset = task.id.clone();
|
||
let task_id_for_log = task.id.clone();
|
||
let app_state_for_segment = app_state.clone();
|
||
let app_state_for_reset = app_state.clone();
|
||
let mut segments = whisper.infer_segments(
|
||
&wav_path,
|
||
&task.id,
|
||
task.source_lang.as_deref(),
|
||
&task.target_lang,
|
||
should_translate,
|
||
&speech_ranges,
|
||
|ratio| {
|
||
let progress = 45.0 + ratio.clamp(0.0, 1.0) * 27.0;
|
||
window.emit(
|
||
"task:progress",
|
||
ProgressEvent {
|
||
task_id: task_id_for_progress.clone(),
|
||
status: TaskStatus::Transcribing,
|
||
progress,
|
||
message: "正在执行 Whisper".to_string(),
|
||
},
|
||
)?;
|
||
Ok(())
|
||
},
|
||
|| {
|
||
if let Ok(mut current_task) = app_state_for_reset.get_task(&task_id_for_reset) {
|
||
current_task.segments.clear();
|
||
let _ = app_state_for_reset.upsert_task(current_task);
|
||
}
|
||
window.emit(
|
||
"task:segments_reset",
|
||
ResetSegmentsEvent {
|
||
task_id: task_id_for_reset.clone(),
|
||
},
|
||
)?;
|
||
Ok(())
|
||
},
|
||
|segment| {
|
||
if let Ok(mut current_task) = app_state_for_segment.get_task(&task_id_for_segment) {
|
||
upsert_segment(&mut current_task.segments, segment.clone());
|
||
let _ = app_state_for_segment.upsert_task(current_task);
|
||
}
|
||
window.emit(
|
||
"task:segment",
|
||
crate::models::SegmentEvent {
|
||
task_id: task_id_for_segment.clone(),
|
||
segment,
|
||
},
|
||
)?;
|
||
Ok(())
|
||
},
|
||
|message| emit_log(&window, &task_id_for_log, message),
|
||
)?;
|
||
|
||
task.segments = segments.clone();
|
||
app_state.upsert_task(task.clone())?;
|
||
|
||
if should_translate {
|
||
let config = payload
|
||
.translation_config
|
||
.clone()
|
||
.or_else(load_translation_config)
|
||
.ok_or_else(|| anyhow::anyhow!("翻译模式需要填写 LLM API 配置,或设置 OPENAI_API_BASE / OPENAI_API_KEY"))?;
|
||
set_status(&window, &app_state, &mut task, TaskStatus::Translating, 72.0, "正在生成译文")?;
|
||
let translator = Translator::new(config)?;
|
||
let task_id_for_translate = task.id.clone();
|
||
let app_state_for_translate = app_state.clone();
|
||
let window_for_translate = window.clone();
|
||
segments = translator
|
||
.translate_segments_with_progress(
|
||
&segments,
|
||
&task.target_lang,
|
||
|message| {
|
||
let _ = emit_log(&window_for_translate, &task_id_for_translate, message);
|
||
},
|
||
|segment| {
|
||
if let Ok(mut current_task) = app_state_for_translate.get_task(&task_id_for_translate) {
|
||
upsert_segment(&mut current_task.segments, segment.clone());
|
||
let _ = app_state_for_translate.upsert_task(current_task);
|
||
}
|
||
let _ = window_for_translate.emit(
|
||
"task:segment",
|
||
crate::models::SegmentEvent {
|
||
task_id: task_id_for_translate.clone(),
|
||
segment,
|
||
},
|
||
);
|
||
},
|
||
)
|
||
.await?;
|
||
task.segments = segments.clone();
|
||
app_state.upsert_task(task.clone())?;
|
||
|
||
for segment in segments {
|
||
window.emit(
|
||
"task:segment",
|
||
crate::models::SegmentEvent {
|
||
task_id: task.id.clone(),
|
||
segment,
|
||
},
|
||
)?;
|
||
}
|
||
}
|
||
|
||
task.status = TaskStatus::Completed;
|
||
task.progress = 100.0;
|
||
app_state.upsert_task(task.clone())?;
|
||
window.emit("task:done", task)?;
|
||
Ok(())
|
||
}
|
||
|
||
fn load_translation_config() -> Option<TranslationConfig> {
|
||
let api_base = std::env::var("OPENAI_API_BASE").ok()?;
|
||
let api_key = std::env::var("OPENAI_API_KEY").ok()?;
|
||
let model = std::env::var("OPENAI_MODEL").unwrap_or_else(|_| "GLM-4-Flash-250414".to_string());
|
||
Some(TranslationConfig {
|
||
api_base,
|
||
api_key,
|
||
model,
|
||
batch_size: 12,
|
||
context_size: 3,
|
||
})
|
||
}
|
||
|
||
fn set_status(
|
||
window: &Window,
|
||
state: &AppState,
|
||
task: &mut SubtitleTask,
|
||
status: TaskStatus,
|
||
progress: f32,
|
||
message: &str,
|
||
) -> Result<()> {
|
||
task.status = status.clone();
|
||
task.progress = progress;
|
||
state.upsert_task(task.clone())?;
|
||
window.emit(
|
||
"task:progress",
|
||
ProgressEvent {
|
||
task_id: task.id.clone(),
|
||
status,
|
||
progress,
|
||
message: message.to_string(),
|
||
},
|
||
)?;
|
||
Ok(())
|
||
}
|
||
|
||
pub fn update_segment_text(state: tauri::State<'_, AppState>, segment: SubtitleSegment) -> Result<SubtitleTask> {
|
||
state.update_segment(segment)
|
||
}
|
||
|
||
pub fn list_tasks(state: tauri::State<'_, AppState>) -> Result<Vec<SubtitleTask>> {
|
||
state.list_tasks()
|
||
}
|
||
|
||
pub fn export_task(state: tauri::State<'_, AppState>, task_id: String, format: String) -> Result<String> {
|
||
let task = state.get_task(&task_id)?;
|
||
let format = SubtitleFormat::try_from(format.as_str())?;
|
||
let content = render(&task.segments, format, task.bilingual_output);
|
||
|
||
let source_path = PathBuf::from(&task.file_path);
|
||
let stem = source_path
|
||
.file_stem()
|
||
.and_then(|item| item.to_str())
|
||
.unwrap_or("subtitle");
|
||
|
||
let output_dir = source_path
|
||
.parent()
|
||
.map(PathBuf::from)
|
||
.unwrap_or(std::env::current_dir().context("failed to get current directory")?);
|
||
fs::create_dir_all(&output_dir)?;
|
||
|
||
let output_path = output_dir.join(format!("{stem}.{}", format.extension()));
|
||
fs::write(&output_path, content)?;
|
||
|
||
Ok(output_path.display().to_string())
|
||
}
|
||
|
||
fn emit_error(window: &Window, task_id: &str, message: &str) -> Result<()> {
|
||
window.emit(
|
||
"task:error",
|
||
ErrorEvent {
|
||
task_id: task_id.to_string(),
|
||
message: message.to_string(),
|
||
},
|
||
)?;
|
||
Ok(())
|
||
}
|
||
|
||
fn emit_log(window: &Window, task_id: &str, message: String) -> Result<()> {
|
||
window.emit(
|
||
"task:log",
|
||
LogEvent {
|
||
task_id: task_id.to_string(),
|
||
message,
|
||
},
|
||
)?;
|
||
Ok(())
|
||
}
|
||
|
||
fn upsert_segment(segments: &mut Vec<SubtitleSegment>, segment: SubtitleSegment) {
|
||
if let Some(existing) = segments.iter_mut().find(|item| item.id == segment.id) {
|
||
*existing = segment;
|
||
} else {
|
||
segments.push(segment);
|
||
}
|
||
|
||
segments.sort_by(|left, right| {
|
||
left.start
|
||
.partial_cmp(&right.start)
|
||
.unwrap_or(std::cmp::Ordering::Equal)
|
||
.then_with(|| left.end.partial_cmp(&right.end).unwrap_or(std::cmp::Ordering::Equal))
|
||
.then_with(|| left.id.cmp(&right.id))
|
||
});
|
||
}
|