Files
zerogravity/src/api/completions.rs
Nikketryhard 48674f65da refactor: decompose large functions and remove dead code
- Decompose modify_request() into 7 single-responsibility helpers
- Decompose handle_http_over_tls(): extract read_full_request, dispatch_stream_events
- Promote connect_upstream/resolve_upstream to module-level functions
- Split standalone.rs (1238 lines) into 4 submodules:
  standalone/mod.rs, spawn.rs, discovery.rs, stub.rs
- Extract proto wire primitives into proto/wire.rs
- Remove 6 dead MitmStore methods
- Remove dead SessionResult, DEFAULT_SESSION, get_or_create
- Remove dead decode_varint_at, extract_conversation_id
- Clean all unused imports across 10 files
- Suppress structural dead_code warnings on deserialization fields

Warnings: 20 -> 0. All 43 tests pass.
2026-02-17 22:27:26 -06:00

1114 lines
49 KiB
Rust

//! OpenAI Chat Completions API (/v1/chat/completions) handler.
use axum::{
extract::State,
http::StatusCode,
response::{sse::Event, IntoResponse, Json, Sse},
};
use rand::Rng;
use std::sync::Arc;
use tracing::{debug, info, warn};
use super::models::{lookup_model, DEFAULT_MODEL, MODELS};
use super::polling::{
extract_response_text, extract_thinking_content, is_response_done, poll_for_response,
};
use super::types::*;
use super::util::{err_response, now_unix, upstream_err_response};
use super::AppState;
use crate::mitm::store::{CapturedFunctionCall, PendingToolResult, ToolRound};
/// System fingerprint for completions responses (derived from crate version at compile time).
fn system_fingerprint() -> String {
format!("fp_{}", env!("CARGO_PKG_VERSION").replace('.', ""))
}
/// Build a streaming chunk JSON with all required OpenAI fields.
/// Includes system_fingerprint, service_tier, and logprobs:null in choices.
fn chunk_json(
id: &str,
model: &str,
choices: serde_json::Value,
usage: Option<serde_json::Value>,
) -> String {
let mut obj = serde_json::json!({
"id": id,
"object": "chat.completion.chunk",
"created": now_unix(),
"model": model,
"system_fingerprint": system_fingerprint(),
"service_tier": "default",
"choices": choices,
});
if let Some(u) = usage {
obj["usage"] = u;
}
serde_json::to_string(&obj).unwrap_or_default()
}
/// Build a single choice for a streaming chunk (delta + finish_reason + logprobs).
fn chunk_choice(
index: u32,
delta: serde_json::Value,
finish_reason: Option<&str>,
) -> serde_json::Value {
serde_json::json!({
"index": index,
"delta": delta,
"logprobs": serde_json::Value::Null,
"finish_reason": finish_reason,
})
}
// ─── Finish reason mapping ───────────────────────────────────────────────────
/// Map Google's finishReason → OpenAI's finish_reason.
/// Google: STOP, MAX_TOKENS, SAFETY, RECITATION, OTHER, BLOCKLIST, PROHIBITED_CONTENT
/// OpenAI: stop, length, content_filter, tool_calls (handled separately)
fn google_to_openai_finish_reason(stop_reason: Option<&str>) -> &'static str {
match stop_reason {
Some("MAX_TOKENS") => "length",
Some("SAFETY") | Some("RECITATION") | Some("BLOCKLIST") | Some("PROHIBITED_CONTENT") => {
"content_filter"
}
_ => "stop",
}
}
// ─── Input extraction ────────────────────────────────────────────────────────
/// Extract user text from Chat Completions messages array.
///
/// Builds the full conversation context including all messages (system, user,
/// assistant, tool) so the model has complete history — matching how OpenAI
/// sends the entire messages array to the model.
fn extract_chat_input(messages: &[CompletionMessage]) -> (String, Option<crate::proto::ImageData>) {
// Extract image from last user message content array
let image = messages
.iter()
.rev()
.find(|m| m.role == "user")
.and_then(|m| super::util::extract_first_image(&m.content));
// Always build the full conversation
(build_conversation_with_tools(messages), image)
}
/// Extract text content from a message's content field (string or array).
fn extract_message_text(content: &serde_json::Value) -> String {
match content {
serde_json::Value::String(s) => s.clone(),
serde_json::Value::Array(arr) => arr
.iter()
.filter_map(|item| item["text"].as_str())
.collect::<Vec<_>>()
.join("\n"),
_ => String::new(),
}
}
/// Build conversation text that includes tool call results.
///
/// Format:
/// [system prompt]
/// [user message]
/// [assistant called tool X with args Y]
/// [tool result: Z]
/// [user followup if any]
fn build_conversation_with_tools(messages: &[CompletionMessage]) -> String {
let mut parts = Vec::new();
for msg in messages {
match msg.role.as_str() {
"system" | "developer" => {
let text = extract_message_text(&msg.content);
if !text.is_empty() {
parts.push(text);
}
}
"user" => {
let text = extract_message_text(&msg.content);
if !text.is_empty() {
parts.push(text);
}
}
"assistant" => {
// Include assistant text if any
let text = extract_message_text(&msg.content);
if !text.is_empty() {
parts.push(text);
}
// Include tool calls as context
if let Some(ref tool_calls) = msg.tool_calls {
for tc in tool_calls {
if let Some(func) = tc.get("function") {
let name = func["name"].as_str().unwrap_or("unknown");
let args = func["arguments"].as_str().unwrap_or("{}");
parts.push(format!("[Tool call: {}({})]", name, args));
}
}
}
}
"tool" => {
let text = extract_message_text(&msg.content);
let tool_id = msg.tool_call_id.as_deref().unwrap_or("unknown");
if !text.is_empty() {
parts.push(format!("[Tool result ({})]:\n{}", tool_id, text));
}
}
_ => {}
}
}
parts.join("\n\n")
}
// ─── Handler ─────────────────────────────────────────────────────────────────
/// POST /v1/chat/completions — OpenAI Chat Completions API compatibility shim.
/// Accepts standard messages format, reuses the same backend cascade, and
/// outputs in the Chat Completions streaming/sync format.
pub(crate) async fn handle_completions(
State(state): State<Arc<AppState>>,
Json(body): Json<CompletionRequest>,
) -> axum::response::Response {
let model_name = body.model.as_deref().unwrap_or(DEFAULT_MODEL);
info!(
"POST /v1/chat/completions model={} stream={}",
model_name, body.stream
);
let model = match lookup_model(model_name) {
Some(m) => m,
None => {
let names: Vec<&str> = MODELS.iter().map(|m| m.name).collect();
return err_response(
StatusCode::BAD_REQUEST,
format!("Unknown model: {model_name}. Available: {names:?}"),
"invalid_request_error",
);
}
};
// ── Build per-request state locally ──────────────────────────────────
// Convert OpenAI tools to Gemini format
let tools = body.tools.as_ref().and_then(|t| {
let gemini_tools = crate::mitm::modify::openai_tools_to_gemini(t);
if gemini_tools.is_empty() { None } else {
info!(count = t.len(), "Completions: client tools for MITM injection");
Some(gemini_tools)
}
});
let tool_config = body.tools.as_ref().and_then(|_| {
body.tool_choice.as_ref().map(|choice| {
crate::mitm::modify::openai_tool_choice_to_gemini(choice)
})
});
// ── Extract tool results from messages for MITM injection ──────────
// Build ToolRounds from message history: each round pairs assistant tool_calls
// with subsequent tool result messages. Local call_id_to_name mapping.
let mut tool_rounds: Vec<ToolRound> = Vec::new();
let mut call_id_to_name: std::collections::HashMap<String, String> = std::collections::HashMap::new();
{
let mut current_round: Option<ToolRound> = None;
for msg in &body.messages {
match msg.role.as_str() {
"assistant" => {
// Finalize any open round
if let Some(round) = current_round.take() {
if !round.calls.is_empty() {
tool_rounds.push(round);
}
}
// Start new round if this assistant has tool_calls
if let Some(ref tool_calls) = msg.tool_calls {
let mut calls = Vec::new();
for tc in tool_calls {
if let Some(func) = tc.get("function") {
let name = func["name"].as_str().unwrap_or("unknown").to_string();
let args_str = func["arguments"].as_str().unwrap_or("{}");
let args = serde_json::from_str::<serde_json::Value>(args_str)
.unwrap_or(serde_json::json!({}));
let call_id = tc["id"].as_str().unwrap_or("").to_string();
// Register call_id → name locally
if !call_id.is_empty() {
call_id_to_name.insert(call_id, name.clone());
}
calls.push(CapturedFunctionCall {
name,
args,
thought_signature: None,
captured_at: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
});
}
}
if !calls.is_empty() {
current_round = Some(ToolRound {
calls,
results: Vec::new(),
});
}
}
}
"tool" => {
let text = extract_message_text(&msg.content);
if let Some(ref call_id) = msg.tool_call_id {
let result_index = current_round
.as_ref()
.map(|r| r.results.len())
.unwrap_or(0);
let name = call_id_to_name
.get(call_id.as_str())
.cloned()
.unwrap_or_else(|| {
current_round
.as_ref()
.and_then(|r| r.calls.get(result_index))
.map(|fc| fc.name.clone())
.unwrap_or_else(|| "unknown_function".to_string())
});
let result_value = serde_json::from_str::<serde_json::Value>(&text)
.unwrap_or_else(|_| serde_json::json!({"result": text}));
if let Some(ref mut round) = current_round {
round.results.push(PendingToolResult {
name,
result: result_value,
});
}
}
}
_ => {
// Any other role (user, system) finalizes the current round
if let Some(round) = current_round.take() {
if !round.calls.is_empty() {
tool_rounds.push(round);
}
}
}
}
}
// Finalize last round
if let Some(round) = current_round.take() {
if !round.calls.is_empty() {
tool_rounds.push(round);
}
}
if !tool_rounds.is_empty() {
info!(
round_count = tool_rounds.len(),
calls = ?tool_rounds.iter().map(|r| r.calls.iter().map(|c| &c.name).collect::<Vec<_>>()).collect::<Vec<_>>(),
"Completions: {} tool round(s) for MITM history rewrite",
tool_rounds.len(),
);
// Merge thought_signatures from MITM-captured function calls.
// OpenAI format doesn't carry thought signatures, but Google requires
// them when injecting functionCall parts back into history.
let sigs = state.mitm_store.peek_thought_signatures().await;
if !sigs.is_empty() {
let mut merged = 0usize;
for round in &mut tool_rounds {
for fc in &mut round.calls {
if fc.thought_signature.is_none() {
if let Some(sig) = sigs.get(&fc.name) {
fc.thought_signature = Some(sig.clone());
merged += 1;
}
}
}
}
if merged > 0 {
info!(
merged_count = merged,
"Completions: merged {} thought_signature(s) from MITM capture",
merged,
);
}
}
}
}
// Build generation parameters locally
use crate::mitm::store::GenerationParams;
let (response_mime_type, response_schema) = match body.response_format.as_ref() {
Some(rf) => match rf.format_type.as_str() {
"json_object" | "json" => (Some("application/json".to_string()), None),
"json_schema" => {
let schema = rf.json_schema.as_ref().and_then(|js| js.schema.clone());
(Some("application/json".to_string()), schema)
}
_ => (None, None),
},
None => (None, None),
};
let gp = GenerationParams {
temperature: body.temperature,
top_p: body.top_p,
top_k: None,
max_output_tokens: body.max_tokens.or(body.max_completion_tokens),
stop_sequences: body.stop.clone().map(|s| s.into_vec()),
frequency_penalty: body.frequency_penalty,
presence_penalty: body.presence_penalty,
reasoning_effort: body.reasoning_effort.clone(),
response_mime_type,
response_schema,
google_search: body.web_search,
};
let generation_params = if gp.temperature.is_some()
|| gp.top_p.is_some()
|| gp.max_output_tokens.is_some()
|| gp.frequency_penalty.is_some()
|| gp.presence_penalty.is_some()
|| gp.reasoning_effort.is_some()
|| gp.stop_sequences.is_some()
|| gp.response_mime_type.is_some()
|| gp.response_schema.is_some()
|| gp.google_search
{
Some(gp)
} else {
None
};
let token = state.backend.oauth_token().await;
if token.is_empty() {
return err_response(
StatusCode::UNAUTHORIZED,
"No OAuth token. POST to /v1/token or set ANTIGRAVITY_OAUTH_TOKEN env var.".into(),
"authentication_error",
);
}
let (user_text, image) = extract_chat_input(&body.messages);
if user_text.is_empty() {
return err_response(
StatusCode::BAD_REQUEST,
"No user message found".to_string(),
"invalid_request_error",
);
}
let n = (body.n.max(1)).min(5); // Cap at 5 to prevent abuse
if n > 1 && body.stream {
warn!("n={n} requested with streaming — streaming only supports n=1, ignoring n");
}
// Always create a new cascade for every request
let cascade_id = match state.backend.create_cascade().await {
Ok(cid) => cid,
Err(e) => {
return err_response(
StatusCode::BAD_GATEWAY,
format!("StartCascade failed: {e}"),
"server_error",
);
}
};
// Image for MITM injection
let pending_image = image.as_ref().map(|img| {
use base64::Engine;
crate::mitm::store::PendingImage {
base64_data: base64::engine::general_purpose::STANDARD.encode(&img.data),
mime_type: img.mime_type.clone(),
}
});
// Get last calls from the latest tool round (if any) for proxy recording compat
let last_function_calls = tool_rounds.last()
.map(|r| r.calls.clone())
.unwrap_or_default();
// Build event channel for streaming
let has_custom_tools = tools.is_some();
let (mitm_rx, event_tx) = if has_custom_tools && body.stream {
let (tx, rx) = tokio::sync::mpsc::channel(64);
(Some(rx), Some(tx))
} else {
(None, None)
};
// Build pending tool results from latest round
let pending_tool_results = tool_rounds.last()
.map(|r| r.results.clone())
.unwrap_or_default();
// Register all per-request state atomically
state.mitm_store.register_request(crate::mitm::store::RequestContext {
cascade_id: cascade_id.clone(),
pending_user_text: user_text.clone(),
event_channel: event_tx,
generation_params,
pending_image,
tools,
tool_config,
pending_tool_results,
tool_rounds,
last_function_calls,
call_id_to_name,
created_at: std::time::Instant::now(),
}).await;
// Send REAL user text to LS
match state
.backend
.send_message_with_image(&cascade_id, &format!(".<cid:{}>", cascade_id), model.model_enum, image.as_ref())
.await
{
Ok((200, _)) => {
let bg = Arc::clone(&state.backend);
let cid = cascade_id.clone();
tokio::spawn(async move {
let _ = bg.update_annotations(&cid).await;
});
}
Ok((status, _)) => {
state.mitm_store.remove_request(&cascade_id).await;
return err_response(
StatusCode::BAD_GATEWAY,
format!("Backend returned {status}"),
"server_error",
);
}
Err(e) => {
state.mitm_store.remove_request(&cascade_id).await;
return err_response(
StatusCode::BAD_GATEWAY,
format!("Send failed: {e}"),
"server_error",
);
}
}
let completion_id = format!(
"chatcmpl-{}",
uuid::Uuid::new_v4().to_string().replace('-', "")
);
let include_usage = body
.stream_options
.as_ref()
.map_or(false, |o| o.include_usage);
if body.stream {
chat_completions_stream(
state,
completion_id,
model_name.to_string(),
cascade_id,
body.timeout,
include_usage,
mitm_rx,
)
.await
} else if n <= 1 {
chat_completions_sync(
state,
completion_id,
model_name.to_string(),
cascade_id,
body.timeout,
)
.await
} else {
// n > 1: fire additional (n-1) parallel cascades
let mut extra_cascade_ids = Vec::with_capacity((n - 1) as usize);
for _ in 1..n {
match state.backend.create_cascade().await {
Ok(cid) => {
// Send the same message on each extra cascade
match state
.backend
.send_message_with_image(&cid, &format!(".<cid:{}>", cid), model.model_enum, image.as_ref())
.await
{
Ok((200, _)) => {
let bg = Arc::clone(&state.backend);
let cid2 = cid.clone();
tokio::spawn(async move {
let _ = bg.update_annotations(&cid2).await;
});
extra_cascade_ids.push(cid);
}
_ => {} // Skip failed cascades
}
}
Err(_) => {} // Skip failed cascade creation
}
}
// Poll all cascades in parallel
let mut handles = Vec::with_capacity(n as usize);
let all_cascade_ids: Vec<String> = std::iter::once(cascade_id.clone())
.chain(extra_cascade_ids)
.collect();
for cid in &all_cascade_ids {
let st = Arc::clone(&state);
let cid = cid.clone();
let timeout = body.timeout;
handles.push(tokio::spawn(async move {
let result = poll_for_response(&st, &cid, timeout).await;
let mitm = match st.mitm_store.take_usage(&cid).await {
Some(u) => Some(u),
None => st.mitm_store.take_usage("_latest").await,
};
(result, mitm)
}));
}
let mut choices = Vec::with_capacity(n as usize);
let mut total_prompt = 0u64;
let mut total_completion = 0u64;
let mut total_cached = 0u64;
let mut total_thinking = 0u64;
for (i, handle) in handles.into_iter().enumerate() {
if let Ok((result, mitm)) = handle.await {
let finish_reason = google_to_openai_finish_reason(
mitm.as_ref().and_then(|u| u.stop_reason.as_deref()),
);
let (pt, ct, cached, thinking) = if let Some(ref mu) = mitm {
(
mu.input_tokens,
mu.output_tokens,
mu.cache_read_input_tokens,
mu.thinking_output_tokens,
)
} else if let Some(u) = &result.usage {
(u.input_tokens, u.output_tokens, 0, 0)
} else {
(0, 0, 0, 0)
};
total_prompt += pt;
total_completion += ct;
total_cached += cached;
total_thinking += thinking;
let mut message = serde_json::json!({
"role": "assistant",
"content": result.text,
});
if let Some(ref thinking_text) = result.thinking {
message["reasoning_content"] = serde_json::json!(thinking_text);
}
choices.push(serde_json::json!({
"index": i,
"message": message,
"logprobs": serde_json::Value::Null,
"finish_reason": finish_reason,
}));
}
}
Json(serde_json::json!({
"id": completion_id,
"object": "chat.completion",
"created": now_unix(),
"model": model_name,
"system_fingerprint": system_fingerprint(),
"service_tier": "default",
"choices": choices,
"usage": {
"prompt_tokens": total_prompt,
"completion_tokens": total_completion,
"total_tokens": total_prompt + total_completion,
"prompt_tokens_details": {
"cached_tokens": total_cached,
},
"completion_tokens_details": {
"reasoning_tokens": total_thinking,
},
},
}))
.into_response()
}
}
// ─── Streaming ───────────────────────────────────────────────────────────────
/// Streaming output in Chat Completions format.
async fn chat_completions_stream(
state: Arc<AppState>,
completion_id: String,
model_name: String,
cascade_id: String,
timeout: u64,
include_usage: bool,
mitm_rx: Option<tokio::sync::mpsc::Receiver<crate::mitm::store::MitmEvent>>,
) -> axum::response::Response {
let stream = async_stream::stream! {
let start = std::time::Instant::now();
let mut last_text = String::new();
let has_custom_tools = mitm_rx.is_some();
if !has_custom_tools {
state.mitm_store.clear_response_async().await;
state.mitm_store.clear_upstream_error().await;
let _ = state.mitm_store.take_any_function_calls().await;
}
// Initial role chunk
yield Ok::<_, std::convert::Infallible>(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([chunk_choice(0, serde_json::json!({"role": "assistant", "content": ""}), None)]),
None,
)));
let mut keepalive_counter: u64 = 0;
let mut last_thinking_len: usize = 0;
let mut complete_polls: u32 = 0;
let mut did_unblock_ls = false; // Prevents infinite unblock loops
// Helper: build usage JSON from MITM tokens
let build_usage = |pt: u64, ct: u64, crt: u64, tt: u64| -> serde_json::Value {
serde_json::json!({
"prompt_tokens": pt,
"completion_tokens": ct,
"total_tokens": pt + ct,
"prompt_tokens_details": { "cached_tokens": crt },
"completion_tokens_details": { "reasoning_tokens": tt },
})
};
// Take ownership of the pre-installed channel receiver
let mut rx_opt = mitm_rx;
while start.elapsed().as_secs() < timeout {
if let Some(ref mut rx) = rx_opt {
// ── Channel-based MITM pipeline ──
// Track accumulated text for delta computation
let mut acc_text = String::new();
let mut acc_thinking = String::new();
let mut last_usage: Option<crate::mitm::store::ApiUsage> = None;
'channel_loop: while let Some(event) = tokio::time::timeout(
std::time::Duration::from_secs(timeout.saturating_sub(start.elapsed().as_secs())),
rx.recv(),
).await.ok().flatten() {
use crate::mitm::store::MitmEvent;
match event {
MitmEvent::ThinkingDelta(full_thinking) => {
if full_thinking.len() > acc_thinking.len() {
let delta = full_thinking[acc_thinking.len()..].to_string();
acc_thinking = full_thinking;
last_thinking_len = acc_thinking.len();
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([chunk_choice(0, serde_json::json!({"reasoning_content": delta}), None)]),
None,
)));
}
}
MitmEvent::TextDelta(full_text) => {
if full_text.len() > acc_text.len() {
let delta = full_text[acc_text.len()..].to_string();
acc_text = full_text;
last_text = acc_text.clone();
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([chunk_choice(0, serde_json::json!({"content": delta}), None)]),
None,
)));
}
}
MitmEvent::FunctionCall(calls) => {
let mut tool_calls = Vec::new();
for (i, fc) in calls.iter().enumerate() {
let call_id = format!(
"call_{}",
uuid::Uuid::new_v4().to_string().replace('-', "")[..24].to_string()
);
let arguments = serde_json::to_string(&fc.args).unwrap_or_default();
tool_calls.push(serde_json::json!({
"index": i,
"id": call_id,
"type": "function",
"function": {
"name": fc.name,
"arguments": arguments,
},
}));
}
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([chunk_choice(0, serde_json::json!({"tool_calls": tool_calls}), None)]),
None,
)));
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([chunk_choice(0, serde_json::json!({}), Some("tool_calls"))]),
None,
)));
if include_usage {
let mitm = state.mitm_store.take_usage(&cascade_id).await
.or(state.mitm_store.take_usage("_latest").await);
let (pt, ct, crt, tt) = if let Some(ref u) = mitm {
(u.input_tokens, u.output_tokens, u.cache_read_input_tokens, u.thinking_output_tokens)
} else if let Some(ref u) = last_usage {
(u.input_tokens, u.output_tokens, u.cache_read_input_tokens, u.thinking_output_tokens)
} else { (0, 0, 0, 0) };
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([]),
Some(build_usage(pt, ct, crt, tt)),
)));
}
yield Ok(Event::default().data("[DONE]"));
state.mitm_store.remove_request(&cascade_id).await;
return;
}
MitmEvent::ResponseComplete => {
if !acc_text.is_empty() {
// Have response text — done
debug!("Completions: channel response complete, text_len={}, thinking_len={}",
acc_text.len(), acc_thinking.len());
let mitm = state.mitm_store.take_usage(&cascade_id).await
.or(state.mitm_store.take_usage("_latest").await)
.or(last_usage.take());
let fr = google_to_openai_finish_reason(mitm.as_ref().and_then(|u| u.stop_reason.as_deref()));
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([chunk_choice(0, serde_json::json!({}), Some(fr))]),
None,
)));
if include_usage {
let (pt, ct, crt, tt) = if let Some(ref u) = mitm {
(u.input_tokens, u.output_tokens, u.cache_read_input_tokens, u.thinking_output_tokens)
} else { (0, 0, 0, 0) };
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([]),
Some(build_usage(pt, ct, crt, tt)),
)));
}
yield Ok(Event::default().data("[DONE]"));
state.mitm_store.remove_request(&cascade_id).await;
return;
} else if !acc_thinking.is_empty() && !did_unblock_ls {
// Thinking-only response — LS needs follow-up API calls.
// Create a new channel and unblock the gate.
did_unblock_ls = true;
let (new_tx, new_rx) = tokio::sync::mpsc::channel(64);
state.mitm_store.set_channel(&cascade_id, new_tx).await;
let _ = state.mitm_store.take_any_function_calls().await;
*rx = new_rx;
debug!(
"Completions: thinking-only — new channel for follow-up, thinking_len={}",
acc_thinking.len()
);
continue 'channel_loop;
} else if !acc_thinking.is_empty() && did_unblock_ls {
// Already unblocked once, still thinking-only.
// Wait a bit for potential follow-up events.
complete_polls += 1;
if complete_polls >= 25 {
info!("Completions: thinking-only timeout, thinking_len={}", acc_thinking.len());
let mitm = state.mitm_store.take_usage(&cascade_id).await
.or(state.mitm_store.take_usage("_latest").await)
.or(last_usage.take());
let fr = google_to_openai_finish_reason(mitm.as_ref().and_then(|u| u.stop_reason.as_deref()));
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([chunk_choice(0, serde_json::json!({}), Some(fr))]),
None,
)));
if include_usage {
let (pt, ct, crt, tt) = if let Some(ref u) = mitm {
(u.input_tokens, u.output_tokens, u.cache_read_input_tokens, u.thinking_output_tokens)
} else { (0, 0, 0, 0) };
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([]),
Some(build_usage(pt, ct, crt, tt)),
)));
}
yield Ok(Event::default().data("[DONE]"));
state.mitm_store.remove_request(&cascade_id).await;
return;
}
// Don't break — wait for more channel events
continue 'channel_loop;
} else {
// Empty response (no text, no thinking, no tools)
complete_polls += 1;
if complete_polls >= 4 {
info!("Completions: channel response complete but empty, ending stream");
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([chunk_choice(0, serde_json::json!({}), Some("stop"))]),
None,
)));
yield Ok(Event::default().data("[DONE]"));
state.mitm_store.remove_request(&cascade_id).await;
return;
}
continue 'channel_loop;
}
}
MitmEvent::UpstreamError(err) => {
let error_msg = super::util::upstream_error_message(&err);
let error_type = super::util::upstream_error_type(&err);
yield Ok(Event::default().data(serde_json::to_string(&serde_json::json!({
"error": {
"message": error_msg,
"type": error_type,
"code": err.status,
}
})).unwrap()));
yield Ok(Event::default().data("[DONE]"));
state.mitm_store.remove_request(&cascade_id).await;
return;
}
MitmEvent::Usage(u) => {
last_usage = Some(u);
}
MitmEvent::Grounding(_) => {
// Grounding metadata handled by store directly
}
}
}
// Channel closed or timeout — clean up
state.mitm_store.remove_request(&cascade_id).await;
// If we got here from timeout with content, emit what we have
if !last_text.is_empty() || last_thinking_len > 0 {
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([chunk_choice(0, serde_json::json!({}), Some("stop"))]),
None,
)));
}
yield Ok(Event::default().data("[DONE]"));
return;
} else {
// ── Fallback: LS steps (no MITM capture active) ──
if let Ok((status, data)) = state.backend.get_steps(&cascade_id).await {
if status == 200 {
if let Some(steps) = data["steps"].as_array() {
// Stream thinking deltas (reasoning_content)
if let Some(tc) = extract_thinking_content(steps) {
if tc.len() > last_thinking_len {
let delta = &tc[last_thinking_len..];
last_thinking_len = tc.len();
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([chunk_choice(0, serde_json::json!({"reasoning_content": delta}), None)]),
None,
)));
}
}
let text = extract_response_text(steps);
if !text.is_empty() && text != last_text {
let delta = if text.len() > last_text.len() && text.starts_with(&*last_text) {
&text[last_text.len()..]
} else {
&text
};
if !delta.is_empty() {
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([chunk_choice(0, serde_json::json!({"content": delta}), None)]),
None,
)));
last_text = text.to_string();
}
}
// Done check
let has_content = !last_text.is_empty() || last_thinking_len > 0;
if is_response_done(steps) && has_content {
debug!("Completions stream done, text length={}, thinking_len={}", last_text.len(), last_thinking_len);
let mitm = state.mitm_store.take_usage(&cascade_id).await
.or(state.mitm_store.take_usage("_latest").await);
let fr = google_to_openai_finish_reason(mitm.as_ref().and_then(|u| u.stop_reason.as_deref()));
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([chunk_choice(0, serde_json::json!({}), Some(fr))]),
None,
)));
if include_usage {
let (pt, ct, crt, tt) = if let Some(ref u) = mitm {
(u.input_tokens, u.output_tokens, u.cache_read_input_tokens, u.thinking_output_tokens)
} else { (0, 0, 0, 0) };
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([]),
Some(build_usage(pt, ct, crt, tt)),
)));
}
yield Ok(Event::default().data("[DONE]"));
return;
}
// IDLE fallback
let step_count = steps.len();
if step_count > 4 && step_count % 5 == 0 {
if let Ok((ts, td)) = state.backend.get_trajectory(&cascade_id).await {
if ts == 200 {
let run_status = td["status"].as_str().unwrap_or("");
let has_content_idle = !last_text.is_empty() || last_thinking_len > 0;
if run_status.contains("IDLE") && has_content_idle {
debug!("Completions IDLE, text length={}, thinking_len={}", last_text.len(), last_thinking_len);
let mitm = state.mitm_store.take_usage(&cascade_id).await
.or(state.mitm_store.take_usage("_latest").await);
let fr = google_to_openai_finish_reason(mitm.as_ref().and_then(|u| u.stop_reason.as_deref()));
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([chunk_choice(0, serde_json::json!({}), Some(fr))]),
None,
)));
if include_usage {
let (pt, ct, crt, tt) = if let Some(ref u) = mitm {
(u.input_tokens, u.output_tokens, u.cache_read_input_tokens, u.thinking_output_tokens)
} else { (0, 0, 0, 0) };
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([]),
Some(build_usage(pt, ct, crt, tt)),
)));
}
yield Ok(Event::default().data("[DONE]"));
return;
}
}
}
}
}
}
}
}
// Keep-alive comment every ~5 iterations
keepalive_counter += 1;
if keepalive_counter % 5 == 0 {
yield Ok(Event::default().comment("keepalive"));
}
// Fast poll — 300ms so we pick up MITM captures quickly
let poll_ms: u64 = rand::thread_rng().gen_range(250..400);
tokio::time::sleep(tokio::time::Duration::from_millis(poll_ms)).await;
}
// Timeout — emit error, not placeholder content
warn!("Completions stream timeout after {}s", timeout);
yield Ok(Event::default().data(serde_json::to_string(&serde_json::json!({
"error": {
"message": format!("Timeout: no response from Google API after {timeout}s"),
"type": "upstream_error",
"code": 504,
}
})).unwrap()));
// Always clear in-flight flag when stream ends
state.mitm_store.remove_request(&cascade_id).await;
yield Ok(Event::default().data("[DONE]"));
};
Sse::new(stream)
.keep_alive(
axum::response::sse::KeepAlive::new()
.interval(std::time::Duration::from_secs(15))
.text(""),
)
.into_response()
}
// ─── Sync ────────────────────────────────────────────────────────────────────
/// Sync output in Chat Completions format.
async fn chat_completions_sync(
state: Arc<AppState>,
completion_id: String,
model_name: String,
cascade_id: String,
timeout: u64,
) -> axum::response::Response {
let result = poll_for_response(&state, &cascade_id, timeout).await;
if let Some(ref err) = result.upstream_error {
return upstream_err_response(err);
}
// Check MITM store first for real intercepted usage (fallback to _latest)
let mitm = match state.mitm_store.take_usage(&cascade_id).await {
Some(u) => Some(u),
None => state.mitm_store.take_usage("_latest").await,
};
let finish_reason =
google_to_openai_finish_reason(mitm.as_ref().and_then(|u| u.stop_reason.as_deref()));
let (prompt_tokens, completion_tokens, cached_tokens, thinking_tokens) =
if let Some(ref mitm_usage) = mitm {
(
mitm_usage.input_tokens,
mitm_usage.output_tokens,
mitm_usage.cache_read_input_tokens,
mitm_usage.thinking_output_tokens,
)
} else if let Some(u) = &result.usage {
(u.input_tokens, u.output_tokens, 0, 0)
} else {
(0, 0, 0, 0)
};
// Build message object, including reasoning_content if thinking is present
let mut message = serde_json::json!({
"role": "assistant",
"content": result.text,
});
if let Some(ref thinking) = result.thinking {
message["reasoning_content"] = serde_json::json!(thinking);
}
Json(serde_json::json!({
"id": completion_id,
"object": "chat.completion",
"created": now_unix(),
"model": model_name,
"system_fingerprint": system_fingerprint(),
"service_tier": "default",
"choices": [{
"index": 0,
"message": message,
"logprobs": serde_json::Value::Null,
"finish_reason": finish_reason,
}],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
"prompt_tokens_details": {
"cached_tokens": cached_tokens,
},
"completion_tokens_details": {
"reasoning_tokens": thinking_tokens,
},
},
}))
.into_response()
}