476 lines
15 KiB
Rust
476 lines
15 KiB
Rust
use std::time::Duration;
|
|
|
|
use anyhow::{anyhow, Context, Result};
|
|
use reqwest::Client;
|
|
use serde::{Deserialize, Serialize};
|
|
use tokio::time::sleep;
|
|
|
|
use crate::models::{SubtitleSegment, TargetLanguage, TranslationConfig};
|
|
|
|
#[derive(Debug, Serialize)]
|
|
struct ChatCompletionRequest {
|
|
model: String,
|
|
messages: Vec<ChatMessage>,
|
|
temperature: f32,
|
|
response_format: ResponseFormat,
|
|
}
|
|
|
|
#[derive(Debug, Serialize)]
|
|
struct ChatMessage {
|
|
role: String,
|
|
content: String,
|
|
}
|
|
|
|
#[derive(Debug, Serialize)]
|
|
struct ResponseFormat {
|
|
#[serde(rename = "type")]
|
|
format_type: String,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct ChatCompletionResponse {
|
|
choices: Vec<ChatChoice>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct ChatChoice {
|
|
message: ChatMessageContent,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct ChatMessageContent {
|
|
content: String,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct TranslationResponse {
|
|
translations: Option<Vec<TranslatedRow>>,
|
|
items: Option<Vec<TranslatedRow>>,
|
|
results: Option<Vec<TranslatedRow>>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Deserialize)]
|
|
pub(crate) struct TranslatedRow {
|
|
pub(crate) id: String,
|
|
pub(crate) text: String,
|
|
}
|
|
|
|
pub struct Translator {
|
|
client: Client,
|
|
config: TranslationConfig,
|
|
}
|
|
|
|
impl Translator {
|
|
pub fn new(config: TranslationConfig) -> Result<Self> {
|
|
let client = Client::builder()
|
|
.timeout(Duration::from_secs(60))
|
|
.build()
|
|
.context("failed to build translation client")?;
|
|
|
|
Ok(Self { client, config })
|
|
}
|
|
|
|
pub fn batch_size(&self) -> usize {
|
|
self.config.batch_size
|
|
}
|
|
|
|
pub fn context_size(&self) -> usize {
|
|
self.config.context_size
|
|
}
|
|
|
|
pub async fn translate_segments_with_progress<LF, PF, SF>(
|
|
&self,
|
|
segments: &[SubtitleSegment],
|
|
target_language: &TargetLanguage,
|
|
mut log: LF,
|
|
mut on_progress: PF,
|
|
mut emit_segment: SF,
|
|
) -> Result<Vec<SubtitleSegment>>
|
|
where
|
|
LF: FnMut(String),
|
|
PF: FnMut(f32),
|
|
SF: FnMut(SubtitleSegment),
|
|
{
|
|
let batch_size = self.config.batch_size.clamp(3, 350);
|
|
let context_size = self.config.context_size.min(5);
|
|
let mut translated = segments.to_vec();
|
|
let target_language_name = match target_language {
|
|
TargetLanguage::Zh => "简体中文",
|
|
TargetLanguage::En => "英文",
|
|
};
|
|
let total_batches = (segments.len() + batch_size - 1) / batch_size;
|
|
|
|
for (batch_index, batch_start) in (0..segments.len()).step_by(batch_size).enumerate() {
|
|
let batch_end = (batch_start + batch_size).min(segments.len());
|
|
let context_start = batch_start.saturating_sub(context_size);
|
|
let context = &segments[context_start..batch_start];
|
|
let batch = &segments[batch_start..batch_end];
|
|
log(format!(
|
|
"translation: batch {}-{}, segments={}",
|
|
batch_start + 1,
|
|
batch_end,
|
|
batch
|
|
.iter()
|
|
.map(|segment| segment.id.as_str())
|
|
.collect::<Vec<_>>()
|
|
.join(", ")
|
|
));
|
|
let batch_progress = (batch_index + 1) as f32 / total_batches.max(1) as f32;
|
|
on_progress(batch_progress);
|
|
let rows = self
|
|
.translate_batch_with_retries(context, batch, target_language_name)
|
|
.await?;
|
|
log(format!("translation: batch done, translated={}", rows.len()));
|
|
|
|
for row in rows {
|
|
if let Some(segment) = translated.iter_mut().find(|item| item.id == row.id) {
|
|
segment.translated_text = Some(row.text);
|
|
emit_segment(segment.clone());
|
|
}
|
|
}
|
|
}
|
|
|
|
Ok(translated)
|
|
}
|
|
|
|
pub(crate) async fn translate_batch_with_retries(
|
|
&self,
|
|
context: &[SubtitleSegment],
|
|
batch: &[SubtitleSegment],
|
|
target_language_name: &str,
|
|
) -> Result<Vec<TranslatedRow>> {
|
|
let mut collected = Vec::<TranslatedRow>::new();
|
|
let mut pending = batch.to_vec();
|
|
|
|
for retry in 0..3 {
|
|
if pending.is_empty() {
|
|
break;
|
|
}
|
|
|
|
if retry > 0 {
|
|
eprintln!(
|
|
"translation retry: attempt={}, missing_segments={}",
|
|
retry + 1,
|
|
pending
|
|
.iter()
|
|
.map(|segment| segment.id.as_str())
|
|
.collect::<Vec<_>>()
|
|
.join(", ")
|
|
);
|
|
}
|
|
|
|
let rows = self
|
|
.request_translation(context, &pending, target_language_name, retry)
|
|
.await?;
|
|
merge_rows(&mut collected, rows);
|
|
|
|
let translated_ids = collected
|
|
.iter()
|
|
.map(|row| row.id.as_str())
|
|
.collect::<std::collections::HashSet<_>>();
|
|
pending.retain(|segment| !translated_ids.contains(segment.id.as_str()));
|
|
}
|
|
|
|
if !pending.is_empty() {
|
|
return Err(anyhow!(
|
|
"translation missing segments after retries: {}",
|
|
pending
|
|
.iter()
|
|
.map(|segment| segment.id.as_str())
|
|
.collect::<Vec<_>>()
|
|
.join(", ")
|
|
));
|
|
}
|
|
|
|
Ok(order_rows(batch, &collected))
|
|
}
|
|
|
|
async fn request_translation(
|
|
&self,
|
|
context: &[SubtitleSegment],
|
|
batch: &[SubtitleSegment],
|
|
target_language_name: &str,
|
|
retry: usize,
|
|
) -> Result<Vec<TranslatedRow>> {
|
|
let context_text = if context.is_empty() {
|
|
"无".to_string()
|
|
} else {
|
|
context
|
|
.iter()
|
|
.map(|item| format!("{}: {}", item.id, item.source_text))
|
|
.collect::<Vec<_>>()
|
|
.join("\n")
|
|
};
|
|
|
|
let batch_text = batch
|
|
.iter()
|
|
.map(|item| format!("{}: {}", item.id, item.source_text))
|
|
.collect::<Vec<_>>()
|
|
.join("\n");
|
|
|
|
let request = ChatCompletionRequest {
|
|
model: self.config.model.clone(),
|
|
temperature: 0.2,
|
|
response_format: ResponseFormat {
|
|
format_type: "json_object".to_string(),
|
|
},
|
|
messages: vec![
|
|
ChatMessage {
|
|
role: "system".to_string(),
|
|
content: "你是专业字幕翻译助手。请保持人称、术语和语气一致,只输出 JSON。".to_string(),
|
|
},
|
|
ChatMessage {
|
|
role: "user".to_string(),
|
|
content: format!(
|
|
"{}把以下字幕翻译成{}。保持专有名词、角色称呼和上下文一致。必须逐条返回所有待翻译片段,禁止遗漏、合并、拆分或改写 id。上下文:\n{}\n\n待翻译片段:\n{}\n\n请返回 {{\"translations\":[{{\"id\":\"seg-0001\",\"text\":\"译文\"}}]}}",
|
|
retry_prompt_prefix(retry),
|
|
target_language_name,
|
|
context_text,
|
|
batch_text
|
|
),
|
|
},
|
|
],
|
|
};
|
|
|
|
let url = format!(
|
|
"{}/chat/completions",
|
|
self.config.api_base.trim_end_matches('/')
|
|
);
|
|
let request_json = serde_json::to_string_pretty(&request)
|
|
.context("failed to serialize translation request")?;
|
|
eprintln!(
|
|
"translation request url: {}\ntranslation request headers: Authorization: Bearer {}\ntranslation request body:\n{}",
|
|
url,
|
|
mask_secret(&self.config.api_key),
|
|
request_json
|
|
);
|
|
|
|
let mut last_error: Option<anyhow::Error> = None;
|
|
for attempt in 0..3 {
|
|
let response = self
|
|
.client
|
|
.post(&url)
|
|
.bearer_auth(&self.config.api_key)
|
|
.json(&request)
|
|
.send()
|
|
.await;
|
|
|
|
match response {
|
|
Ok(response) => {
|
|
let response = response.error_for_status().context("translation http error")?;
|
|
let raw_text = response.text().await.context("invalid response body")?;
|
|
eprintln!("translation raw response:\n{}", raw_text);
|
|
let payload: ChatCompletionResponse =
|
|
serde_json::from_str(&raw_text).context("invalid response body")?;
|
|
let content = payload
|
|
.choices
|
|
.first()
|
|
.ok_or_else(|| anyhow!("translation response missing choices"))?
|
|
.message
|
|
.content
|
|
.clone();
|
|
let rows = parse_translation_response(&content)
|
|
.with_context(|| format!("translation json parse failed: {}", preview(&content)))?;
|
|
return Ok(rows);
|
|
}
|
|
Err(error) => {
|
|
last_error = Some(error.into());
|
|
sleep(Duration::from_millis(500 * (attempt + 1) as u64)).await;
|
|
}
|
|
}
|
|
}
|
|
|
|
Err(last_error.unwrap_or_else(|| anyhow!("translation request failed")))
|
|
}
|
|
}
|
|
|
|
fn retry_prompt_prefix(retry: usize) -> &'static str {
|
|
if retry == 0 {
|
|
""
|
|
} else {
|
|
"这是补漏重试。你上次漏掉了部分片段,这次只需要返回当前待翻译片段对应的结果,确保每个 id 都出现且只出现一次。\n\n"
|
|
}
|
|
}
|
|
|
|
fn merge_rows(collected: &mut Vec<TranslatedRow>, rows: Vec<TranslatedRow>) {
|
|
for row in rows {
|
|
if let Some(existing) = collected.iter_mut().find(|item| item.id == row.id) {
|
|
if !row.text.trim().is_empty() {
|
|
existing.text = row.text;
|
|
}
|
|
} else {
|
|
collected.push(row);
|
|
}
|
|
}
|
|
}
|
|
|
|
fn order_rows(batch: &[SubtitleSegment], rows: &[TranslatedRow]) -> Vec<TranslatedRow> {
|
|
batch
|
|
.iter()
|
|
.filter_map(|segment| rows.iter().find(|row| row.id == segment.id).cloned())
|
|
.collect()
|
|
}
|
|
|
|
fn parse_translation_response(content: &str) -> Result<Vec<TranslatedRow>> {
|
|
let candidates = [
|
|
content.trim().to_string(),
|
|
strip_code_fence(content),
|
|
strip_think_block(content),
|
|
extract_json_after_think(content).unwrap_or_default(),
|
|
extract_last_json_object(content).unwrap_or_default(),
|
|
extract_json_object(content).unwrap_or_default(),
|
|
];
|
|
|
|
for candidate in candidates {
|
|
if candidate.trim().is_empty() {
|
|
continue;
|
|
}
|
|
|
|
if let Ok(response) = serde_json::from_str::<TranslationResponse>(&candidate) {
|
|
if let Some(rows) = response
|
|
.translations
|
|
.or(response.items)
|
|
.or(response.results)
|
|
.filter(|rows| !rows.is_empty())
|
|
{
|
|
return Ok(rows);
|
|
}
|
|
}
|
|
|
|
if let Ok(rows) = serde_json::from_str::<Vec<TranslatedRow>>(&candidate) {
|
|
if !rows.is_empty() {
|
|
return Ok(rows);
|
|
}
|
|
}
|
|
|
|
let loose_rows = extract_rows_loose(&candidate);
|
|
if !loose_rows.is_empty() {
|
|
return Ok(loose_rows);
|
|
}
|
|
}
|
|
|
|
Err(anyhow!("unable to parse translation response"))
|
|
}
|
|
|
|
fn strip_code_fence(content: &str) -> String {
|
|
let trimmed = content.trim();
|
|
if !trimmed.starts_with("```") {
|
|
return trimmed.to_string();
|
|
}
|
|
|
|
let without_prefix = trimmed
|
|
.trim_start_matches("```json")
|
|
.trim_start_matches("```JSON")
|
|
.trim_start_matches("```");
|
|
without_prefix
|
|
.trim_end_matches("```")
|
|
.trim()
|
|
.to_string()
|
|
}
|
|
|
|
fn extract_json_object(content: &str) -> Option<String> {
|
|
let start = content.find('{')?;
|
|
let end = content.rfind('}')?;
|
|
(end > start).then(|| content[start..=end].trim().to_string())
|
|
}
|
|
|
|
fn strip_think_block(content: &str) -> String {
|
|
if let Some(end) = content.rfind("</think>") {
|
|
return content[end + "</think>".len()..].trim().to_string();
|
|
}
|
|
|
|
content.trim().to_string()
|
|
}
|
|
|
|
fn extract_json_after_think(content: &str) -> Option<String> {
|
|
let stripped = strip_think_block(content);
|
|
extract_last_json_object(&stripped)
|
|
}
|
|
|
|
fn extract_last_json_object(content: &str) -> Option<String> {
|
|
let end = content.rfind('}')?;
|
|
let start = content[..=end].rfind('{')?;
|
|
(end > start).then(|| content[start..=end].trim().to_string())
|
|
}
|
|
|
|
fn preview(content: &str) -> String {
|
|
let compact = content.replace('\n', " ");
|
|
compact.chars().take(240).collect()
|
|
}
|
|
|
|
fn mask_secret(secret: &str) -> String {
|
|
if secret.len() <= 8 {
|
|
return "****".to_string();
|
|
}
|
|
|
|
format!("{}****{}", &secret[..4], &secret[secret.len().saturating_sub(4)..])
|
|
}
|
|
|
|
fn extract_rows_loose(content: &str) -> Vec<TranslatedRow> {
|
|
let mut rows = Vec::new();
|
|
let mut cursor = 0usize;
|
|
|
|
while let Some(id_key_pos) = content[cursor..].find("\"id\"") {
|
|
let id_key_pos = cursor + id_key_pos;
|
|
let Some((id, after_id)) = extract_field_value(content, id_key_pos, "id") else {
|
|
cursor = id_key_pos + 4;
|
|
continue;
|
|
};
|
|
let Some(text_key_rel) = content[after_id..].find("\"text\"") else {
|
|
cursor = after_id;
|
|
continue;
|
|
};
|
|
let text_key_pos = after_id + text_key_rel;
|
|
let Some((text, after_text)) = extract_field_value(content, text_key_pos, "text") else {
|
|
cursor = text_key_pos + 6;
|
|
continue;
|
|
};
|
|
|
|
rows.push(TranslatedRow { id, text });
|
|
cursor = after_text;
|
|
}
|
|
|
|
rows
|
|
}
|
|
|
|
fn extract_field_value(content: &str, key_pos: usize, key: &str) -> Option<(String, usize)> {
|
|
let search_start = key_pos + key.len() + 2;
|
|
let colon_rel = content[search_start..].find(':')?;
|
|
let after_colon = search_start + colon_rel + 1;
|
|
let first_quote_rel = content[after_colon..].find('"')?;
|
|
let value_start = after_colon + first_quote_rel;
|
|
let (value, next_index) = parse_json_string(content, value_start)?;
|
|
Some((value, next_index))
|
|
}
|
|
|
|
fn parse_json_string(content: &str, start_quote: usize) -> Option<(String, usize)> {
|
|
let bytes = content.as_bytes();
|
|
if *bytes.get(start_quote)? != b'"' {
|
|
return None;
|
|
}
|
|
|
|
let mut end = start_quote + 1;
|
|
let mut escaped = false;
|
|
while let Some(&byte) = bytes.get(end) {
|
|
if escaped {
|
|
escaped = false;
|
|
end += 1;
|
|
continue;
|
|
}
|
|
match byte {
|
|
b'\\' => {
|
|
escaped = true;
|
|
end += 1;
|
|
}
|
|
b'"' => {
|
|
let raw = &content[start_quote..=end];
|
|
let parsed = serde_json::from_str::<String>(raw).ok()?;
|
|
return Some((parsed, end + 1));
|
|
}
|
|
_ => end += 1,
|
|
}
|
|
}
|
|
|
|
None
|
|
}
|