//! Gemini-native endpoint (/v1/gemini) — zero-translation tool call passthrough. //! //! Accepts tools in Gemini `functionDeclarations` format directly, //! returns `functionCall` in Gemini format directly. //! No OpenAI ↔ Gemini format conversion. use axum::{ extract::State, http::StatusCode, response::{IntoResponse, Json}, }; use std::sync::Arc; use tracing::info; use super::models::{lookup_model, DEFAULT_MODEL, MODELS}; use super::polling::poll_for_response; use super::util::{err_response, now_unix}; use super::AppState; use crate::mitm::store::PendingToolResult; /// Gemini-native request format. #[derive(serde::Deserialize)] pub(crate) struct GeminiRequest { pub model: Option, /// User input text. pub input: serde_json::Value, /// Gemini-native tools: [{"functionDeclarations": [...]}] #[serde(default)] pub tools: Option>, /// Gemini-native toolConfig: {"functionCallingConfig": {"mode": "AUTO"}} #[serde(default)] pub tool_config: Option, /// Session/conversation ID. #[serde(default)] pub conversation: Option, #[serde(default = "default_timeout")] pub timeout: u64, #[serde(default)] pub stream: bool, /// Tool results in Gemini format: [{"functionResponse": {"name": "...", "response": {...}}}] #[serde(default)] pub tool_results: Option>, } fn default_timeout() -> u64 { 120 } fn extract_conversation_id(conv: &Option) -> Option { match conv { Some(serde_json::Value::String(s)) => Some(s.clone()), Some(obj) => obj["id"].as_str().map(|s| s.to_string()), None => None, } } pub(crate) async fn handle_gemini( State(state): State>, Json(body): Json, ) -> axum::response::Response { info!( "POST /v1/gemini model={} stream={}", body.model.as_deref().unwrap_or(DEFAULT_MODEL), body.stream ); let model_name = body.model.as_deref().unwrap_or(DEFAULT_MODEL); let model = match lookup_model(model_name) { Some(m) => m, None => { let names: Vec<&str> = MODELS.iter().map(|m| m.name).collect(); return err_response( StatusCode::BAD_REQUEST, format!("Unknown model: {model_name}. Available: {names:?}"), "invalid_request_error", ); } }; let token = state.backend.oauth_token().await; if token.is_empty() { return err_response( StatusCode::UNAUTHORIZED, "No OAuth token. POST to /v1/token or set ANTIGRAVITY_OAUTH_TOKEN env var.".into(), "authentication_error", ); } // Extract user text let user_text = match &body.input { serde_json::Value::String(s) => s.clone(), _ => { return err_response( StatusCode::BAD_REQUEST, "Gemini endpoint requires input as a string".to_string(), "invalid_request_error", ); } }; // Store tools directly in Gemini format (no conversion needed!) if let Some(ref tools) = body.tools { if !tools.is_empty() { state.mitm_store.set_tools(tools.clone()).await; info!(count = tools.len(), "Stored Gemini-native tools for MITM injection"); } } if let Some(ref config) = body.tool_config { state.mitm_store.set_tool_config(config.clone()).await; } // Handle tool results (Gemini format: functionResponse) if let Some(ref results) = body.tool_results { for r in results { if let Some(fr) = r.get("functionResponse") { let name = fr["name"].as_str().unwrap_or("unknown").to_string(); let response = fr.get("response").cloned().unwrap_or(serde_json::json!({})); state.mitm_store.add_tool_result(PendingToolResult { name, result: response, }).await; } } info!(count = results.len(), "Stored Gemini-native tool results for MITM injection"); } // Session/conversation management let session_id_str = extract_conversation_id(&body.conversation); let cascade_id = if let Some(ref sid) = session_id_str { match state .sessions .get_or_create(Some(sid), || state.backend.create_cascade()) .await { Ok(sr) => sr.cascade_id, Err(e) => { return err_response( StatusCode::BAD_GATEWAY, format!("StartCascade failed: {e}"), "server_error", ); } } } else { match state.backend.create_cascade().await { Ok(cid) => cid, Err(e) => { return err_response( StatusCode::BAD_GATEWAY, format!("StartCascade failed: {e}"), "server_error", ); } } }; // Send message match state .backend .send_message(&cascade_id, &user_text, model.model_enum) .await { Ok((200, _)) => { let bg = Arc::clone(&state.backend); let cid = cascade_id.clone(); tokio::spawn(async move { let _ = bg.update_annotations(&cid).await; }); } Ok((status, _)) => { return err_response( StatusCode::BAD_GATEWAY, format!("Antigravity returned {status}"), "server_error", ); } Err(e) => { return err_response( StatusCode::BAD_GATEWAY, format!("Send message failed: {e}"), "server_error", ); } } let has_custom_tools = state.mitm_store.get_tools().await.is_some(); // Clear stale response state.mitm_store.clear_response_async().await; // ── MITM bypass: when tools active, poll MitmStore directly ── if has_custom_tools { let start = std::time::Instant::now(); while start.elapsed().as_secs() < body.timeout { // Check for function calls let captured = state.mitm_store.take_any_function_calls().await; if let Some(ref calls) = captured { if !calls.is_empty() { let parts: Vec = calls .iter() .map(|fc| { serde_json::json!({ "functionCall": { "name": fc.name, "args": fc.args, } }) }) .collect(); return Json(serde_json::json!({ "candidates": [{ "content": { "parts": parts, "role": "model", }, "finishReason": "STOP", }], "modelVersion": model_name, })) .into_response(); } } // Check for completed text response if state.mitm_store.is_response_complete() { let text = state.mitm_store.take_response_text().await.unwrap_or_default(); return Json(serde_json::json!({ "candidates": [{ "content": { "parts": [{"text": text}], "role": "model", }, "finishReason": "STOP", }], "modelVersion": model_name, })) .into_response(); } tokio::time::sleep(tokio::time::Duration::from_millis(200)).await; } // Timeout return Json(serde_json::json!({ "error": { "message": "Request timed out", "type": "timeout_error", } })) .into_response(); } // ── Normal LS path (no custom tools) ── // Poll for response let poll_result = poll_for_response(&state, &cascade_id, body.timeout).await; // Check for captured function calls — return in Gemini format let captured_tool_calls = state.mitm_store.take_any_function_calls().await; if let Some(ref calls) = captured_tool_calls { info!( count = calls.len(), tools = ?calls.iter().map(|c| &c.name).collect::>(), "Returning captured function calls (Gemini format)" ); let parts: Vec = calls .iter() .map(|fc| { serde_json::json!({ "functionCall": { "name": fc.name, "args": fc.args, } }) }) .collect(); return Json(serde_json::json!({ "candidates": [{ "content": { "parts": parts, "role": "model", }, "finishReason": "STOP", }], "modelVersion": model_name, })) .into_response(); } // Normal text response Json(serde_json::json!({ "candidates": [{ "content": { "parts": [{"text": poll_result.text}], "role": "model", }, "finishReason": "STOP", }], "modelVersion": model_name, })) .into_response() }