Files
zerogravity/src/api/gemini.rs
Nikketryhard b3af73cebd feat: sync all endpoints with MITM LS bypass + real-time thinking streaming
- Responses API (streaming): MITM bypass path polls MitmStore directly
  when custom tools are active, skipping LS step polling entirely.
  Streams thinking text deltas in real-time as they arrive from the MITM.
  Handles function calls, text response, and thinking/reasoning events.

- Responses API (sync): Same MITM bypass for non-streaming responses.
  Polls MitmStore for function calls or completed text before falling
  back to LS path.

- Gemini endpoint: MITM bypass polls MitmStore directly for tool call
  responses, eliminating LS overhead.

- MitmStore: Added captured_thinking_text field with set/peek/take methods
  for real-time thinking text capture from MITM SSE.

- MITM proxy: Now captures both thinking_text and response_text from
  StreamingAccumulator into MitmStore when bypass mode is active.
2026-02-15 01:03:39 -06:00

306 lines
9.8 KiB
Rust

//! Gemini-native endpoint (/v1/gemini) — zero-translation tool call passthrough.
//!
//! Accepts tools in Gemini `functionDeclarations` format directly,
//! returns `functionCall` in Gemini format directly.
//! No OpenAI ↔ Gemini format conversion.
use axum::{
extract::State,
http::StatusCode,
response::{IntoResponse, Json},
};
use std::sync::Arc;
use tracing::info;
use super::models::{lookup_model, DEFAULT_MODEL, MODELS};
use super::polling::poll_for_response;
use super::util::{err_response, now_unix};
use super::AppState;
use crate::mitm::store::PendingToolResult;
/// Gemini-native request format.
#[derive(serde::Deserialize)]
pub(crate) struct GeminiRequest {
pub model: Option<String>,
/// User input text.
pub input: serde_json::Value,
/// Gemini-native tools: [{"functionDeclarations": [...]}]
#[serde(default)]
pub tools: Option<Vec<serde_json::Value>>,
/// Gemini-native toolConfig: {"functionCallingConfig": {"mode": "AUTO"}}
#[serde(default)]
pub tool_config: Option<serde_json::Value>,
/// Session/conversation ID.
#[serde(default)]
pub conversation: Option<serde_json::Value>,
#[serde(default = "default_timeout")]
pub timeout: u64,
#[serde(default)]
pub stream: bool,
/// Tool results in Gemini format: [{"functionResponse": {"name": "...", "response": {...}}}]
#[serde(default)]
pub tool_results: Option<Vec<serde_json::Value>>,
}
fn default_timeout() -> u64 {
120
}
fn extract_conversation_id(conv: &Option<serde_json::Value>) -> Option<String> {
match conv {
Some(serde_json::Value::String(s)) => Some(s.clone()),
Some(obj) => obj["id"].as_str().map(|s| s.to_string()),
None => None,
}
}
pub(crate) async fn handle_gemini(
State(state): State<Arc<AppState>>,
Json(body): Json<GeminiRequest>,
) -> axum::response::Response {
info!(
"POST /v1/gemini model={} stream={}",
body.model.as_deref().unwrap_or(DEFAULT_MODEL),
body.stream
);
let model_name = body.model.as_deref().unwrap_or(DEFAULT_MODEL);
let model = match lookup_model(model_name) {
Some(m) => m,
None => {
let names: Vec<&str> = MODELS.iter().map(|m| m.name).collect();
return err_response(
StatusCode::BAD_REQUEST,
format!("Unknown model: {model_name}. Available: {names:?}"),
"invalid_request_error",
);
}
};
let token = state.backend.oauth_token().await;
if token.is_empty() {
return err_response(
StatusCode::UNAUTHORIZED,
"No OAuth token. POST to /v1/token or set ANTIGRAVITY_OAUTH_TOKEN env var.".into(),
"authentication_error",
);
}
// Extract user text
let user_text = match &body.input {
serde_json::Value::String(s) => s.clone(),
_ => {
return err_response(
StatusCode::BAD_REQUEST,
"Gemini endpoint requires input as a string".to_string(),
"invalid_request_error",
);
}
};
// Store tools directly in Gemini format (no conversion needed!)
if let Some(ref tools) = body.tools {
if !tools.is_empty() {
state.mitm_store.set_tools(tools.clone()).await;
info!(count = tools.len(), "Stored Gemini-native tools for MITM injection");
}
}
if let Some(ref config) = body.tool_config {
state.mitm_store.set_tool_config(config.clone()).await;
}
// Handle tool results (Gemini format: functionResponse)
if let Some(ref results) = body.tool_results {
for r in results {
if let Some(fr) = r.get("functionResponse") {
let name = fr["name"].as_str().unwrap_or("unknown").to_string();
let response = fr.get("response").cloned().unwrap_or(serde_json::json!({}));
state.mitm_store.add_tool_result(PendingToolResult {
name,
result: response,
}).await;
}
}
info!(count = results.len(), "Stored Gemini-native tool results for MITM injection");
}
// Session/conversation management
let session_id_str = extract_conversation_id(&body.conversation);
let cascade_id = if let Some(ref sid) = session_id_str {
match state
.sessions
.get_or_create(Some(sid), || state.backend.create_cascade())
.await
{
Ok(sr) => sr.cascade_id,
Err(e) => {
return err_response(
StatusCode::BAD_GATEWAY,
format!("StartCascade failed: {e}"),
"server_error",
);
}
}
} else {
match state.backend.create_cascade().await {
Ok(cid) => cid,
Err(e) => {
return err_response(
StatusCode::BAD_GATEWAY,
format!("StartCascade failed: {e}"),
"server_error",
);
}
}
};
// Send message
match state
.backend
.send_message(&cascade_id, &user_text, model.model_enum)
.await
{
Ok((200, _)) => {
let bg = Arc::clone(&state.backend);
let cid = cascade_id.clone();
tokio::spawn(async move {
let _ = bg.update_annotations(&cid).await;
});
}
Ok((status, _)) => {
return err_response(
StatusCode::BAD_GATEWAY,
format!("Antigravity returned {status}"),
"server_error",
);
}
Err(e) => {
return err_response(
StatusCode::BAD_GATEWAY,
format!("Send message failed: {e}"),
"server_error",
);
}
}
let has_custom_tools = state.mitm_store.get_tools().await.is_some();
// Clear stale response
state.mitm_store.clear_response_async().await;
// ── MITM bypass: when tools active, poll MitmStore directly ──
if has_custom_tools {
let start = std::time::Instant::now();
while start.elapsed().as_secs() < body.timeout {
// Check for function calls
let captured = state.mitm_store.take_any_function_calls().await;
if let Some(ref calls) = captured {
if !calls.is_empty() {
let parts: Vec<serde_json::Value> = calls
.iter()
.map(|fc| {
serde_json::json!({
"functionCall": {
"name": fc.name,
"args": fc.args,
}
})
})
.collect();
return Json(serde_json::json!({
"candidates": [{
"content": {
"parts": parts,
"role": "model",
},
"finishReason": "STOP",
}],
"modelVersion": model_name,
}))
.into_response();
}
}
// Check for completed text response
if state.mitm_store.is_response_complete() {
let text = state.mitm_store.take_response_text().await.unwrap_or_default();
return Json(serde_json::json!({
"candidates": [{
"content": {
"parts": [{"text": text}],
"role": "model",
},
"finishReason": "STOP",
}],
"modelVersion": model_name,
}))
.into_response();
}
tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
}
// Timeout
return Json(serde_json::json!({
"error": {
"message": "Request timed out",
"type": "timeout_error",
}
}))
.into_response();
}
// ── Normal LS path (no custom tools) ──
// Poll for response
let poll_result = poll_for_response(&state, &cascade_id, body.timeout).await;
// Check for captured function calls — return in Gemini format
let captured_tool_calls = state.mitm_store.take_any_function_calls().await;
if let Some(ref calls) = captured_tool_calls {
info!(
count = calls.len(),
tools = ?calls.iter().map(|c| &c.name).collect::<Vec<_>>(),
"Returning captured function calls (Gemini format)"
);
let parts: Vec<serde_json::Value> = calls
.iter()
.map(|fc| {
serde_json::json!({
"functionCall": {
"name": fc.name,
"args": fc.args,
}
})
})
.collect();
return Json(serde_json::json!({
"candidates": [{
"content": {
"parts": parts,
"role": "model",
},
"finishReason": "STOP",
}],
"modelVersion": model_name,
}))
.into_response();
}
// Normal text response
Json(serde_json::json!({
"candidates": [{
"content": {
"parts": [{"text": poll_result.text}],
"role": "model",
},
"finishReason": "STOP",
}],
"modelVersion": model_name,
}))
.into_response()
}