- Responses API (streaming): MITM bypass path polls MitmStore directly when custom tools are active, skipping LS step polling entirely. Streams thinking text deltas in real-time as they arrive from the MITM. Handles function calls, text response, and thinking/reasoning events. - Responses API (sync): Same MITM bypass for non-streaming responses. Polls MitmStore for function calls or completed text before falling back to LS path. - Gemini endpoint: MITM bypass polls MitmStore directly for tool call responses, eliminating LS overhead. - MitmStore: Added captured_thinking_text field with set/peek/take methods for real-time thinking text capture from MITM SSE. - MITM proxy: Now captures both thinking_text and response_text from StreamingAccumulator into MitmStore when bypass mode is active.
306 lines
9.8 KiB
Rust
306 lines
9.8 KiB
Rust
//! Gemini-native endpoint (/v1/gemini) — zero-translation tool call passthrough.
|
|
//!
|
|
//! Accepts tools in Gemini `functionDeclarations` format directly,
|
|
//! returns `functionCall` in Gemini format directly.
|
|
//! No OpenAI ↔ Gemini format conversion.
|
|
|
|
use axum::{
|
|
extract::State,
|
|
http::StatusCode,
|
|
response::{IntoResponse, Json},
|
|
};
|
|
use std::sync::Arc;
|
|
use tracing::info;
|
|
|
|
use super::models::{lookup_model, DEFAULT_MODEL, MODELS};
|
|
use super::polling::poll_for_response;
|
|
use super::util::{err_response, now_unix};
|
|
use super::AppState;
|
|
use crate::mitm::store::PendingToolResult;
|
|
|
|
/// Gemini-native request format.
|
|
#[derive(serde::Deserialize)]
|
|
pub(crate) struct GeminiRequest {
|
|
pub model: Option<String>,
|
|
/// User input text.
|
|
pub input: serde_json::Value,
|
|
/// Gemini-native tools: [{"functionDeclarations": [...]}]
|
|
#[serde(default)]
|
|
pub tools: Option<Vec<serde_json::Value>>,
|
|
/// Gemini-native toolConfig: {"functionCallingConfig": {"mode": "AUTO"}}
|
|
#[serde(default)]
|
|
pub tool_config: Option<serde_json::Value>,
|
|
/// Session/conversation ID.
|
|
#[serde(default)]
|
|
pub conversation: Option<serde_json::Value>,
|
|
#[serde(default = "default_timeout")]
|
|
pub timeout: u64,
|
|
#[serde(default)]
|
|
pub stream: bool,
|
|
/// Tool results in Gemini format: [{"functionResponse": {"name": "...", "response": {...}}}]
|
|
#[serde(default)]
|
|
pub tool_results: Option<Vec<serde_json::Value>>,
|
|
}
|
|
|
|
fn default_timeout() -> u64 {
|
|
120
|
|
}
|
|
|
|
fn extract_conversation_id(conv: &Option<serde_json::Value>) -> Option<String> {
|
|
match conv {
|
|
Some(serde_json::Value::String(s)) => Some(s.clone()),
|
|
Some(obj) => obj["id"].as_str().map(|s| s.to_string()),
|
|
None => None,
|
|
}
|
|
}
|
|
|
|
pub(crate) async fn handle_gemini(
|
|
State(state): State<Arc<AppState>>,
|
|
Json(body): Json<GeminiRequest>,
|
|
) -> axum::response::Response {
|
|
info!(
|
|
"POST /v1/gemini model={} stream={}",
|
|
body.model.as_deref().unwrap_or(DEFAULT_MODEL),
|
|
body.stream
|
|
);
|
|
|
|
let model_name = body.model.as_deref().unwrap_or(DEFAULT_MODEL);
|
|
let model = match lookup_model(model_name) {
|
|
Some(m) => m,
|
|
None => {
|
|
let names: Vec<&str> = MODELS.iter().map(|m| m.name).collect();
|
|
return err_response(
|
|
StatusCode::BAD_REQUEST,
|
|
format!("Unknown model: {model_name}. Available: {names:?}"),
|
|
"invalid_request_error",
|
|
);
|
|
}
|
|
};
|
|
|
|
let token = state.backend.oauth_token().await;
|
|
if token.is_empty() {
|
|
return err_response(
|
|
StatusCode::UNAUTHORIZED,
|
|
"No OAuth token. POST to /v1/token or set ANTIGRAVITY_OAUTH_TOKEN env var.".into(),
|
|
"authentication_error",
|
|
);
|
|
}
|
|
|
|
// Extract user text
|
|
let user_text = match &body.input {
|
|
serde_json::Value::String(s) => s.clone(),
|
|
_ => {
|
|
return err_response(
|
|
StatusCode::BAD_REQUEST,
|
|
"Gemini endpoint requires input as a string".to_string(),
|
|
"invalid_request_error",
|
|
);
|
|
}
|
|
};
|
|
|
|
// Store tools directly in Gemini format (no conversion needed!)
|
|
if let Some(ref tools) = body.tools {
|
|
if !tools.is_empty() {
|
|
state.mitm_store.set_tools(tools.clone()).await;
|
|
info!(count = tools.len(), "Stored Gemini-native tools for MITM injection");
|
|
}
|
|
}
|
|
if let Some(ref config) = body.tool_config {
|
|
state.mitm_store.set_tool_config(config.clone()).await;
|
|
}
|
|
|
|
// Handle tool results (Gemini format: functionResponse)
|
|
if let Some(ref results) = body.tool_results {
|
|
for r in results {
|
|
if let Some(fr) = r.get("functionResponse") {
|
|
let name = fr["name"].as_str().unwrap_or("unknown").to_string();
|
|
let response = fr.get("response").cloned().unwrap_or(serde_json::json!({}));
|
|
state.mitm_store.add_tool_result(PendingToolResult {
|
|
name,
|
|
result: response,
|
|
}).await;
|
|
}
|
|
}
|
|
info!(count = results.len(), "Stored Gemini-native tool results for MITM injection");
|
|
}
|
|
|
|
// Session/conversation management
|
|
let session_id_str = extract_conversation_id(&body.conversation);
|
|
let cascade_id = if let Some(ref sid) = session_id_str {
|
|
match state
|
|
.sessions
|
|
.get_or_create(Some(sid), || state.backend.create_cascade())
|
|
.await
|
|
{
|
|
Ok(sr) => sr.cascade_id,
|
|
Err(e) => {
|
|
return err_response(
|
|
StatusCode::BAD_GATEWAY,
|
|
format!("StartCascade failed: {e}"),
|
|
"server_error",
|
|
);
|
|
}
|
|
}
|
|
} else {
|
|
match state.backend.create_cascade().await {
|
|
Ok(cid) => cid,
|
|
Err(e) => {
|
|
return err_response(
|
|
StatusCode::BAD_GATEWAY,
|
|
format!("StartCascade failed: {e}"),
|
|
"server_error",
|
|
);
|
|
}
|
|
}
|
|
};
|
|
|
|
// Send message
|
|
match state
|
|
.backend
|
|
.send_message(&cascade_id, &user_text, model.model_enum)
|
|
.await
|
|
{
|
|
Ok((200, _)) => {
|
|
let bg = Arc::clone(&state.backend);
|
|
let cid = cascade_id.clone();
|
|
tokio::spawn(async move {
|
|
let _ = bg.update_annotations(&cid).await;
|
|
});
|
|
}
|
|
Ok((status, _)) => {
|
|
return err_response(
|
|
StatusCode::BAD_GATEWAY,
|
|
format!("Antigravity returned {status}"),
|
|
"server_error",
|
|
);
|
|
}
|
|
Err(e) => {
|
|
return err_response(
|
|
StatusCode::BAD_GATEWAY,
|
|
format!("Send message failed: {e}"),
|
|
"server_error",
|
|
);
|
|
}
|
|
}
|
|
|
|
let has_custom_tools = state.mitm_store.get_tools().await.is_some();
|
|
|
|
// Clear stale response
|
|
state.mitm_store.clear_response_async().await;
|
|
|
|
// ── MITM bypass: when tools active, poll MitmStore directly ──
|
|
if has_custom_tools {
|
|
let start = std::time::Instant::now();
|
|
while start.elapsed().as_secs() < body.timeout {
|
|
// Check for function calls
|
|
let captured = state.mitm_store.take_any_function_calls().await;
|
|
if let Some(ref calls) = captured {
|
|
if !calls.is_empty() {
|
|
let parts: Vec<serde_json::Value> = calls
|
|
.iter()
|
|
.map(|fc| {
|
|
serde_json::json!({
|
|
"functionCall": {
|
|
"name": fc.name,
|
|
"args": fc.args,
|
|
}
|
|
})
|
|
})
|
|
.collect();
|
|
|
|
return Json(serde_json::json!({
|
|
"candidates": [{
|
|
"content": {
|
|
"parts": parts,
|
|
"role": "model",
|
|
},
|
|
"finishReason": "STOP",
|
|
}],
|
|
"modelVersion": model_name,
|
|
}))
|
|
.into_response();
|
|
}
|
|
}
|
|
|
|
// Check for completed text response
|
|
if state.mitm_store.is_response_complete() {
|
|
let text = state.mitm_store.take_response_text().await.unwrap_or_default();
|
|
return Json(serde_json::json!({
|
|
"candidates": [{
|
|
"content": {
|
|
"parts": [{"text": text}],
|
|
"role": "model",
|
|
},
|
|
"finishReason": "STOP",
|
|
}],
|
|
"modelVersion": model_name,
|
|
}))
|
|
.into_response();
|
|
}
|
|
|
|
tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
|
|
}
|
|
|
|
// Timeout
|
|
return Json(serde_json::json!({
|
|
"error": {
|
|
"message": "Request timed out",
|
|
"type": "timeout_error",
|
|
}
|
|
}))
|
|
.into_response();
|
|
}
|
|
|
|
// ── Normal LS path (no custom tools) ──
|
|
// Poll for response
|
|
let poll_result = poll_for_response(&state, &cascade_id, body.timeout).await;
|
|
|
|
// Check for captured function calls — return in Gemini format
|
|
let captured_tool_calls = state.mitm_store.take_any_function_calls().await;
|
|
|
|
if let Some(ref calls) = captured_tool_calls {
|
|
info!(
|
|
count = calls.len(),
|
|
tools = ?calls.iter().map(|c| &c.name).collect::<Vec<_>>(),
|
|
"Returning captured function calls (Gemini format)"
|
|
);
|
|
|
|
let parts: Vec<serde_json::Value> = calls
|
|
.iter()
|
|
.map(|fc| {
|
|
serde_json::json!({
|
|
"functionCall": {
|
|
"name": fc.name,
|
|
"args": fc.args,
|
|
}
|
|
})
|
|
})
|
|
.collect();
|
|
|
|
return Json(serde_json::json!({
|
|
"candidates": [{
|
|
"content": {
|
|
"parts": parts,
|
|
"role": "model",
|
|
},
|
|
"finishReason": "STOP",
|
|
}],
|
|
"modelVersion": model_name,
|
|
}))
|
|
.into_response();
|
|
}
|
|
|
|
// Normal text response
|
|
Json(serde_json::json!({
|
|
"candidates": [{
|
|
"content": {
|
|
"parts": [{"text": poll_result.text}],
|
|
"role": "model",
|
|
},
|
|
"finishReason": "STOP",
|
|
}],
|
|
"modelVersion": model_name,
|
|
}))
|
|
.into_response()
|
|
}
|