use std::time::Duration; use anyhow::{anyhow, Context, Result}; use reqwest::Client; use serde::{Deserialize, Serialize}; use tokio::time::sleep; use crate::models::{SubtitleSegment, TargetLanguage, TranslationConfig}; #[derive(Debug, Serialize)] struct ChatCompletionRequest { model: String, messages: Vec, temperature: f32, response_format: ResponseFormat, } #[derive(Debug, Serialize)] struct ChatMessage { role: String, content: String, } #[derive(Debug, Serialize)] struct ResponseFormat { #[serde(rename = "type")] format_type: String, } #[derive(Debug, Deserialize)] struct ChatCompletionResponse { choices: Vec, } #[derive(Debug, Deserialize)] struct ChatChoice { message: ChatMessageContent, } #[derive(Debug, Deserialize)] struct ChatMessageContent { content: String, } #[derive(Debug, Deserialize)] struct TranslationResponse { translations: Option>, items: Option>, results: Option>, } #[derive(Debug, Clone, Deserialize)] pub(crate) struct TranslatedRow { pub(crate) id: String, pub(crate) text: String, } pub struct Translator { client: Client, config: TranslationConfig, } impl Translator { pub fn new(config: TranslationConfig) -> Result { let client = Client::builder() .timeout(Duration::from_secs(60)) .build() .context("failed to build translation client")?; Ok(Self { client, config }) } pub fn batch_size(&self) -> usize { self.config.batch_size } pub fn context_size(&self) -> usize { self.config.context_size } pub async fn translate_segments_with_progress( &self, segments: &[SubtitleSegment], target_language: &TargetLanguage, mut log: LF, mut on_progress: PF, mut emit_segment: SF, ) -> Result> where LF: FnMut(String), PF: FnMut(f32), SF: FnMut(SubtitleSegment), { let batch_size = self.config.batch_size.clamp(3, 350); let context_size = self.config.context_size.min(5); let mut translated = segments.to_vec(); let target_language_name = match target_language { TargetLanguage::Zh => "简体中文", TargetLanguage::En => "英文", }; let total_batches = (segments.len() + batch_size - 1) / batch_size; for (batch_index, batch_start) in (0..segments.len()).step_by(batch_size).enumerate() { let batch_end = (batch_start + batch_size).min(segments.len()); let context_start = batch_start.saturating_sub(context_size); let context = &segments[context_start..batch_start]; let batch = &segments[batch_start..batch_end]; log(format!( "translation: batch {}-{}, segments={}", batch_start + 1, batch_end, batch .iter() .map(|segment| segment.id.as_str()) .collect::>() .join(", ") )); let batch_progress = (batch_index + 1) as f32 / total_batches.max(1) as f32; on_progress(batch_progress); let rows = self .translate_batch_with_retries(context, batch, target_language_name) .await?; log(format!("translation: batch done, translated={}", rows.len())); for row in rows { if let Some(segment) = translated.iter_mut().find(|item| item.id == row.id) { segment.translated_text = Some(row.text); emit_segment(segment.clone()); } } } Ok(translated) } pub(crate) async fn translate_batch_with_retries( &self, context: &[SubtitleSegment], batch: &[SubtitleSegment], target_language_name: &str, ) -> Result> { let mut collected = Vec::::new(); let mut pending = batch.to_vec(); for retry in 0..3 { if pending.is_empty() { break; } if retry > 0 { eprintln!( "translation retry: attempt={}, missing_segments={}", retry + 1, pending .iter() .map(|segment| segment.id.as_str()) .collect::>() .join(", ") ); } let rows = self .request_translation(context, &pending, target_language_name, retry) .await?; merge_rows(&mut collected, rows); let translated_ids = collected .iter() .map(|row| row.id.as_str()) .collect::>(); pending.retain(|segment| !translated_ids.contains(segment.id.as_str())); } if !pending.is_empty() { return Err(anyhow!( "translation missing segments after retries: {}", pending .iter() .map(|segment| segment.id.as_str()) .collect::>() .join(", ") )); } Ok(order_rows(batch, &collected)) } async fn request_translation( &self, context: &[SubtitleSegment], batch: &[SubtitleSegment], target_language_name: &str, retry: usize, ) -> Result> { let context_text = if context.is_empty() { "无".to_string() } else { context .iter() .map(|item| format!("{}: {}", item.id, item.source_text)) .collect::>() .join("\n") }; let batch_text = batch .iter() .map(|item| format!("{}: {}", item.id, item.source_text)) .collect::>() .join("\n"); let request = ChatCompletionRequest { model: self.config.model.clone(), temperature: 0.2, response_format: ResponseFormat { format_type: "json_object".to_string(), }, messages: vec![ ChatMessage { role: "system".to_string(), content: "你是专业字幕翻译助手。请保持人称、术语和语气一致,只输出 JSON。".to_string(), }, ChatMessage { role: "user".to_string(), content: format!( "{}把以下字幕翻译成{}。保持专有名词、角色称呼和上下文一致。必须逐条返回所有待翻译片段,禁止遗漏、合并、拆分或改写 id。上下文:\n{}\n\n待翻译片段:\n{}\n\n请返回 {{\"translations\":[{{\"id\":\"seg-0001\",\"text\":\"译文\"}}]}}", retry_prompt_prefix(retry), target_language_name, context_text, batch_text ), }, ], }; let url = format!( "{}/chat/completions", self.config.api_base.trim_end_matches('/') ); let request_json = serde_json::to_string_pretty(&request) .context("failed to serialize translation request")?; eprintln!( "translation request url: {}\ntranslation request headers: Authorization: Bearer {}\ntranslation request body:\n{}", url, mask_secret(&self.config.api_key), request_json ); let mut last_error: Option = None; for attempt in 0..3 { let response = self .client .post(&url) .bearer_auth(&self.config.api_key) .json(&request) .send() .await; match response { Ok(response) => { let response = response.error_for_status().context("translation http error")?; let raw_text = response.text().await.context("invalid response body")?; eprintln!("translation raw response:\n{}", raw_text); let payload: ChatCompletionResponse = serde_json::from_str(&raw_text).context("invalid response body")?; let content = payload .choices .first() .ok_or_else(|| anyhow!("translation response missing choices"))? .message .content .clone(); let rows = parse_translation_response(&content) .with_context(|| format!("translation json parse failed: {}", preview(&content)))?; return Ok(rows); } Err(error) => { last_error = Some(error.into()); sleep(Duration::from_millis(500 * (attempt + 1) as u64)).await; } } } Err(last_error.unwrap_or_else(|| anyhow!("translation request failed"))) } } fn retry_prompt_prefix(retry: usize) -> &'static str { if retry == 0 { "" } else { "这是补漏重试。你上次漏掉了部分片段,这次只需要返回当前待翻译片段对应的结果,确保每个 id 都出现且只出现一次。\n\n" } } fn merge_rows(collected: &mut Vec, rows: Vec) { for row in rows { if let Some(existing) = collected.iter_mut().find(|item| item.id == row.id) { if !row.text.trim().is_empty() { existing.text = row.text; } } else { collected.push(row); } } } fn order_rows(batch: &[SubtitleSegment], rows: &[TranslatedRow]) -> Vec { batch .iter() .filter_map(|segment| rows.iter().find(|row| row.id == segment.id).cloned()) .collect() } fn parse_translation_response(content: &str) -> Result> { let candidates = [ content.trim().to_string(), strip_code_fence(content), strip_think_block(content), extract_json_after_think(content).unwrap_or_default(), extract_last_json_object(content).unwrap_or_default(), extract_json_object(content).unwrap_or_default(), ]; for candidate in candidates { if candidate.trim().is_empty() { continue; } if let Ok(response) = serde_json::from_str::(&candidate) { if let Some(rows) = response .translations .or(response.items) .or(response.results) .filter(|rows| !rows.is_empty()) { return Ok(rows); } } if let Ok(rows) = serde_json::from_str::>(&candidate) { if !rows.is_empty() { return Ok(rows); } } let loose_rows = extract_rows_loose(&candidate); if !loose_rows.is_empty() { return Ok(loose_rows); } } Err(anyhow!("unable to parse translation response")) } fn strip_code_fence(content: &str) -> String { let trimmed = content.trim(); if !trimmed.starts_with("```") { return trimmed.to_string(); } let without_prefix = trimmed .trim_start_matches("```json") .trim_start_matches("```JSON") .trim_start_matches("```"); without_prefix .trim_end_matches("```") .trim() .to_string() } fn extract_json_object(content: &str) -> Option { let start = content.find('{')?; let end = content.rfind('}')?; (end > start).then(|| content[start..=end].trim().to_string()) } fn strip_think_block(content: &str) -> String { if let Some(end) = content.rfind("") { return content[end + "".len()..].trim().to_string(); } content.trim().to_string() } fn extract_json_after_think(content: &str) -> Option { let stripped = strip_think_block(content); extract_last_json_object(&stripped) } fn extract_last_json_object(content: &str) -> Option { let end = content.rfind('}')?; let start = content[..=end].rfind('{')?; (end > start).then(|| content[start..=end].trim().to_string()) } fn preview(content: &str) -> String { let compact = content.replace('\n', " "); compact.chars().take(240).collect() } fn mask_secret(secret: &str) -> String { if secret.len() <= 8 { return "****".to_string(); } format!("{}****{}", &secret[..4], &secret[secret.len().saturating_sub(4)..]) } fn extract_rows_loose(content: &str) -> Vec { let mut rows = Vec::new(); let mut cursor = 0usize; while let Some(id_key_pos) = content[cursor..].find("\"id\"") { let id_key_pos = cursor + id_key_pos; let Some((id, after_id)) = extract_field_value(content, id_key_pos, "id") else { cursor = id_key_pos + 4; continue; }; let Some(text_key_rel) = content[after_id..].find("\"text\"") else { cursor = after_id; continue; }; let text_key_pos = after_id + text_key_rel; let Some((text, after_text)) = extract_field_value(content, text_key_pos, "text") else { cursor = text_key_pos + 6; continue; }; rows.push(TranslatedRow { id, text }); cursor = after_text; } rows } fn extract_field_value(content: &str, key_pos: usize, key: &str) -> Option<(String, usize)> { let search_start = key_pos + key.len() + 2; let colon_rel = content[search_start..].find(':')?; let after_colon = search_start + colon_rel + 1; let first_quote_rel = content[after_colon..].find('"')?; let value_start = after_colon + first_quote_rel; let (value, next_index) = parse_json_string(content, value_start)?; Some((value, next_index)) } fn parse_json_string(content: &str, start_quote: usize) -> Option<(String, usize)> { let bytes = content.as_bytes(); if *bytes.get(start_quote)? != b'"' { return None; } let mut end = start_quote + 1; let mut escaped = false; while let Some(&byte) = bytes.get(end) { if escaped { escaped = false; end += 1; continue; } match byte { b'\\' => { escaped = true; end += 1; } b'"' => { let raw = &content[start_quote..=end]; let parsed = serde_json::from_str::(raw).ok()?; return Some((parsed, end + 1)); } _ => end += 1, } } None }