//! OpenAI Chat Completions API (/v1/chat/completions) handler. use axum::{ extract::State, http::StatusCode, response::{sse::Event, IntoResponse, Json, Sse}, }; use rand::Rng; use std::sync::Arc; use tracing::{debug, info, warn}; use super::models::{lookup_model, DEFAULT_MODEL, MODELS}; use super::polling::{extract_response_text, is_response_done, poll_for_response}; use super::types::*; use super::util::{err_response, now_unix}; use super::AppState; // ─── Input extraction ──────────────────────────────────────────────────────── /// Extract user text from Chat Completions messages array. fn extract_chat_input(messages: &[CompletionMessage]) -> String { let mut system_parts = Vec::new(); let mut user_parts = Vec::new(); for msg in messages { let text = match &msg.content { serde_json::Value::String(s) => s.clone(), serde_json::Value::Array(arr) => arr .iter() .filter_map(|item| item["text"].as_str()) .collect::>() .join("\n"), _ => continue, }; match msg.role.as_str() { "system" | "developer" => system_parts.push(text), "user" => user_parts.push(text), _ => {} } } let mut result = String::new(); if !system_parts.is_empty() { result.push_str(&system_parts.join("\n")); result.push_str("\n\n"); } // Use the last user message if let Some(last) = user_parts.last() { result.push_str(last); } result.trim().to_string() } // ─── Handler ───────────────────────────────────────────────────────────────── /// POST /v1/chat/completions — OpenAI Chat Completions API compatibility shim. /// Accepts standard messages format, reuses the same backend cascade, and /// outputs in the Chat Completions streaming/sync format. pub(crate) async fn handle_completions( State(state): State>, Json(body): Json, ) -> axum::response::Response { let model_name = body.model.as_deref().unwrap_or(DEFAULT_MODEL); info!( "POST /v1/chat/completions model={} stream={}", model_name, body.stream ); let model = match lookup_model(model_name) { Some(m) => m, None => { let names: Vec<&str> = MODELS.iter().map(|m| m.name).collect(); return err_response( StatusCode::BAD_REQUEST, format!("Unknown model: {model_name}. Available: {names:?}"), "invalid_request_error", ); } }; // Store client tools from this request (or clear stale ones from other endpoints) if let Some(ref tools) = body.tools { let gemini_tools = crate::mitm::modify::openai_tools_to_gemini(tools); if !gemini_tools.is_empty() { state.mitm_store.set_tools(gemini_tools).await; if let Some(ref choice) = body.tool_choice { let gemini_config = crate::mitm::modify::openai_tool_choice_to_gemini(choice); state.mitm_store.set_tool_config(gemini_config).await; } info!(count = tools.len(), "Completions: stored client tools for MITM injection"); } else { state.mitm_store.clear_tools().await; } } else { state.mitm_store.clear_tools().await; } state.mitm_store.clear_active_function_call(); let token = state.backend.oauth_token().await; if token.is_empty() { return err_response( StatusCode::UNAUTHORIZED, "No OAuth token. POST to /v1/token or set ANTIGRAVITY_OAUTH_TOKEN env var.".into(), "authentication_error", ); } let user_text = extract_chat_input(&body.messages); if user_text.is_empty() { return err_response( StatusCode::BAD_REQUEST, "No user message found".to_string(), "invalid_request_error", ); } // Fresh cascade per request let cascade_id = match state.backend.create_cascade().await { Ok(cid) => cid, Err(e) => { return err_response( StatusCode::BAD_GATEWAY, format!("StartCascade failed: {e}"), "server_error", ); } }; // Send message match state .backend .send_message(&cascade_id, &user_text, model.model_enum) .await { Ok((200, _)) => { let bg = Arc::clone(&state.backend); let cid = cascade_id.clone(); tokio::spawn(async move { let _ = bg.update_annotations(&cid).await; }); } Ok((status, _)) => { return err_response( StatusCode::BAD_GATEWAY, format!("Backend returned {status}"), "server_error", ); } Err(e) => { return err_response( StatusCode::BAD_GATEWAY, format!("Send failed: {e}"), "server_error", ); } } let completion_id = format!( "chatcmpl-{}", uuid::Uuid::new_v4().to_string().replace('-', "") ); if body.stream { chat_completions_stream( state, completion_id, model_name.to_string(), cascade_id, body.timeout, ) .await } else { chat_completions_sync( state, completion_id, model_name.to_string(), cascade_id, body.timeout, ) .await } } // ─── Streaming ─────────────────────────────────────────────────────────────── /// Streaming output in Chat Completions format. async fn chat_completions_stream( state: Arc, completion_id: String, model_name: String, cascade_id: String, timeout: u64, ) -> axum::response::Response { let stream = async_stream::stream! { let start = std::time::Instant::now(); let mut last_text = String::new(); // Initial role chunk yield Ok::<_, std::convert::Infallible>(Event::default().data(serde_json::to_string(&serde_json::json!({ "id": completion_id, "object": "chat.completion.chunk", "created": now_unix(), "model": model_name, "choices": [{ "index": 0, "delta": {"role": "assistant", "content": ""}, "finish_reason": serde_json::Value::Null, }], })).unwrap_or_default())); while start.elapsed().as_secs() < timeout { if let Ok((status, data)) = state.backend.get_steps(&cascade_id).await { if status == 200 { if let Some(steps) = data["steps"].as_array() { // Check for MITM-captured function calls FIRST (before text) // This prevents dummy placeholder text from leaking to client let captured = state.mitm_store.take_any_function_calls().await; if let Some(ref calls) = captured { if !calls.is_empty() { // Emit tool_calls in OpenAI streaming format — NO text let mut tool_calls = Vec::new(); for (i, fc) in calls.iter().enumerate() { let call_id = format!( "call_{}", uuid::Uuid::new_v4().to_string().replace('-', "")[..24].to_string() ); let arguments = serde_json::to_string(&fc.args).unwrap_or_default(); tool_calls.push(serde_json::json!({ "index": i, "id": call_id, "type": "function", "function": { "name": fc.name, "arguments": arguments, }, })); } yield Ok(Event::default().data(serde_json::to_string(&serde_json::json!({ "id": completion_id, "object": "chat.completion.chunk", "created": now_unix(), "model": model_name, "choices": [{ "index": 0, "delta": {"tool_calls": tool_calls}, "finish_reason": serde_json::Value::Null, }], })).unwrap_or_default())); // Finish with tool_calls reason yield Ok(Event::default().data(serde_json::to_string(&serde_json::json!({ "id": completion_id, "object": "chat.completion.chunk", "created": now_unix(), "model": model_name, "choices": [{ "index": 0, "delta": {}, "finish_reason": "tool_calls", }], })).unwrap_or_default())); yield Ok(Event::default().data("[DONE]")); return; } } // Normal text streaming (only when no function calls) let text = extract_response_text(steps); if !text.is_empty() && text != last_text { let delta = if text.len() > last_text.len() && text.starts_with(&*last_text) { &text[last_text.len()..] } else { &text }; if !delta.is_empty() { yield Ok(Event::default().data(serde_json::to_string(&serde_json::json!({ "id": completion_id, "object": "chat.completion.chunk", "created": now_unix(), "model": model_name, "choices": [{ "index": 0, "delta": {"content": delta}, "finish_reason": serde_json::Value::Null, }], })).unwrap_or_default())); last_text = text.to_string(); } } // Done check: need DONE status AND non-empty text if is_response_done(steps) && !last_text.is_empty() { debug!("Completions stream done, text length={}", last_text.len()); yield Ok(Event::default().data(serde_json::to_string(&serde_json::json!({ "id": completion_id, "object": "chat.completion.chunk", "created": now_unix(), "model": model_name, "choices": [{ "index": 0, "delta": {}, "finish_reason": "stop", }], })).unwrap_or_default())); yield Ok(Event::default().data("[DONE]")); return; } // IDLE fallback: check trajectory status periodically // Only check every 5th step count to reduce backend traffic let step_count = steps.len(); if step_count > 4 && step_count % 5 == 0 { if let Ok((ts, td)) = state.backend.get_trajectory(&cascade_id).await { if ts == 200 { let run_status = td["status"].as_str().unwrap_or(""); if run_status.contains("IDLE") && !last_text.is_empty() { debug!("Completions IDLE, text length={}", last_text.len()); yield Ok(Event::default().data(serde_json::to_string(&serde_json::json!({ "id": completion_id, "object": "chat.completion.chunk", "created": now_unix(), "model": model_name, "choices": [{ "index": 0, "delta": {}, "finish_reason": "stop", }], })).unwrap_or_default())); yield Ok(Event::default().data("[DONE]")); return; } } } } } } } let poll_ms: u64 = rand::thread_rng().gen_range(800..1200); tokio::time::sleep(tokio::time::Duration::from_millis(poll_ms)).await; } // Timeout warn!("Completions stream timeout after {}s", timeout); yield Ok(Event::default().data(serde_json::to_string(&serde_json::json!({ "id": completion_id, "object": "chat.completion.chunk", "created": now_unix(), "model": model_name, "choices": [{ "index": 0, "delta": {"content": if last_text.is_empty() { "[Timeout waiting for response]" } else { "" }}, "finish_reason": "stop", }], })).unwrap_or_default())); yield Ok(Event::default().data("[DONE]")); }; Sse::new(stream) .keep_alive( axum::response::sse::KeepAlive::new() .interval(std::time::Duration::from_secs(15)) .text(""), ) .into_response() } // ─── Sync ──────────────────────────────────────────────────────────────────── /// Sync output in Chat Completions format. async fn chat_completions_sync( state: Arc, completion_id: String, model_name: String, cascade_id: String, timeout: u64, ) -> axum::response::Response { let result = poll_for_response(&state, &cascade_id, timeout).await; // Check MITM store first for real intercepted usage (fallback to _latest) let mitm = match state.mitm_store.take_usage(&cascade_id).await { Some(u) => Some(u), None => state.mitm_store.take_usage("_latest").await, }; let (prompt_tokens, completion_tokens, cached_tokens) = if let Some(mitm_usage) = mitm { (mitm_usage.input_tokens, mitm_usage.output_tokens, mitm_usage.cache_read_input_tokens) } else if let Some(u) = &result.usage { (u.input_tokens, u.output_tokens, 0) } else { (0, 0, 0) }; Json(serde_json::json!({ "id": completion_id, "object": "chat.completion", "created": now_unix(), "model": model_name, "choices": [{ "index": 0, "message": { "role": "assistant", "content": result.text, }, "finish_reason": "stop", }], "usage": { "prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "total_tokens": prompt_tokens + completion_tokens, "prompt_tokens_details": { "cached_tokens": cached_tokens, }, }, })) .into_response() }