crosssubtitle-ai/src-tauri/src/translate.rs
2026-05-02 16:10:27 +08:00

476 lines
15 KiB
Rust

use std::time::Duration;
use anyhow::{anyhow, Context, Result};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use tokio::time::sleep;
use crate::models::{SubtitleSegment, TargetLanguage, TranslationConfig};
#[derive(Debug, Serialize)]
struct ChatCompletionRequest {
model: String,
messages: Vec<ChatMessage>,
temperature: f32,
response_format: ResponseFormat,
}
#[derive(Debug, Serialize)]
struct ChatMessage {
role: String,
content: String,
}
#[derive(Debug, Serialize)]
struct ResponseFormat {
#[serde(rename = "type")]
format_type: String,
}
#[derive(Debug, Deserialize)]
struct ChatCompletionResponse {
choices: Vec<ChatChoice>,
}
#[derive(Debug, Deserialize)]
struct ChatChoice {
message: ChatMessageContent,
}
#[derive(Debug, Deserialize)]
struct ChatMessageContent {
content: String,
}
#[derive(Debug, Deserialize)]
struct TranslationResponse {
translations: Option<Vec<TranslatedRow>>,
items: Option<Vec<TranslatedRow>>,
results: Option<Vec<TranslatedRow>>,
}
#[derive(Debug, Clone, Deserialize)]
pub(crate) struct TranslatedRow {
pub(crate) id: String,
pub(crate) text: String,
}
pub struct Translator {
client: Client,
config: TranslationConfig,
}
impl Translator {
pub fn new(config: TranslationConfig) -> Result<Self> {
let client = Client::builder()
.timeout(Duration::from_secs(60))
.build()
.context("failed to build translation client")?;
Ok(Self { client, config })
}
pub fn batch_size(&self) -> usize {
self.config.batch_size
}
pub fn context_size(&self) -> usize {
self.config.context_size
}
pub async fn translate_segments_with_progress<LF, PF, SF>(
&self,
segments: &[SubtitleSegment],
target_language: &TargetLanguage,
mut log: LF,
mut on_progress: PF,
mut emit_segment: SF,
) -> Result<Vec<SubtitleSegment>>
where
LF: FnMut(String),
PF: FnMut(f32),
SF: FnMut(SubtitleSegment),
{
let batch_size = self.config.batch_size.clamp(3, 350);
let context_size = self.config.context_size.min(5);
let mut translated = segments.to_vec();
let target_language_name = match target_language {
TargetLanguage::Zh => "简体中文",
TargetLanguage::En => "英文",
};
let total_batches = (segments.len() + batch_size - 1) / batch_size;
for (batch_index, batch_start) in (0..segments.len()).step_by(batch_size).enumerate() {
let batch_end = (batch_start + batch_size).min(segments.len());
let context_start = batch_start.saturating_sub(context_size);
let context = &segments[context_start..batch_start];
let batch = &segments[batch_start..batch_end];
log(format!(
"translation: batch {}-{}, segments={}",
batch_start + 1,
batch_end,
batch
.iter()
.map(|segment| segment.id.as_str())
.collect::<Vec<_>>()
.join(", ")
));
let batch_progress = (batch_index + 1) as f32 / total_batches.max(1) as f32;
on_progress(batch_progress);
let rows = self
.translate_batch_with_retries(context, batch, target_language_name)
.await?;
log(format!("translation: batch done, translated={}", rows.len()));
for row in rows {
if let Some(segment) = translated.iter_mut().find(|item| item.id == row.id) {
segment.translated_text = Some(row.text);
emit_segment(segment.clone());
}
}
}
Ok(translated)
}
pub(crate) async fn translate_batch_with_retries(
&self,
context: &[SubtitleSegment],
batch: &[SubtitleSegment],
target_language_name: &str,
) -> Result<Vec<TranslatedRow>> {
let mut collected = Vec::<TranslatedRow>::new();
let mut pending = batch.to_vec();
for retry in 0..3 {
if pending.is_empty() {
break;
}
if retry > 0 {
eprintln!(
"translation retry: attempt={}, missing_segments={}",
retry + 1,
pending
.iter()
.map(|segment| segment.id.as_str())
.collect::<Vec<_>>()
.join(", ")
);
}
let rows = self
.request_translation(context, &pending, target_language_name, retry)
.await?;
merge_rows(&mut collected, rows);
let translated_ids = collected
.iter()
.map(|row| row.id.as_str())
.collect::<std::collections::HashSet<_>>();
pending.retain(|segment| !translated_ids.contains(segment.id.as_str()));
}
if !pending.is_empty() {
return Err(anyhow!(
"translation missing segments after retries: {}",
pending
.iter()
.map(|segment| segment.id.as_str())
.collect::<Vec<_>>()
.join(", ")
));
}
Ok(order_rows(batch, &collected))
}
async fn request_translation(
&self,
context: &[SubtitleSegment],
batch: &[SubtitleSegment],
target_language_name: &str,
retry: usize,
) -> Result<Vec<TranslatedRow>> {
let context_text = if context.is_empty() {
"".to_string()
} else {
context
.iter()
.map(|item| format!("{}: {}", item.id, item.source_text))
.collect::<Vec<_>>()
.join("\n")
};
let batch_text = batch
.iter()
.map(|item| format!("{}: {}", item.id, item.source_text))
.collect::<Vec<_>>()
.join("\n");
let request = ChatCompletionRequest {
model: self.config.model.clone(),
temperature: 0.2,
response_format: ResponseFormat {
format_type: "json_object".to_string(),
},
messages: vec![
ChatMessage {
role: "system".to_string(),
content: "你是专业字幕翻译助手。请保持人称、术语和语气一致,只输出 JSON。".to_string(),
},
ChatMessage {
role: "user".to_string(),
content: format!(
"{}把以下字幕翻译成{}。保持专有名词、角色称呼和上下文一致。必须逐条返回所有待翻译片段,禁止遗漏、合并、拆分或改写 id。上下文:\n{}\n\n待翻译片段:\n{}\n\n请返回 {{\"translations\":[{{\"id\":\"seg-0001\",\"text\":\"译文\"}}]}}",
retry_prompt_prefix(retry),
target_language_name,
context_text,
batch_text
),
},
],
};
let url = format!(
"{}/chat/completions",
self.config.api_base.trim_end_matches('/')
);
let request_json = serde_json::to_string_pretty(&request)
.context("failed to serialize translation request")?;
eprintln!(
"translation request url: {}\ntranslation request headers: Authorization: Bearer {}\ntranslation request body:\n{}",
url,
mask_secret(&self.config.api_key),
request_json
);
let mut last_error: Option<anyhow::Error> = None;
for attempt in 0..3 {
let response = self
.client
.post(&url)
.bearer_auth(&self.config.api_key)
.json(&request)
.send()
.await;
match response {
Ok(response) => {
let response = response.error_for_status().context("translation http error")?;
let raw_text = response.text().await.context("invalid response body")?;
eprintln!("translation raw response:\n{}", raw_text);
let payload: ChatCompletionResponse =
serde_json::from_str(&raw_text).context("invalid response body")?;
let content = payload
.choices
.first()
.ok_or_else(|| anyhow!("translation response missing choices"))?
.message
.content
.clone();
let rows = parse_translation_response(&content)
.with_context(|| format!("translation json parse failed: {}", preview(&content)))?;
return Ok(rows);
}
Err(error) => {
last_error = Some(error.into());
sleep(Duration::from_millis(500 * (attempt + 1) as u64)).await;
}
}
}
Err(last_error.unwrap_or_else(|| anyhow!("translation request failed")))
}
}
fn retry_prompt_prefix(retry: usize) -> &'static str {
if retry == 0 {
""
} else {
"这是补漏重试。你上次漏掉了部分片段,这次只需要返回当前待翻译片段对应的结果,确保每个 id 都出现且只出现一次。\n\n"
}
}
fn merge_rows(collected: &mut Vec<TranslatedRow>, rows: Vec<TranslatedRow>) {
for row in rows {
if let Some(existing) = collected.iter_mut().find(|item| item.id == row.id) {
if !row.text.trim().is_empty() {
existing.text = row.text;
}
} else {
collected.push(row);
}
}
}
fn order_rows(batch: &[SubtitleSegment], rows: &[TranslatedRow]) -> Vec<TranslatedRow> {
batch
.iter()
.filter_map(|segment| rows.iter().find(|row| row.id == segment.id).cloned())
.collect()
}
fn parse_translation_response(content: &str) -> Result<Vec<TranslatedRow>> {
let candidates = [
content.trim().to_string(),
strip_code_fence(content),
strip_think_block(content),
extract_json_after_think(content).unwrap_or_default(),
extract_last_json_object(content).unwrap_or_default(),
extract_json_object(content).unwrap_or_default(),
];
for candidate in candidates {
if candidate.trim().is_empty() {
continue;
}
if let Ok(response) = serde_json::from_str::<TranslationResponse>(&candidate) {
if let Some(rows) = response
.translations
.or(response.items)
.or(response.results)
.filter(|rows| !rows.is_empty())
{
return Ok(rows);
}
}
if let Ok(rows) = serde_json::from_str::<Vec<TranslatedRow>>(&candidate) {
if !rows.is_empty() {
return Ok(rows);
}
}
let loose_rows = extract_rows_loose(&candidate);
if !loose_rows.is_empty() {
return Ok(loose_rows);
}
}
Err(anyhow!("unable to parse translation response"))
}
fn strip_code_fence(content: &str) -> String {
let trimmed = content.trim();
if !trimmed.starts_with("```") {
return trimmed.to_string();
}
let without_prefix = trimmed
.trim_start_matches("```json")
.trim_start_matches("```JSON")
.trim_start_matches("```");
without_prefix
.trim_end_matches("```")
.trim()
.to_string()
}
fn extract_json_object(content: &str) -> Option<String> {
let start = content.find('{')?;
let end = content.rfind('}')?;
(end > start).then(|| content[start..=end].trim().to_string())
}
fn strip_think_block(content: &str) -> String {
if let Some(end) = content.rfind("</think>") {
return content[end + "</think>".len()..].trim().to_string();
}
content.trim().to_string()
}
fn extract_json_after_think(content: &str) -> Option<String> {
let stripped = strip_think_block(content);
extract_last_json_object(&stripped)
}
fn extract_last_json_object(content: &str) -> Option<String> {
let end = content.rfind('}')?;
let start = content[..=end].rfind('{')?;
(end > start).then(|| content[start..=end].trim().to_string())
}
fn preview(content: &str) -> String {
let compact = content.replace('\n', " ");
compact.chars().take(240).collect()
}
fn mask_secret(secret: &str) -> String {
if secret.len() <= 8 {
return "****".to_string();
}
format!("{}****{}", &secret[..4], &secret[secret.len().saturating_sub(4)..])
}
fn extract_rows_loose(content: &str) -> Vec<TranslatedRow> {
let mut rows = Vec::new();
let mut cursor = 0usize;
while let Some(id_key_pos) = content[cursor..].find("\"id\"") {
let id_key_pos = cursor + id_key_pos;
let Some((id, after_id)) = extract_field_value(content, id_key_pos, "id") else {
cursor = id_key_pos + 4;
continue;
};
let Some(text_key_rel) = content[after_id..].find("\"text\"") else {
cursor = after_id;
continue;
};
let text_key_pos = after_id + text_key_rel;
let Some((text, after_text)) = extract_field_value(content, text_key_pos, "text") else {
cursor = text_key_pos + 6;
continue;
};
rows.push(TranslatedRow { id, text });
cursor = after_text;
}
rows
}
fn extract_field_value(content: &str, key_pos: usize, key: &str) -> Option<(String, usize)> {
let search_start = key_pos + key.len() + 2;
let colon_rel = content[search_start..].find(':')?;
let after_colon = search_start + colon_rel + 1;
let first_quote_rel = content[after_colon..].find('"')?;
let value_start = after_colon + first_quote_rel;
let (value, next_index) = parse_json_string(content, value_start)?;
Some((value, next_index))
}
fn parse_json_string(content: &str, start_quote: usize) -> Option<(String, usize)> {
let bytes = content.as_bytes();
if *bytes.get(start_quote)? != b'"' {
return None;
}
let mut end = start_quote + 1;
let mut escaped = false;
while let Some(&byte) = bytes.get(end) {
if escaped {
escaped = false;
end += 1;
continue;
}
match byte {
b'\\' => {
escaped = true;
end += 1;
}
b'"' => {
let raw = &content[start_quote..=end];
let parsed = serde_json::from_str::<String>(raw).ok()?;
return Some((parsed, end + 1));
}
_ => end += 1,
}
}
None
}