//! OpenAI Chat Completions API (/v1/chat/completions) handler. use axum::{ extract::State, http::StatusCode, response::{sse::Event, IntoResponse, Json, Sse}, }; use rand::Rng; use std::sync::Arc; use tracing::{debug, info, warn}; use super::models::{lookup_model, DEFAULT_MODEL, MODELS}; use super::polling::{ extract_response_text, extract_thinking_content, is_response_done, poll_for_response, }; use super::types::*; use super::util::{err_response, now_unix, upstream_err_response}; use super::AppState; use crate::mitm::store::{CapturedFunctionCall, PendingToolResult, ToolRound}; /// System fingerprint for completions responses (derived from crate version at compile time). fn system_fingerprint() -> String { format!("fp_{}", env!("CARGO_PKG_VERSION").replace('.', "")) } /// Build a streaming chunk JSON with all required OpenAI fields. /// Includes system_fingerprint, service_tier, and logprobs:null in choices. fn chunk_json( id: &str, model: &str, choices: serde_json::Value, usage: Option, ) -> String { let mut obj = serde_json::json!({ "id": id, "object": "chat.completion.chunk", "created": now_unix(), "model": model, "system_fingerprint": system_fingerprint(), "service_tier": "default", "choices": choices, }); if let Some(u) = usage { obj["usage"] = u; } serde_json::to_string(&obj).unwrap_or_default() } /// Build a single choice for a streaming chunk (delta + finish_reason + logprobs). fn chunk_choice( index: u32, delta: serde_json::Value, finish_reason: Option<&str>, ) -> serde_json::Value { serde_json::json!({ "index": index, "delta": delta, "logprobs": serde_json::Value::Null, "finish_reason": finish_reason, }) } // ─── Finish reason mapping ─────────────────────────────────────────────────── /// Map Google's finishReason → OpenAI's finish_reason. /// Google: STOP, MAX_TOKENS, SAFETY, RECITATION, OTHER, BLOCKLIST, PROHIBITED_CONTENT /// OpenAI: stop, length, content_filter, tool_calls (handled separately) fn google_to_openai_finish_reason(stop_reason: Option<&str>) -> &'static str { match stop_reason { Some("MAX_TOKENS") => "length", Some("SAFETY") | Some("RECITATION") | Some("BLOCKLIST") | Some("PROHIBITED_CONTENT") => { "content_filter" } _ => "stop", } } // ─── Input extraction ──────────────────────────────────────────────────────── /// Extract user text from Chat Completions messages array. /// /// Builds the full conversation context including all messages (system, user, /// assistant, tool) so the model has complete history — matching how OpenAI /// sends the entire messages array to the model. fn extract_chat_input(messages: &[CompletionMessage]) -> (String, Option) { // Extract image from last user message content array let image = messages .iter() .rev() .find(|m| m.role == "user") .and_then(|m| super::util::extract_first_image(&m.content)); // Always build the full conversation (build_conversation_with_tools(messages), image) } /// Extract text content from a message's content field (string or array). fn extract_message_text(content: &serde_json::Value) -> String { match content { serde_json::Value::String(s) => s.clone(), serde_json::Value::Array(arr) => arr .iter() .filter_map(|item| item["text"].as_str()) .collect::>() .join("\n"), _ => String::new(), } } /// Build conversation text that includes tool call results. /// /// Format: /// [system prompt] /// [user message] /// [assistant called tool X with args Y] /// [tool result: Z] /// [user followup if any] fn build_conversation_with_tools(messages: &[CompletionMessage]) -> String { let mut parts = Vec::new(); for msg in messages { match msg.role.as_str() { "system" | "developer" => { let text = extract_message_text(&msg.content); if !text.is_empty() { parts.push(text); } } "user" => { let text = extract_message_text(&msg.content); if !text.is_empty() { parts.push(text); } } "assistant" => { // Include assistant text if any let text = extract_message_text(&msg.content); if !text.is_empty() { parts.push(text); } // Include tool calls as context if let Some(ref tool_calls) = msg.tool_calls { for tc in tool_calls { if let Some(func) = tc.get("function") { let name = func["name"].as_str().unwrap_or("unknown"); let args = func["arguments"].as_str().unwrap_or("{}"); parts.push(format!("[Tool call: {}({})]", name, args)); } } } } "tool" => { let text = extract_message_text(&msg.content); let tool_id = msg.tool_call_id.as_deref().unwrap_or("unknown"); if !text.is_empty() { parts.push(format!("[Tool result ({})]:\n{}", tool_id, text)); } } _ => {} } } parts.join("\n\n") } // ─── Handler ───────────────────────────────────────────────────────────────── /// POST /v1/chat/completions — OpenAI Chat Completions API compatibility shim. /// Accepts standard messages format, reuses the same backend cascade, and /// outputs in the Chat Completions streaming/sync format. pub(crate) async fn handle_completions( State(state): State>, Json(body): Json, ) -> axum::response::Response { let model_name = body.model.as_deref().unwrap_or(DEFAULT_MODEL); info!( "POST /v1/chat/completions model={} stream={}", model_name, body.stream ); let model = match lookup_model(model_name) { Some(m) => m, None => { let names: Vec<&str> = MODELS.iter().map(|m| m.name).collect(); return err_response( StatusCode::BAD_REQUEST, format!("Unknown model: {model_name}. Available: {names:?}"), "invalid_request_error", ); } }; // ── Build per-request state locally ────────────────────────────────── // Convert OpenAI tools to Gemini format let tools = body.tools.as_ref().and_then(|t| { let gemini_tools = crate::mitm::modify::openai_tools_to_gemini(t); if gemini_tools.is_empty() { None } else { info!( count = t.len(), "Completions: client tools for MITM injection" ); Some(gemini_tools) } }); let tool_config = body.tools.as_ref().and_then(|_| { body.tool_choice .as_ref() .map(crate::mitm::modify::openai_tool_choice_to_gemini) }); // ── Extract tool results from messages for MITM injection ────────── // Build ToolRounds from message history: each round pairs assistant tool_calls // with subsequent tool result messages. Local call_id_to_name mapping. let mut tool_rounds: Vec = Vec::new(); let mut call_id_to_name: std::collections::HashMap = std::collections::HashMap::new(); { let mut current_round: Option = None; for msg in &body.messages { match msg.role.as_str() { "assistant" => { // Finalize any open round if let Some(round) = current_round.take() { if !round.calls.is_empty() { tool_rounds.push(round); } } // Start new round if this assistant has tool_calls if let Some(ref tool_calls) = msg.tool_calls { let mut calls = Vec::new(); for tc in tool_calls { if let Some(func) = tc.get("function") { let name = func["name"].as_str().unwrap_or("unknown").to_string(); let args_str = func["arguments"].as_str().unwrap_or("{}"); let args = serde_json::from_str::(args_str) .unwrap_or(serde_json::json!({})); let call_id = tc["id"].as_str().unwrap_or("").to_string(); // Register call_id → name locally if !call_id.is_empty() { call_id_to_name.insert(call_id, name.clone()); } calls.push(CapturedFunctionCall { name, args, thought_signature: None, captured_at: std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .unwrap_or_default() .as_secs(), }); } } if !calls.is_empty() { current_round = Some(ToolRound { calls, results: Vec::new(), }); } } } "tool" => { let text = extract_message_text(&msg.content); if let Some(ref call_id) = msg.tool_call_id { let result_index = current_round.as_ref().map(|r| r.results.len()).unwrap_or(0); let name = call_id_to_name .get(call_id.as_str()) .cloned() .unwrap_or_else(|| { current_round .as_ref() .and_then(|r| r.calls.get(result_index)) .map(|fc| fc.name.clone()) .unwrap_or_else(|| "unknown_function".to_string()) }); let result_value = serde_json::from_str::(&text) .unwrap_or_else(|_| serde_json::json!({"result": text})); if let Some(ref mut round) = current_round { round.results.push(PendingToolResult { name, result: result_value, }); } } } _ => { // Any other role (user, system) finalizes the current round if let Some(round) = current_round.take() { if !round.calls.is_empty() { tool_rounds.push(round); } } } } } // Finalize last round if let Some(round) = current_round.take() { if !round.calls.is_empty() { tool_rounds.push(round); } } if !tool_rounds.is_empty() { info!( round_count = tool_rounds.len(), calls = ?tool_rounds.iter().map(|r| r.calls.iter().map(|c| &c.name).collect::>()).collect::>(), "Completions: {} tool round(s) for MITM history rewrite", tool_rounds.len(), ); // Merge thought_signatures from MITM-captured function calls. // OpenAI format doesn't carry thought signatures, but Google requires // them when injecting functionCall parts back into history. let sigs = state.mitm_store.peek_thought_signatures().await; if !sigs.is_empty() { let mut merged = 0usize; for round in &mut tool_rounds { for fc in &mut round.calls { if fc.thought_signature.is_none() { if let Some(sig) = sigs.get(&fc.name) { fc.thought_signature = Some(sig.clone()); merged += 1; } } } } if merged > 0 { info!( merged_count = merged, "Completions: merged {} thought_signature(s) from MITM capture", merged, ); } } } } // Build generation parameters locally use crate::mitm::store::GenerationParams; let (response_mime_type, response_schema) = match body.response_format.as_ref() { Some(rf) => match rf.format_type.as_str() { "json_object" | "json" => (Some("application/json".to_string()), None), "json_schema" => { let schema = rf.json_schema.as_ref().and_then(|js| js.schema.clone()); (Some("application/json".to_string()), schema) } _ => (None, None), }, None => (None, None), }; let gp = GenerationParams { temperature: body.temperature, top_p: body.top_p, top_k: None, max_output_tokens: body.max_tokens.or(body.max_completion_tokens), stop_sequences: body.stop.clone().map(|s| s.into_vec()), frequency_penalty: body.frequency_penalty, presence_penalty: body.presence_penalty, reasoning_effort: body.reasoning_effort.clone(), response_mime_type, response_schema, google_search: body.web_search, }; let generation_params = if gp.temperature.is_some() || gp.top_p.is_some() || gp.max_output_tokens.is_some() || gp.frequency_penalty.is_some() || gp.presence_penalty.is_some() || gp.reasoning_effort.is_some() || gp.stop_sequences.is_some() || gp.response_mime_type.is_some() || gp.response_schema.is_some() || gp.google_search { Some(gp) } else { None }; let token = state.backend.oauth_token().await; if token.is_empty() { return err_response( StatusCode::UNAUTHORIZED, "No OAuth token. POST to /v1/token or set ZEROGRAVITY_TOKEN env var.".into(), "authentication_error", ); } let (user_text, image) = extract_chat_input(&body.messages); if user_text.is_empty() { return err_response( StatusCode::BAD_REQUEST, "No user message found".to_string(), "invalid_request_error", ); } let n = body.n.clamp(1, 5); // Cap at 5 to prevent abuse if n > 1 && body.stream { warn!("n={n} requested with streaming — streaming only supports n=1, ignoring n"); } // Always create a new cascade for every request let cascade_id = match state.backend.create_cascade().await { Ok(cid) => cid, Err(e) => { return err_response( StatusCode::BAD_GATEWAY, format!("StartCascade failed: {e}"), "server_error", ); } }; // Image for MITM injection let pending_image = image.as_ref().map(|img| { use base64::Engine; crate::mitm::store::PendingImage { base64_data: base64::engine::general_purpose::STANDARD.encode(&img.data), mime_type: img.mime_type.clone(), } }); // Get last calls from the latest tool round (if any) for proxy recording compat let last_function_calls = tool_rounds .last() .map(|r| r.calls.clone()) .unwrap_or_default(); // Build event channel — always created for MITM response path let (tx, rx) = tokio::sync::mpsc::channel(64); let (mitm_rx, event_tx) = (Some(rx), tx); // Build pending tool results from latest round let pending_tool_results = tool_rounds .last() .map(|r| r.results.clone()) .unwrap_or_default(); // Start debug trace let trace = state.trace.start( &cascade_id, "POST /v1/chat/completions", model_name, body.stream, ); if let Some(ref t) = trace { t.set_client_request(crate::trace::ClientRequestSummary { message_count: body.messages.len(), tool_count: body.tools.as_ref().map_or(0, |t| t.len()), tool_round_count: tool_rounds.len(), user_text_len: user_text.len(), user_text_preview: user_text.chars().take(200).collect(), system_prompt: body.messages.iter().any(|m| m.role == "system"), has_image: image.is_some(), }) .await; // Start turn 0 t.start_turn().await; } let mitm_gate = std::sync::Arc::new(tokio::sync::Notify::new()); let mitm_gate_clone = mitm_gate.clone(); state .mitm_store .register_request(crate::mitm::store::RequestContext { cascade_id: cascade_id.clone(), pending_user_text: user_text.clone(), event_channel: event_tx, generation_params, pending_image, tools, tool_config, pending_tool_results, tool_rounds, last_function_calls, call_id_to_name, created_at: std::time::Instant::now(), gate: mitm_gate_clone, trace_handle: trace.clone(), trace_turn: 0, }) .await; // Send REAL user text to LS match state .backend .send_message_with_image( &cascade_id, &format!(".", cascade_id), model.model_enum, image.as_ref(), ) .await { Ok((200, _)) => { let bg = Arc::clone(&state.backend); let cid = cascade_id.clone(); tokio::spawn(async move { let _ = bg.update_annotations(&cid).await; }); } Ok((status, _)) => { state.mitm_store.remove_request(&cascade_id).await; if let Some(ref t) = trace { t.record_error(format!("Backend returned {status}")).await; t.finish("backend_error").await; } return err_response( StatusCode::BAD_GATEWAY, format!("Backend returned {status}"), "server_error", ); } Err(e) => { state.mitm_store.remove_request(&cascade_id).await; if let Some(ref t) = trace { t.record_error(format!("Send failed: {e}")).await; t.finish("send_error").await; } return err_response( StatusCode::BAD_GATEWAY, format!("Send failed: {e}"), "server_error", ); } } // Wait for MITM gate: 5s → 502 if MITM enabled let gate_start = std::time::Instant::now(); let gate_matched = tokio::time::timeout(std::time::Duration::from_secs(5), mitm_gate.notified()).await; let gate_wait_ms = gate_start.elapsed().as_millis() as u64; if gate_matched.is_err() { if state.mitm_enabled { state.mitm_store.remove_request(&cascade_id).await; if let Some(ref t) = trace { t.record_error("MITM gate timeout (5s)".to_string()).await; t.finish("mitm_timeout").await; } return err_response( StatusCode::BAD_GATEWAY, "MITM proxy did not match request within 5s".to_string(), "mitm_timeout", ); } warn!(cascade = %cascade_id, "MITM gate timeout (--no-mitm mode)"); } else { debug!(cascade = %cascade_id, gate_wait_ms, "MITM gate signaled — request matched"); if let Some(ref t) = trace { t.record_mitm_match(0, gate_wait_ms).await; } } let completion_id = format!( "chatcmpl-{}", uuid::Uuid::new_v4().to_string().replace('-', "") ); let include_usage = body .stream_options .as_ref() .is_some_and(|o| o.include_usage); if body.stream { chat_completions_stream( state, completion_id, model_name.to_string(), cascade_id, body.timeout, include_usage, mitm_rx, trace, ) .await } else if n <= 1 { chat_completions_sync( state, completion_id, model_name.to_string(), cascade_id, body.timeout, trace, ) .await } else { // n > 1: fire additional (n-1) parallel cascades let mut extra_cascade_ids = Vec::with_capacity((n - 1) as usize); for _ in 1..n { if let Ok(cid) = state.backend.create_cascade().await { // Send the same message on each extra cascade if let Ok((200, _)) = state .backend .send_message_with_image( &cid, &format!(".", cid), model.model_enum, image.as_ref(), ) .await { let bg = Arc::clone(&state.backend); let cid2 = cid.clone(); tokio::spawn(async move { let _ = bg.update_annotations(&cid2).await; }); extra_cascade_ids.push(cid); } } } // Poll all cascades in parallel let mut handles = Vec::with_capacity(n as usize); let all_cascade_ids: Vec = std::iter::once(cascade_id.clone()) .chain(extra_cascade_ids) .collect(); for cid in &all_cascade_ids { let st = Arc::clone(&state); let cid = cid.clone(); let timeout = body.timeout; handles.push(tokio::spawn(async move { let result = poll_for_response(&st, &cid, timeout).await; let mitm = match st.mitm_store.take_usage(&cid).await { Some(u) => Some(u), None => st.mitm_store.take_usage("_latest").await, }; (result, mitm) })); } let mut choices = Vec::with_capacity(n as usize); let mut total_prompt = 0u64; let mut total_completion = 0u64; let mut total_cached = 0u64; let mut total_thinking = 0u64; for (i, handle) in handles.into_iter().enumerate() { if let Ok((result, mitm)) = handle.await { let finish_reason = google_to_openai_finish_reason( mitm.as_ref().and_then(|u| u.stop_reason.as_deref()), ); let (pt, ct, cached, thinking) = if let Some(ref mu) = mitm { ( mu.input_tokens, mu.output_tokens, mu.cache_read_input_tokens, mu.thinking_output_tokens, ) } else if let Some(u) = &result.usage { (u.input_tokens, u.output_tokens, 0, 0) } else { (0, 0, 0, 0) }; total_prompt += pt; total_completion += ct; total_cached += cached; total_thinking += thinking; let mut message = serde_json::json!({ "role": "assistant", "content": result.text, }); if let Some(ref thinking_text) = result.thinking { message["reasoning_content"] = serde_json::json!(thinking_text); } choices.push(serde_json::json!({ "index": i, "message": message, "logprobs": serde_json::Value::Null, "finish_reason": finish_reason, })); } } Json(serde_json::json!({ "id": completion_id, "object": "chat.completion", "created": now_unix(), "model": model_name, "system_fingerprint": system_fingerprint(), "service_tier": "default", "choices": choices, "usage": { "prompt_tokens": total_prompt, "completion_tokens": total_completion, "total_tokens": total_prompt + total_completion, "prompt_tokens_details": { "cached_tokens": total_cached, }, "completion_tokens_details": { "reasoning_tokens": total_thinking, }, }, })) .into_response() } } // ─── Streaming ─────────────────────────────────────────────────────────────── /// Streaming output in Chat Completions format. #[allow(clippy::too_many_arguments)] async fn chat_completions_stream( state: Arc, completion_id: String, model_name: String, cascade_id: String, timeout: u64, include_usage: bool, mitm_rx: Option>, trace: Option, ) -> axum::response::Response { let stream = async_stream::stream! { let start = std::time::Instant::now(); let mut last_text = String::new(); let has_custom_tools = mitm_rx.is_some(); if !has_custom_tools { state.mitm_store.clear_response_async().await; state.mitm_store.clear_upstream_error().await; let _ = state.mitm_store.take_any_function_calls().await; } // Initial role chunk yield Ok::<_, std::convert::Infallible>(Event::default().data(chunk_json( &completion_id, &model_name, serde_json::json!([chunk_choice(0, serde_json::json!({"role": "assistant", "content": ""}), None)]), None, ))); let mut keepalive_counter: u64 = 0; let mut last_thinking_len: usize = 0; let mut complete_polls: u32 = 0; let mut did_unblock_ls = false; // Prevents infinite unblock loops // Helper: build usage JSON from MITM tokens let build_usage = |pt: u64, ct: u64, crt: u64, tt: u64| -> serde_json::Value { serde_json::json!({ "prompt_tokens": pt, "completion_tokens": ct, "total_tokens": pt + ct, "prompt_tokens_details": { "cached_tokens": crt }, "completion_tokens_details": { "reasoning_tokens": tt }, }) }; // Take ownership of the pre-installed channel receiver let mut rx_opt = mitm_rx; while start.elapsed().as_secs() < timeout { if let Some(ref mut rx) = rx_opt { // ── Channel-based MITM pipeline ── // Track accumulated text for delta computation let mut acc_text = String::new(); let mut acc_thinking = String::new(); let mut last_usage: Option = None; 'channel_loop: while let Some(event) = tokio::time::timeout( std::time::Duration::from_secs(timeout.saturating_sub(start.elapsed().as_secs())), rx.recv(), ).await.ok().flatten() { use crate::mitm::store::MitmEvent; match event { MitmEvent::ThinkingDelta(full_thinking) => { if full_thinking.len() > acc_thinking.len() { let delta = full_thinking[acc_thinking.len()..].to_string(); acc_thinking = full_thinking; last_thinking_len = acc_thinking.len(); yield Ok(Event::default().data(chunk_json( &completion_id, &model_name, serde_json::json!([chunk_choice(0, serde_json::json!({"reasoning_content": delta}), None)]), None, ))); } } MitmEvent::TextDelta(full_text) => { if full_text.len() > acc_text.len() { let delta = full_text[acc_text.len()..].to_string(); acc_text = full_text; last_text = acc_text.clone(); yield Ok(Event::default().data(chunk_json( &completion_id, &model_name, serde_json::json!([chunk_choice(0, serde_json::json!({"content": delta}), None)]), None, ))); } } MitmEvent::FunctionCall(calls) => { let mut tool_calls = Vec::new(); for (i, fc) in calls.iter().enumerate() { let call_id = format!( "call_{}", &uuid::Uuid::new_v4().to_string().replace('-', "")[..24] ); let arguments = serde_json::to_string(&fc.args).unwrap_or_default(); tool_calls.push(serde_json::json!({ "index": i, "id": call_id, "type": "function", "function": { "name": fc.name, "arguments": arguments, }, })); } yield Ok(Event::default().data(chunk_json( &completion_id, &model_name, serde_json::json!([chunk_choice(0, serde_json::json!({"tool_calls": tool_calls}), None)]), None, ))); yield Ok(Event::default().data(chunk_json( &completion_id, &model_name, serde_json::json!([chunk_choice(0, serde_json::json!({}), Some("tool_calls"))]), None, ))); if include_usage { let mitm = state.mitm_store.take_usage(&cascade_id).await .or(state.mitm_store.take_usage("_latest").await); let (pt, ct, crt, tt) = if let Some(ref u) = mitm { (u.input_tokens, u.output_tokens, u.cache_read_input_tokens, u.thinking_output_tokens) } else if let Some(ref u) = last_usage { (u.input_tokens, u.output_tokens, u.cache_read_input_tokens, u.thinking_output_tokens) } else { (0, 0, 0, 0) }; yield Ok(Event::default().data(chunk_json( &completion_id, &model_name, serde_json::json!([]), Some(build_usage(pt, ct, crt, tt)), ))); } yield Ok(Event::default().data("[DONE]")); state.mitm_store.remove_request(&cascade_id).await; if let Some(ref t) = trace { let (ipt, opt, crt2, tht) = if let Some(ref u) = last_usage { (u.input_tokens, u.output_tokens, u.cache_read_input_tokens, u.thinking_output_tokens) } else { (0, 0, 0, 0) }; t.record_response(0, crate::trace::ResponseSummary { text_len: 0, thinking_len: 0, text_preview: String::new(), finish_reason: Some("tool_calls".to_string()), function_calls: calls.iter().map(|fc| crate::trace::FunctionCallSummary { name: fc.name.clone(), args_preview: serde_json::to_string(&fc.args).unwrap_or_default().chars().take(200).collect(), }).collect(), grounding: false, }).await; t.set_usage(crate::trace::TrackedUsage { input_tokens: ipt, output_tokens: opt, thinking_tokens: tht, cache_read: crt2 }).await; t.finish("tool_call").await; } return; } MitmEvent::ResponseComplete => { if !acc_text.is_empty() { // Have response text — done debug!("Completions: channel response complete, text_len={}, thinking_len={}", acc_text.len(), acc_thinking.len()); let mitm = state.mitm_store.take_usage(&cascade_id).await .or(state.mitm_store.take_usage("_latest").await) .or(last_usage.take()); let fr = google_to_openai_finish_reason(mitm.as_ref().and_then(|u| u.stop_reason.as_deref())); yield Ok(Event::default().data(chunk_json( &completion_id, &model_name, serde_json::json!([chunk_choice(0, serde_json::json!({}), Some(fr))]), None, ))); if include_usage { let (pt, ct, crt, tt) = if let Some(ref u) = mitm { (u.input_tokens, u.output_tokens, u.cache_read_input_tokens, u.thinking_output_tokens) } else { (0, 0, 0, 0) }; yield Ok(Event::default().data(chunk_json( &completion_id, &model_name, serde_json::json!([]), Some(build_usage(pt, ct, crt, tt)), ))); } yield Ok(Event::default().data("[DONE]")); state.mitm_store.remove_request(&cascade_id).await; if let Some(ref t) = trace { let (ipt, opt, crt2, tht) = if let Some(ref u) = mitm { (u.input_tokens, u.output_tokens, u.cache_read_input_tokens, u.thinking_output_tokens) } else { (0, 0, 0, 0) }; t.record_response(0, crate::trace::ResponseSummary { text_len: acc_text.len(), thinking_len: acc_thinking.len(), text_preview: acc_text.chars().take(200).collect(), finish_reason: Some("stop".to_string()), function_calls: Vec::new(), grounding: false, }).await; t.set_usage(crate::trace::TrackedUsage { input_tokens: ipt, output_tokens: opt, thinking_tokens: tht, cache_read: crt2 }).await; t.finish("completed").await; } return; } else if !acc_thinking.is_empty() && !did_unblock_ls { // Thinking-only response — LS needs follow-up API calls. // Create a new channel and unblock the gate. did_unblock_ls = true; let (new_tx, new_rx) = tokio::sync::mpsc::channel(64); state.mitm_store.set_channel(&cascade_id, new_tx).await; let _ = state.mitm_store.take_any_function_calls().await; *rx = new_rx; debug!( "Completions: thinking-only — new channel for follow-up, thinking_len={}", acc_thinking.len() ); continue 'channel_loop; } else if !acc_thinking.is_empty() && did_unblock_ls { // Already unblocked once, still thinking-only. // Wait a bit for potential follow-up events. complete_polls += 1; if complete_polls >= 25 { info!("Completions: thinking-only timeout, thinking_len={}", acc_thinking.len()); let mitm = state.mitm_store.take_usage(&cascade_id).await .or(state.mitm_store.take_usage("_latest").await) .or(last_usage.take()); let fr = google_to_openai_finish_reason(mitm.as_ref().and_then(|u| u.stop_reason.as_deref())); yield Ok(Event::default().data(chunk_json( &completion_id, &model_name, serde_json::json!([chunk_choice(0, serde_json::json!({}), Some(fr))]), None, ))); if include_usage { let (pt, ct, crt, tt) = if let Some(ref u) = mitm { (u.input_tokens, u.output_tokens, u.cache_read_input_tokens, u.thinking_output_tokens) } else { (0, 0, 0, 0) }; yield Ok(Event::default().data(chunk_json( &completion_id, &model_name, serde_json::json!([]), Some(build_usage(pt, ct, crt, tt)), ))); } yield Ok(Event::default().data("[DONE]")); state.mitm_store.remove_request(&cascade_id).await; if let Some(ref t) = trace { let (ipt, opt, crt2, tht) = if let Some(ref u) = mitm { (u.input_tokens, u.output_tokens, u.cache_read_input_tokens, u.thinking_output_tokens) } else { (0, 0, 0, 0) }; t.record_response(0, crate::trace::ResponseSummary { text_len: 0, thinking_len: acc_thinking.len(), text_preview: String::new(), finish_reason: Some("stop".to_string()), function_calls: Vec::new(), grounding: false, }).await; t.set_usage(crate::trace::TrackedUsage { input_tokens: ipt, output_tokens: opt, thinking_tokens: tht, cache_read: crt2 }).await; t.finish("thinking_timeout").await; } return; } // Don't break — wait for more channel events continue 'channel_loop; } else { // Empty response (no text, no thinking, no tools) complete_polls += 1; if complete_polls >= 4 { info!("Completions: channel response complete but empty, ending stream"); yield Ok(Event::default().data(chunk_json( &completion_id, &model_name, serde_json::json!([chunk_choice(0, serde_json::json!({}), Some("stop"))]), None, ))); yield Ok(Event::default().data("[DONE]")); state.mitm_store.remove_request(&cascade_id).await; if let Some(ref t) = trace { t.record_response(0, crate::trace::ResponseSummary { text_len: 0, thinking_len: 0, text_preview: String::new(), finish_reason: Some("stop".to_string()), function_calls: Vec::new(), grounding: false, }).await; t.finish("empty_response").await; } return; } continue 'channel_loop; } } MitmEvent::UpstreamError(err) => { let error_msg = super::util::upstream_error_message(&err); let error_type = super::util::upstream_error_type(&err); yield Ok(Event::default().data(serde_json::to_string(&serde_json::json!({ "error": { "message": error_msg, "type": error_type, "code": err.status, } })).unwrap())); yield Ok(Event::default().data("[DONE]")); state.mitm_store.remove_request(&cascade_id).await; return; } MitmEvent::Usage(u) => { last_usage = Some(u); } MitmEvent::Grounding(_) => { // Grounding metadata handled by store directly } } } // Channel closed or timeout — clean up state.mitm_store.remove_request(&cascade_id).await; // If we got here from timeout with content, emit what we have if !last_text.is_empty() || last_thinking_len > 0 { yield Ok(Event::default().data(chunk_json( &completion_id, &model_name, serde_json::json!([chunk_choice(0, serde_json::json!({}), Some("stop"))]), None, ))); } yield Ok(Event::default().data("[DONE]")); if let Some(ref t) = trace { t.record_response(0, crate::trace::ResponseSummary { text_len: last_text.len(), thinking_len: last_thinking_len, text_preview: last_text.chars().take(200).collect(), finish_reason: Some("stop".to_string()), function_calls: Vec::new(), grounding: false, }).await; t.finish("channel_closed").await; } return; } else { // ── Fallback: LS steps (no MITM capture active) ── if let Ok((status, data)) = state.backend.get_steps(&cascade_id).await { if status == 200 { if let Some(steps) = data["steps"].as_array() { // Stream thinking deltas (reasoning_content) if let Some(tc) = extract_thinking_content(steps) { if tc.len() > last_thinking_len { let delta = &tc[last_thinking_len..]; last_thinking_len = tc.len(); yield Ok(Event::default().data(chunk_json( &completion_id, &model_name, serde_json::json!([chunk_choice(0, serde_json::json!({"reasoning_content": delta}), None)]), None, ))); } } let text = extract_response_text(steps); if !text.is_empty() && text != last_text { let delta = if text.len() > last_text.len() && text.starts_with(&*last_text) { &text[last_text.len()..] } else { &text }; if !delta.is_empty() { yield Ok(Event::default().data(chunk_json( &completion_id, &model_name, serde_json::json!([chunk_choice(0, serde_json::json!({"content": delta}), None)]), None, ))); last_text = text.to_string(); } } // Done check let has_content = !last_text.is_empty() || last_thinking_len > 0; if is_response_done(steps) && has_content { debug!("Completions stream done, text length={}, thinking_len={}", last_text.len(), last_thinking_len); let mitm = state.mitm_store.take_usage(&cascade_id).await .or(state.mitm_store.take_usage("_latest").await); let fr = google_to_openai_finish_reason(mitm.as_ref().and_then(|u| u.stop_reason.as_deref())); yield Ok(Event::default().data(chunk_json( &completion_id, &model_name, serde_json::json!([chunk_choice(0, serde_json::json!({}), Some(fr))]), None, ))); if include_usage { let (pt, ct, crt, tt) = if let Some(ref u) = mitm { (u.input_tokens, u.output_tokens, u.cache_read_input_tokens, u.thinking_output_tokens) } else { (0, 0, 0, 0) }; yield Ok(Event::default().data(chunk_json( &completion_id, &model_name, serde_json::json!([]), Some(build_usage(pt, ct, crt, tt)), ))); } yield Ok(Event::default().data("[DONE]")); return; } // IDLE fallback let step_count = steps.len(); if step_count > 4 && step_count % 5 == 0 { if let Ok((ts, td)) = state.backend.get_trajectory(&cascade_id).await { if ts == 200 { let run_status = td["status"].as_str().unwrap_or(""); let has_content_idle = !last_text.is_empty() || last_thinking_len > 0; if run_status.contains("IDLE") && has_content_idle { debug!("Completions IDLE, text length={}, thinking_len={}", last_text.len(), last_thinking_len); let mitm = state.mitm_store.take_usage(&cascade_id).await .or(state.mitm_store.take_usage("_latest").await); let fr = google_to_openai_finish_reason(mitm.as_ref().and_then(|u| u.stop_reason.as_deref())); yield Ok(Event::default().data(chunk_json( &completion_id, &model_name, serde_json::json!([chunk_choice(0, serde_json::json!({}), Some(fr))]), None, ))); if include_usage { let (pt, ct, crt, tt) = if let Some(ref u) = mitm { (u.input_tokens, u.output_tokens, u.cache_read_input_tokens, u.thinking_output_tokens) } else { (0, 0, 0, 0) }; yield Ok(Event::default().data(chunk_json( &completion_id, &model_name, serde_json::json!([]), Some(build_usage(pt, ct, crt, tt)), ))); } yield Ok(Event::default().data("[DONE]")); return; } } } } } } } } // Keep-alive comment every ~5 iterations keepalive_counter += 1; if keepalive_counter.is_multiple_of(5) { yield Ok(Event::default().comment("keepalive")); } // Fast poll — 300ms so we pick up MITM captures quickly let poll_ms: u64 = rand::thread_rng().gen_range(250..400); tokio::time::sleep(tokio::time::Duration::from_millis(poll_ms)).await; } // Timeout — emit error, not placeholder content warn!("Completions stream timeout after {}s", timeout); yield Ok(Event::default().data(serde_json::to_string(&serde_json::json!({ "error": { "message": format!("Timeout: no response from Google API after {timeout}s"), "type": "upstream_error", "code": 504, } })).unwrap())); // Always clear in-flight flag when stream ends state.mitm_store.remove_request(&cascade_id).await; yield Ok(Event::default().data("[DONE]")); }; Sse::new(stream) .keep_alive( axum::response::sse::KeepAlive::new() .interval(std::time::Duration::from_secs(15)) .text(""), ) .into_response() } // ─── Sync ──────────────────────────────────────────────────────────────────── /// Sync output in Chat Completions format. async fn chat_completions_sync( state: Arc, completion_id: String, model_name: String, cascade_id: String, timeout: u64, trace: Option, ) -> axum::response::Response { let result = poll_for_response(&state, &cascade_id, timeout).await; if let Some(ref err) = result.upstream_error { return upstream_err_response(err); } // Check MITM store first for real intercepted usage (fallback to _latest) let mitm = match state.mitm_store.take_usage(&cascade_id).await { Some(u) => Some(u), None => state.mitm_store.take_usage("_latest").await, }; let finish_reason = google_to_openai_finish_reason(mitm.as_ref().and_then(|u| u.stop_reason.as_deref())); let (prompt_tokens, completion_tokens, cached_tokens, thinking_tokens) = if let Some(ref mitm_usage) = mitm { ( mitm_usage.input_tokens, mitm_usage.output_tokens, mitm_usage.cache_read_input_tokens, mitm_usage.thinking_output_tokens, ) } else if let Some(u) = &result.usage { (u.input_tokens, u.output_tokens, 0, 0) } else { (0, 0, 0, 0) }; // Build message object, including reasoning_content if thinking is present let mut message = serde_json::json!({ "role": "assistant", "content": result.text, }); if let Some(ref thinking) = result.thinking { message["reasoning_content"] = serde_json::json!(thinking); } // Record trace data if let Some(ref t) = trace { t.record_response( 0, crate::trace::ResponseSummary { text_len: result.text.len(), thinking_len: result.thinking.as_ref().map_or(0, |s| s.len()), text_preview: result.text.chars().take(200).collect(), finish_reason: Some(finish_reason.to_string()), function_calls: Vec::new(), grounding: false, }, ) .await; if prompt_tokens > 0 || completion_tokens > 0 { t.set_usage(crate::trace::TrackedUsage { input_tokens: prompt_tokens, output_tokens: completion_tokens, thinking_tokens, cache_read: cached_tokens, }) .await; } t.finish("completed").await; } Json(serde_json::json!({ "id": completion_id, "object": "chat.completion", "created": now_unix(), "model": model_name, "system_fingerprint": system_fingerprint(), "service_tier": "default", "choices": [{ "index": 0, "message": message, "logprobs": serde_json::Value::Null, "finish_reason": finish_reason, }], "usage": { "prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "total_tokens": prompt_tokens + completion_tokens, "prompt_tokens_details": { "cached_tokens": cached_tokens, }, "completion_tokens_details": { "reasoning_tokens": thinking_tokens, }, }, })) .into_response() }