//! OpenAI Responses API (/v1/responses) handler. //! //! Strictly adheres to the official OpenAI Responses API protocol: //! https://platform.openai.com/docs/api-reference/responses use axum::{ extract::State, http::StatusCode, response::{sse::Event, IntoResponse, Json, Sse}, }; use rand::Rng; use std::sync::atomic::{AtomicU32, Ordering}; use std::sync::Arc; use tracing::{debug, info}; use super::models::{lookup_model, DEFAULT_MODEL, MODELS}; use super::polling::{extract_response_text, is_response_done, poll_for_response, extract_model_usage, extract_thinking_signature, extract_thinking_content}; use super::types::*; use super::util::{err_response, now_unix, responses_sse_event}; use super::AppState; // ─── Input extraction ──────────────────────────────────────────────────────── /// Extract user text from Responses API `input` field. fn extract_responses_input(input: &serde_json::Value, instructions: Option<&str>) -> String { let user_text = match input { serde_json::Value::String(s) => s.clone(), serde_json::Value::Array(items) => { items .iter() .rev() .find(|item| item["role"].as_str() == Some("user")) .and_then(|item| match &item["content"] { serde_json::Value::String(s) => Some(s.clone()), serde_json::Value::Array(parts) => Some( parts .iter() .filter(|p| { let t = p["type"].as_str().unwrap_or(""); t == "input_text" || t == "text" }) .filter_map(|p| p["text"].as_str()) .collect::>() .join(" "), ), _ => None, }) .unwrap_or_default() } _ => String::new(), }; match instructions { Some(inst) if !inst.is_empty() => format!("{inst}\n\n{user_text}"), _ => user_text, } } /// Extract conversation/session ID from Responses API `conversation` field. fn extract_conversation_id(conv: &Option) -> Option { match conv { Some(serde_json::Value::String(s)) => Some(s.clone()), Some(obj) => obj["id"].as_str().map(|s| s.to_string()), None => None, } } /// Response-specific data for building a Response object. struct ResponseData { id: String, model: String, status: &'static str, created_at: u64, completed_at: Option, output: Vec, usage: Option, thinking_signature: Option, } /// Build a full Response object matching the official OpenAI schema. fn build_response_object(data: ResponseData, params: &RequestParams) -> ResponsesResponse { ResponsesResponse { id: data.id, object: "response", created_at: data.created_at, status: data.status, completed_at: data.completed_at, error: None, incomplete_details: None, instructions: params.instructions.clone(), max_output_tokens: params.max_output_tokens, model: data.model, output: data.output, parallel_tool_calls: true, previous_response_id: params.previous_response_id.clone(), reasoning: Reasoning::default(), store: params.store, temperature: params.temperature, text: TextFormat::default(), tool_choice: "auto", tools: vec![], top_p: params.top_p, truncation: "disabled", usage: data.usage, user: params.user.clone(), metadata: params.metadata.clone(), thinking_signature: data.thinking_signature, } } /// Serialize a ResponsesResponse to serde_json::Value for SSE embedding. fn response_to_json(resp: &ResponsesResponse) -> serde_json::Value { serde_json::to_value(resp).unwrap_or(serde_json::json!({})) } // ─── Handler ───────────────────────────────────────────────────────────────── pub(crate) async fn handle_responses( State(state): State>, Json(body): Json, ) -> axum::response::Response { info!( "POST /v1/responses model={} stream={}", body.model.as_deref().unwrap_or(DEFAULT_MODEL), body.stream ); let model_name = body.model.as_deref().unwrap_or(DEFAULT_MODEL); let model = match lookup_model(model_name) { Some(m) => m, None => { let names: Vec<&str> = MODELS.iter().map(|m| m.name).collect(); return err_response( StatusCode::BAD_REQUEST, format!("Unknown model: {model_name}. Available: {names:?}"), "invalid_request_error", ); } }; let token = state.backend.oauth_token().await; if token.is_empty() { return err_response( StatusCode::UNAUTHORIZED, "No OAuth token. POST to /v1/token or set ANTIGRAVITY_OAUTH_TOKEN env var.".into(), "authentication_error", ); } let user_text = extract_responses_input(&body.input, body.instructions.as_deref()); if user_text.is_empty() { return err_response( StatusCode::BAD_REQUEST, "No user input found".to_string(), "invalid_request_error", ); } let response_id = format!( "resp_{}", uuid::Uuid::new_v4().to_string().replace('-', "") ); // Session/conversation management let session_id_str = extract_conversation_id(&body.conversation); let cascade_id = if let Some(ref sid) = session_id_str { match state .sessions .get_or_create(Some(sid), || state.backend.create_cascade()) .await { Ok(sr) => sr.cascade_id, Err(e) => { return err_response( StatusCode::BAD_GATEWAY, format!("StartCascade failed: {e}"), "server_error", ); } } } else { match state.backend.create_cascade().await { Ok(cid) => cid, Err(e) => { return err_response( StatusCode::BAD_GATEWAY, format!("StartCascade failed: {e}"), "server_error", ); } } }; // Send message match state .backend .send_message(&cascade_id, &user_text, model.model_enum) .await { Ok((200, _)) => { let bg = Arc::clone(&state.backend); let cid = cascade_id.clone(); tokio::spawn(async move { let _ = bg.update_annotations(&cid).await; }); } Ok((status, _)) => { return err_response( StatusCode::BAD_GATEWAY, format!("Antigravity returned {status}"), "server_error", ); } Err(e) => { return err_response( StatusCode::BAD_GATEWAY, format!("Send message failed: {e}"), "server_error", ); } } // Capture request params for response building let req_params = RequestParams { user_text: user_text.clone(), instructions: body.instructions.clone(), store: body.store, temperature: body.temperature.unwrap_or(1.0), top_p: body.top_p.unwrap_or(1.0), max_output_tokens: body.max_output_tokens, previous_response_id: body.previous_response_id.clone(), user: body.user.clone(), metadata: body.metadata.clone().unwrap_or(serde_json::json!({})), }; if body.stream { handle_responses_stream( state, response_id, model_name.to_string(), cascade_id, body.timeout, req_params, ) .await } else { handle_responses_sync( state, response_id, model_name.to_string(), cascade_id, body.timeout, req_params, ) .await } } /// Captured request parameters needed to echo back in the response. struct RequestParams { user_text: String, instructions: Option, store: bool, temperature: f64, top_p: f64, max_output_tokens: Option, previous_response_id: Option, user: Option, metadata: serde_json::Value, } /// Build Usage from the best available source, and extract thinking text from MITM: /// 1. MITM intercepted data (real API tokens, including cache stats + thinking text) /// 2. LS trajectory data (real tokens, no cache info) /// 3. Estimation from text lengths (fallback) /// /// Returns (Usage, Option). The LS strips thinking text from steps, /// so we capture it from the raw MITM-intercepted API response. async fn usage_from_poll( mitm_store: &crate::mitm::store::MitmStore, cascade_id: &str, model_usage: &Option, input_text: &str, output_text: &str, ) -> (Usage, Option) { // Priority 1: MITM intercepted data (most accurate — includes cache tokens + thinking text) // Try exact cascade_id match first, then fall back to "_latest" (unmatched). // // Race condition: The LS makes TWO Google API calls for thinking models: // Call 1: response + thinking token count (recorded first) // Call 2: thinking summary text (merged into Call 1 by the store) // We may read the usage after Call 1 but before Call 2 arrives. // If we see thinking tokens but no text, wait briefly for the merge. let keys_to_try: Vec<&str> = vec![cascade_id, "_latest"]; let mut mitm_usage = None; for key in &keys_to_try { if let Some(u) = mitm_store.peek_usage(key).await { if u.thinking_output_tokens > 0 && u.thinking_text.is_none() { // Call 2 hasn't arrived yet — wait briefly for the merge tracing::debug!("MITM: thinking tokens found but no text, waiting for summary merge..."); for _ in 0..10 { tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; if let Some(u2) = mitm_store.peek_usage(key).await { if u2.thinking_text.is_some() { break; } } } } mitm_usage = mitm_store.take_usage(key).await; break; } } if let Some(mitm_usage) = mitm_usage { tracing::debug!( input = mitm_usage.input_tokens, output = mitm_usage.output_tokens, cache_read = mitm_usage.cache_read_input_tokens, cache_create = mitm_usage.cache_creation_input_tokens, thinking = mitm_usage.thinking_output_tokens, thinking_text_len = mitm_usage.thinking_text.as_ref().map_or(0, |t| t.len()), "Using MITM intercepted usage" ); let thinking_text = mitm_usage.thinking_text; let usage = Usage { input_tokens: mitm_usage.input_tokens, input_tokens_details: InputTokensDetails { cached_tokens: mitm_usage.cache_read_input_tokens, }, output_tokens: mitm_usage.output_tokens, output_tokens_details: OutputTokensDetails { reasoning_tokens: mitm_usage.thinking_output_tokens, }, total_tokens: mitm_usage.input_tokens + mitm_usage.output_tokens, }; return (usage, thinking_text); } // Priority 2: LS trajectory data (from CHECKPOINT/metadata steps) if let Some(u) = model_usage { return (Usage { input_tokens: u.input_tokens, input_tokens_details: InputTokensDetails { cached_tokens: 0 }, output_tokens: u.output_tokens, output_tokens_details: OutputTokensDetails { reasoning_tokens: 0 }, total_tokens: u.input_tokens + u.output_tokens, }, None); } // Priority 3: Estimate from text lengths (Usage::estimate(input_text, output_text), None) } // ─── Sync response ─────────────────────────────────────────────────────────── async fn handle_responses_sync( state: Arc, response_id: String, model_name: String, cascade_id: String, timeout: u64, params: RequestParams, ) -> axum::response::Response { let created_at = now_unix(); let poll_result = poll_for_response(&state, &cascade_id, timeout).await; let completed_at = now_unix(); let msg_id = format!( "msg_{}", uuid::Uuid::new_v4().to_string().replace('-', "") ); let (usage, mitm_thinking) = usage_from_poll(&state.mitm_store, &cascade_id, &poll_result.usage, ¶ms.user_text, &poll_result.text).await; // Thinking text priority: MITM-captured (raw API) > LS-extracted (steps) let thinking_text = mitm_thinking.or(poll_result.thinking); // Build output array: [reasoning (if present), message] let mut output_items: Vec = Vec::new(); if let Some(ref thinking) = thinking_text { output_items.push(build_reasoning_output(thinking)); } output_items.push(build_message_output(&msg_id, &poll_result.text)); let resp = build_response_object( ResponseData { id: response_id, model: model_name, status: "completed", created_at, completed_at: Some(completed_at), output: output_items, usage: Some(usage), thinking_signature: poll_result.thinking_signature, }, ¶ms, ); Json(resp).into_response() } // ─── Streaming response ───────────────────────────────────────────────────── async fn handle_responses_stream( state: Arc, response_id: String, model_name: String, cascade_id: String, timeout: u64, params: RequestParams, ) -> axum::response::Response { let stream = async_stream::stream! { let msg_id = format!("msg_{}", uuid::Uuid::new_v4().to_string().replace('-', "")); let created_at = now_unix(); let seq = AtomicU32::new(0); let next_seq = || seq.fetch_add(1, Ordering::Relaxed); const CONTENT_IDX: u32 = 0; // Build the in-progress response shell (no output yet) let in_progress_resp = build_response_object( ResponseData { id: response_id.clone(), model: model_name.clone(), status: "in_progress", created_at, completed_at: None, output: vec![], usage: None, thinking_signature: None, }, ¶ms, ); let resp_json = response_to_json(&in_progress_resp); // 1. response.created yield Ok::<_, std::convert::Infallible>(responses_sse_event( "response.created", serde_json::json!({ "type": "response.created", "sequence_number": next_seq(), "response": resp_json, }), )); // 2. response.in_progress yield Ok(responses_sse_event( "response.in_progress", serde_json::json!({ "type": "response.in_progress", "sequence_number": next_seq(), "response": resp_json, }), )); // ── Stream cascade updates: event-driven instead of timer-based polling ── let start = std::time::Instant::now(); let mut last_text = String::new(); let mut thinking_emitted = false; let mut thinking_text: Option = None; let mut message_started = false; let reasoning_id = format!("rs_{}", uuid::Uuid::new_v4().to_string().replace('-', "")); // Try to open a reactive streaming connection for real-time notifications. // Falls back to timer-based polling if the streaming RPC is unavailable. let mut reactive_rx = match state.backend.stream_cascade_updates(&cascade_id).await { Ok(rx) => { debug!("Using reactive streaming for cascade updates"); Some(rx) } Err(e) => { debug!("Reactive streaming unavailable, falling back to polling: {e}"); None } }; while start.elapsed().as_secs() < timeout { if let Ok((status, data)) = state.backend.get_steps(&cascade_id).await { if status == 200 { if let Some(steps) = data["steps"].as_array() { // Check for thinking content (appears before response text) if !thinking_emitted { if let Some(tc) = extract_thinking_content(steps) { thinking_text = Some(tc.clone()); thinking_emitted = true; // Emit full reasoning event sequence at output_index 0 yield Ok(responses_sse_event( "response.output_item.added", serde_json::json!({ "type": "response.output_item.added", "sequence_number": next_seq(), "output_index": 0, "item": { "id": &reasoning_id, "type": "reasoning", "summary": [], }, }), )); yield Ok(responses_sse_event( "response.reasoning_summary_part.added", serde_json::json!({ "type": "response.reasoning_summary_part.added", "sequence_number": next_seq(), "item_id": &reasoning_id, "output_index": 0, "summary_index": 0, "part": { "type": "summary_text", "text": "" }, }), )); yield Ok(responses_sse_event( "response.reasoning_summary_text.delta", serde_json::json!({ "type": "response.reasoning_summary_text.delta", "sequence_number": next_seq(), "item_id": &reasoning_id, "output_index": 0, "summary_index": 0, "delta": &tc, }), )); yield Ok(responses_sse_event( "response.reasoning_summary_text.done", serde_json::json!({ "type": "response.reasoning_summary_text.done", "sequence_number": next_seq(), "item_id": &reasoning_id, "output_index": 0, "summary_index": 0, "text": &tc, }), )); yield Ok(responses_sse_event( "response.reasoning_summary_part.done", serde_json::json!({ "type": "response.reasoning_summary_part.done", "sequence_number": next_seq(), "item_id": &reasoning_id, "output_index": 0, "summary_index": 0, "part": { "type": "summary_text", "text": &tc }, }), )); yield Ok(responses_sse_event( "response.output_item.done", serde_json::json!({ "type": "response.output_item.done", "sequence_number": next_seq(), "output_index": 0, "item": { "id": &reasoning_id, "type": "reasoning", "summary": [{ "type": "summary_text", "text": &tc, }], }, }), )); } } // ── Phase 2: Stream text deltas ── let text = extract_response_text(steps); let msg_output_index: u32 = if thinking_emitted { 1 } else { 0 }; if !text.is_empty() && text != last_text { // Emit message output_item.added on first text if !message_started { message_started = true; yield Ok(responses_sse_event( "response.output_item.added", serde_json::json!({ "type": "response.output_item.added", "sequence_number": next_seq(), "output_index": msg_output_index, "item": build_message_output_in_progress(&msg_id), }), )); yield Ok(responses_sse_event( "response.content_part.added", serde_json::json!({ "type": "response.content_part.added", "sequence_number": next_seq(), "output_index": msg_output_index, "content_index": CONTENT_IDX, "part": { "type": "output_text", "text": "", "annotations": [], } }), )); } let new_content = if text.len() > last_text.len() && text.starts_with(&*last_text) { &text[last_text.len()..] } else { &text }; if !new_content.is_empty() { yield Ok(responses_sse_event( "response.output_text.delta", serde_json::json!({ "type": "response.output_text.delta", "sequence_number": next_seq(), "item_id": &msg_id, "output_index": msg_output_index, "content_index": CONTENT_IDX, "delta": new_content, }), )); last_text = text.to_string(); } } // ── Check completion ── if is_response_done(steps) && !last_text.is_empty() { debug!("Response done, text length={}", last_text.len()); let mu = extract_model_usage(steps); let msg_idx: u32 = if thinking_emitted { 1 } else { 0 }; let (usage, mitm_thinking) = usage_from_poll(&state.mitm_store, &cascade_id, &mu, ¶ms.user_text, &last_text).await; let ts = extract_thinking_signature(steps); // Use already-captured thinking, or MITM thinking, or LS thinking let tc = thinking_text.clone() .or(mitm_thinking) .or_else(|| extract_thinking_content(steps)); for evt in completion_events( &response_id, &model_name, &msg_id, &reasoning_id, msg_idx, CONTENT_IDX, &last_text, usage, created_at, &seq, ¶ms, ts, tc, ) { yield Ok(evt); } return; } // IDLE fallback let step_count = steps.len(); if step_count > 4 && step_count % 5 == 0 { if let Ok((ts, td)) = state.backend.get_trajectory(&cascade_id).await { if ts == 200 { let run_status = td["status"].as_str().unwrap_or(""); if run_status.contains("IDLE") && !last_text.is_empty() { debug!("Trajectory IDLE, text length={}", last_text.len()); let mu = extract_model_usage(steps); let msg_idx: u32 = if thinking_emitted { 1 } else { 0 }; let (usage, mitm_thinking) = usage_from_poll(&state.mitm_store, &cascade_id, &mu, ¶ms.user_text, &last_text).await; let ts = extract_thinking_signature(steps); let tc = thinking_text.clone() .or(mitm_thinking) .or_else(|| extract_thinking_content(steps)); for evt in completion_events( &response_id, &model_name, &msg_id, &reasoning_id, msg_idx, CONTENT_IDX, &last_text, usage, created_at, &seq, ¶ms, ts, tc, ) { yield Ok(evt); } return; } } } } } } } // Wait for next update: either reactive notification or fallback timer match reactive_rx { Some(ref mut rx) => { // Wait for reactive notification with a safety timeout let timeout = tokio::time::timeout( tokio::time::Duration::from_millis(500), rx.recv(), ).await; match timeout { Ok(Some(_diff)) => { // Drain any additional queued notifications (coalesce) while rx.try_recv().is_ok() {} } Ok(None) => { // Stream closed — fall back to polling debug!("Reactive stream closed, falling back to polling"); reactive_rx = None; } Err(_) => {} // timeout — fetch anyway as safety net } } None => { // Fallback: timer-based polling let poll_ms: u64 = rand::thread_rng().gen_range(150..250); tokio::time::sleep(tokio::time::Duration::from_millis(poll_ms)).await; } } } // Timeout — emit incomplete response let timeout_resp = build_response_object( ResponseData { id: response_id.clone(), model: model_name.clone(), status: "incomplete", created_at, completed_at: None, output: vec![], usage: Some(Usage::estimate(¶ms.user_text, "")), thinking_signature: None, }, ¶ms, ); yield Ok(responses_sse_event( "response.completed", serde_json::json!({ "type": "response.completed", "sequence_number": next_seq(), "response": response_to_json(&timeout_resp), }), )); }; Sse::new(stream) .keep_alive( axum::response::sse::KeepAlive::new() .interval(std::time::Duration::from_secs(15)) .text(""), ) .into_response() } // ─── SSE completion events ─────────────────────────────────────────────────── /// Build the final SSE events at completion time. /// /// Reasoning events were already streamed during polling (when thinking /// appeared in LS steps before response text). Message output_item.added /// and content_part.added were also emitted when text first appeared. /// /// This function emits only the "done" events plus the final response.completed. #[allow(clippy::too_many_arguments)] fn completion_events( resp_id: &str, model: &str, msg_id: &str, reasoning_id: &str, msg_output_index: u32, content_idx: u32, text: &str, usage: Usage, created_at: u64, seq: &AtomicU32, params: &RequestParams, thinking_signature: Option, thinking: Option, ) -> Vec { let next_seq = || seq.fetch_add(1, Ordering::Relaxed); let completed_at = now_unix(); let output_item = build_message_output(msg_id, text); // Build output array: [reasoning (if present), message] let mut output_items: Vec = Vec::new(); if let Some(ref thinking_text) = thinking { output_items.push(serde_json::json!({ "id": reasoning_id, "type": "reasoning", "summary": [{ "type": "summary_text", "text": thinking_text, }], })); } output_items.push(build_message_output(msg_id, text)); let completed_resp = build_response_object( ResponseData { id: resp_id.to_string(), model: model.to_string(), status: "completed", created_at, completed_at: Some(completed_at), output: output_items, usage: Some(usage), thinking_signature, }, params, ); let mut events: Vec = Vec::new(); // Message done events events.push(responses_sse_event( "response.output_text.done", serde_json::json!({ "type": "response.output_text.done", "sequence_number": next_seq(), "item_id": msg_id, "output_index": msg_output_index, "content_index": content_idx, "text": text, }), )); events.push(responses_sse_event( "response.content_part.done", serde_json::json!({ "type": "response.content_part.done", "sequence_number": next_seq(), "output_index": msg_output_index, "content_index": content_idx, "part": { "type": "output_text", "text": text, "annotations": [] }, }), )); events.push(responses_sse_event( "response.output_item.done", serde_json::json!({ "type": "response.output_item.done", "sequence_number": next_seq(), "output_index": msg_output_index, "item": output_item, }), )); events.push(responses_sse_event( "response.completed", serde_json::json!({ "type": "response.completed", "sequence_number": next_seq(), "response": response_to_json(&completed_resp), }), )); events }