diff --git a/.gemini/plans/tool-calls-implementation.md b/.gemini/plans/tool-calls-implementation.md new file mode 100644 index 0000000..a3f7beb --- /dev/null +++ b/.gemini/plans/tool-calls-implementation.md @@ -0,0 +1,292 @@ +# Tool Call Implementation Plan + +## Overview + +Add full tool call support to the Antigravity proxy. Primary endpoint is OpenAI Responses API (`/v1/responses`), with a Gemini-native backup endpoint (`/v1/gemini`). Tools are stored per-session, all `tool_choice` modes supported, parallel tool calls supported. + +## Data Flow + +``` +┌─────────┐ ┌───────────┐ ┌────┐ ┌──────┐ ┌────────┐ +│ Client │─────▶│ Proxy │─────▶│ LS │─────▶│ MITM │─────▶│ Google │ +│ (openai) │ │ (axum) │ │ │ │ │ │ │ +│ │◀─────│ │◀─────│ │◀─────│ │◀─────│ │ +└─────────┘ └───────────┘ └────┘ └──────┘ └────────┘ + │ │ │ │ + │ tools (OAI) │ store tools (Gemini fmt) │ inject │ + │───────────────▶│────────────▶ MitmStore ─────▶│ tools │ + │ │ │──────────────▶│ + │ │ │ │ + │ │ │ functionCall │ + │ │◀──── capture ───────────────│◀──────────────│ + │ tool_calls │ │ block follow │ + │◀───────────────│ │ ups │ + │ │ │ │ + │ tool result │ store result │ inject │ + │───────────────▶│────────────▶ MitmStore ─────▶│ fn response │ + │ │ │──────────────▶│ + │ final text │ │ │ + │◀───────────────│◀────────────────────────────│◀──────────────│ +``` + +## Format Differences + +### Tool Definitions + +| Aspect | OpenAI | Gemini | +| ------------ | -------------------------------------- | ---------------------------------- | +| Wrapper | `{"type":"function","function":{...}}` | `{"functionDeclarations":[{...}]}` | +| Type strings | lowercase: `"object"`, `"string"` | UPPERCASE: `"OBJECT"`, `"STRING"` | +| Parameters | JSON Schema subset | Same schema, uppercase types | + +### Tool Choice + +| OpenAI | Gemini toolConfig | +| --------------------------------------------- | ----------------------------------------------------------------------- | +| `"auto"` | `{"functionCallingConfig":{"mode":"AUTO"}}` | +| `"required"` | `{"functionCallingConfig":{"mode":"ANY"}}` | +| `"none"` | `{"functionCallingConfig":{"mode":"NONE"}}` | +| `{"type":"function","function":{"name":"X"}}` | `{"functionCallingConfig":{"mode":"ANY","allowedFunctionNames":["X"]}}` | + +### Tool Call Response + +| OpenAI (what we return) | Gemini (what Google returns) | +| -------------------------------------------------------------------------------------------------- | --------------------------------------------------------------- | +| `output: [{"type":"function_call","call_id":"call_xxx","name":"get_weather","arguments":"{...}"}]` | `parts: [{"functionCall":{"name":"get_weather","args":{...}}}]` | + +### Tool Result Submission + +| OpenAI (what client sends) | Gemini (what we inject into Google request) | +| -------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------- | +| `input: [{"type":"function_call_output","call_id":"call_xxx","output":"{...}"}]` | `contents: [{role:"model",parts:[{functionCall:...}]},{role:"user",parts:[{functionResponse:{name:"...",response:{...}}}]}]` | + +--- + +## Implementation Phases + +### Phase 1: Store Infrastructure (`store.rs`) + +Add to `MitmStore`: + +```rust +/// Active tool definitions (Gemini format) for MITM injection. +active_tools: Arc>>>, +/// Active tool config (Gemini toolConfig format). +active_tool_config: Arc>>, +/// Pending tool results for MITM to inject as functionResponse. +pending_tool_results: Arc>>, +/// Mapping call_id → function name for tool result routing. +call_id_to_name: Arc>>, +/// Last captured function calls (for conversation history rewriting). +last_function_calls: Arc>>, +``` + +New types: + +```rust +pub struct PendingToolResult { + pub name: String, + pub result: serde_json::Value, +} +``` + +New methods: + +- `set_tools(tools)` / `get_tools()` / `clear_tools()` +- `set_tool_config(config)` / `get_tool_config()` +- `add_tool_result(result)` / `take_tool_results()` +- `register_call_id(call_id, name)` / `lookup_call_id(call_id)` +- `set_last_function_calls(calls)` / `get_last_function_calls()` + +### Phase 2: Request Types (`types.rs`) + +Add to `ResponsesRequest`: + +```rust +#[serde(default)] +pub tools: Option>, +#[serde(default)] +pub tool_choice: Option, +``` + +New output builder: + +```rust +pub fn build_function_call_output(call_id: &str, name: &str, arguments: &str) -> Value +``` + +### Phase 3: Format Conversion + Dynamic Injection (`modify.rs`) + +New public struct: + +```rust +pub struct ToolContext { + pub tools: Option>, // Gemini functionDeclarations + pub tool_config: Option, // Gemini toolConfig + pub pending_results: Vec, // Tool results to inject + pub last_calls: Vec, // For history rewriting +} +``` + +New conversion functions: + +```rust +pub fn openai_tools_to_gemini(tools: &[Value]) -> Vec // OAI → Gemini format +pub fn openai_tool_choice_to_gemini(choice: &Value) -> Value // OAI → Gemini toolConfig +fn uppercase_types(val: Value) -> Value // Recursive type case fix +``` + +Change `modify_request` signature: + +```rust +pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option> +``` + +Tool injection logic: + +1. Strip all LS tools (existing) +2. If `tool_ctx.tools` provided → inject as Gemini `functionDeclarations` +3. If `tool_ctx.tool_config` provided → inject as `toolConfig` +4. If `tool_ctx.pending_results` not empty → rewrite conversation history: + - Find model turn with "Tool call completed" → replace with `functionCall` parts + - Find last user turn → prepend `functionResponse` part + +### Phase 4: MITM Plumbing (`proxy.rs`) + +In `handle_http_over_tls`, before calling `modify_request`: + +1. Read `get_tools()`, `get_tool_config()`, `take_tool_results()`, `get_last_function_calls()` from store +2. Build `ToolContext` +3. Pass to `modify_request(body, tool_ctx)` + +After response capture: + +1. Save captured function calls as `last_function_calls` (for future history rewriting) + +### Phase 5: API Handler (`responses.rs`) + +#### Request handling (in `handle_responses`): + +1. If `body.tools` provided: + - Convert OpenAI → Gemini format via `openai_tools_to_gemini()` + - Store in `MitmStore` via `set_tools()` +2. If `body.tool_choice` provided: + - Convert via `openai_tool_choice_to_gemini()` + - Store in `MitmStore` via `set_tool_config()` +3. Check `body.input` for `function_call_output` items: + - If found: look up `call_id` → function name via `lookup_call_id()` + - Store as `PendingToolResult` via `add_tool_result()` + - Extract any accompanying text (or use placeholder) + +#### Response handling (in `handle_responses_sync` / `handle_responses_stream`): + +After polling completes: + +1. Check `take_any_function_calls()` for captured tool calls +2. If captured: + - Generate `call_id` for each (e.g., `"call_" + random`) + - Register `call_id → name` mapping via `register_call_id()` + - Build `function_call` output items via `build_function_call_output()` + - Return these INSTEAD of the text message output +3. If no tool calls: existing text response behavior + +### Phase 6: Gemini-Native Endpoint (`gemini.rs` + `mod.rs`) + +New file `src/api/gemini.rs` with handler `handle_gemini`: + +- Accepts tools in Gemini `functionDeclarations` format directly (no conversion) +- Accepts `toolConfig` directly +- Returns `functionCall` in Gemini format directly +- Same cascade/session management as responses.rs +- Much simpler — no format translation + +Route: `POST /v1/gemini` in `mod.rs` + +--- + +## File Change Summary + +| File | Changes | Complexity | +| ---------------------- | ----------------------------------------------------------------------- | ---------- | +| `src/mitm/store.rs` | Add tool context storage (5 new fields, ~10 methods) | Medium | +| `src/api/types.rs` | Add `tools`/`tool_choice` to request, add output builder | Low | +| `src/mitm/modify.rs` | `ToolContext`, format conversion, dynamic injection, history rewrite | High | +| `src/mitm/proxy.rs` | Read store → build ToolContext → pass to modify | Low | +| `src/api/responses.rs` | Store tools, detect tool results in input, return function_call outputs | High | +| `src/api/gemini.rs` | New file — Gemini-native endpoint (passthrough) | Medium | +| `src/api/mod.rs` | Add route + module declaration | Low | + +## Implementation Order + +1. `store.rs` — foundation, no dependencies +2. `types.rs` — request/response types +3. `modify.rs` — format conversion + injection (depends on store types) +4. `proxy.rs` — plumbing (depends on modify signature) +5. Build + verify compilation +6. `responses.rs` — handler changes (depends on all above) +7. Build + test with `get_weather` request +8. `gemini.rs` + `mod.rs` — Gemini endpoint +9. Build + test with Gemini format +10. Tool result flow test (multi-turn) + +## Testing Strategy + +### Test 1: Basic tool call (sync) + +```bash +curl -s http://localhost:8741/v1/responses -H "Content-Type: application/json" -d '{ + "model": "gemini-3-flash", + "input": "What is the weather in Tokyo?", + "tools": [{"type":"function","function":{"name":"get_weather","description":"Get weather","parameters":{"type":"object","properties":{"city":{"type":"string"}},"required":["city"]}}}], + "tool_choice": "auto", + "conversation": "tool-test", + "stream": false +}' +# Expected: output contains function_call with name=get_weather, arguments={"city":"Tokyo"} +``` + +### Test 2: Tool result submission (multi-turn) + +```bash +curl -s http://localhost:8741/v1/responses -H "Content-Type: application/json" -d '{ + "model": "gemini-3-flash", + "input": [{"type":"function_call_output","call_id":"call_xxx","output":"{\"temp\":72,\"unit\":\"F\"}"}], + "conversation": "tool-test", + "stream": false +}' +# Expected: output contains text response using the tool result +``` + +### Test 3: Gemini-native endpoint + +```bash +curl -s http://localhost:8741/v1/gemini -H "Content-Type: application/json" -d '{ + "model": "gemini-3-flash", + "input": "What is the weather in Tokyo?", + "tools": [{"functionDeclarations":[{"name":"get_weather","description":"Get weather","parameters":{"type":"OBJECT","properties":{"city":{"type":"STRING"}},"required":["city"]}}]}], + "conversation": "gemini-tool-test", + "stream": false +}' +# Expected: response contains functionCall in Gemini format +``` + +### Test 4: No tools (regression) + +```bash +curl -s http://localhost:8741/v1/responses -H "Content-Type: application/json" -d '{ + "model": "gemini-3-flash", + "input": "What is 2+2?", + "stream": false +}' +# Expected: normal text response, no tool call behavior +``` + +## Risks & Mitigations + +| Risk | Impact | Mitigation | +| ---------------------------------------------------------------- | ------ | ------------------------------------------------------------------------- | +| History rewriting breaks conversation | High | Only rewrite when pending_results non-empty; keep original as fallback | +| LS times out waiting for Google response during tool result turn | Medium | Increase timeout for tool result turns | +| Multiple parallel tool calls create race conditions | Medium | AtomicBool + sequential processing already handles this | +| `modify_request` test breakage | Low | Update existing tests for new signature | +| Global tool storage conflicts across concurrent requests | Medium | Not an issue — LS processes one request at a time (single cascade active) | diff --git a/src/api/gemini.rs b/src/api/gemini.rs new file mode 100644 index 0000000..add9327 --- /dev/null +++ b/src/api/gemini.rs @@ -0,0 +1,236 @@ +//! Gemini-native endpoint (/v1/gemini) — zero-translation tool call passthrough. +//! +//! Accepts tools in Gemini `functionDeclarations` format directly, +//! returns `functionCall` in Gemini format directly. +//! No OpenAI ↔ Gemini format conversion. + +use axum::{ + extract::State, + http::StatusCode, + response::{IntoResponse, Json}, +}; +use std::sync::Arc; +use tracing::info; + +use super::models::{lookup_model, DEFAULT_MODEL, MODELS}; +use super::polling::poll_for_response; +use super::util::{err_response, now_unix}; +use super::AppState; +use crate::mitm::store::PendingToolResult; + +/// Gemini-native request format. +#[derive(serde::Deserialize)] +pub(crate) struct GeminiRequest { + pub model: Option, + /// User input text. + pub input: serde_json::Value, + /// Gemini-native tools: [{"functionDeclarations": [...]}] + #[serde(default)] + pub tools: Option>, + /// Gemini-native toolConfig: {"functionCallingConfig": {"mode": "AUTO"}} + #[serde(default)] + pub tool_config: Option, + /// Session/conversation ID. + #[serde(default)] + pub conversation: Option, + #[serde(default = "default_timeout")] + pub timeout: u64, + #[serde(default)] + pub stream: bool, + /// Tool results in Gemini format: [{"functionResponse": {"name": "...", "response": {...}}}] + #[serde(default)] + pub tool_results: Option>, +} + +fn default_timeout() -> u64 { + 120 +} + +fn extract_conversation_id(conv: &Option) -> Option { + match conv { + Some(serde_json::Value::String(s)) => Some(s.clone()), + Some(obj) => obj["id"].as_str().map(|s| s.to_string()), + None => None, + } +} + +pub(crate) async fn handle_gemini( + State(state): State>, + Json(body): Json, +) -> axum::response::Response { + info!( + "POST /v1/gemini model={} stream={}", + body.model.as_deref().unwrap_or(DEFAULT_MODEL), + body.stream + ); + + let model_name = body.model.as_deref().unwrap_or(DEFAULT_MODEL); + let model = match lookup_model(model_name) { + Some(m) => m, + None => { + let names: Vec<&str> = MODELS.iter().map(|m| m.name).collect(); + return err_response( + StatusCode::BAD_REQUEST, + format!("Unknown model: {model_name}. Available: {names:?}"), + "invalid_request_error", + ); + } + }; + + let token = state.backend.oauth_token().await; + if token.is_empty() { + return err_response( + StatusCode::UNAUTHORIZED, + "No OAuth token. POST to /v1/token or set ANTIGRAVITY_OAUTH_TOKEN env var.".into(), + "authentication_error", + ); + } + + // Extract user text + let user_text = match &body.input { + serde_json::Value::String(s) => s.clone(), + _ => { + return err_response( + StatusCode::BAD_REQUEST, + "Gemini endpoint requires input as a string".to_string(), + "invalid_request_error", + ); + } + }; + + // Store tools directly in Gemini format (no conversion needed!) + if let Some(ref tools) = body.tools { + if !tools.is_empty() { + state.mitm_store.set_tools(tools.clone()).await; + info!(count = tools.len(), "Stored Gemini-native tools for MITM injection"); + } + } + if let Some(ref config) = body.tool_config { + state.mitm_store.set_tool_config(config.clone()).await; + } + + // Handle tool results (Gemini format: functionResponse) + if let Some(ref results) = body.tool_results { + for r in results { + if let Some(fr) = r.get("functionResponse") { + let name = fr["name"].as_str().unwrap_or("unknown").to_string(); + let response = fr.get("response").cloned().unwrap_or(serde_json::json!({})); + state.mitm_store.add_tool_result(PendingToolResult { + name, + result: response, + }).await; + } + } + info!(count = results.len(), "Stored Gemini-native tool results for MITM injection"); + } + + // Session/conversation management + let session_id_str = extract_conversation_id(&body.conversation); + let cascade_id = if let Some(ref sid) = session_id_str { + match state + .sessions + .get_or_create(Some(sid), || state.backend.create_cascade()) + .await + { + Ok(sr) => sr.cascade_id, + Err(e) => { + return err_response( + StatusCode::BAD_GATEWAY, + format!("StartCascade failed: {e}"), + "server_error", + ); + } + } + } else { + match state.backend.create_cascade().await { + Ok(cid) => cid, + Err(e) => { + return err_response( + StatusCode::BAD_GATEWAY, + format!("StartCascade failed: {e}"), + "server_error", + ); + } + } + }; + + // Send message + match state + .backend + .send_message(&cascade_id, &user_text, model.model_enum) + .await + { + Ok((200, _)) => { + let bg = Arc::clone(&state.backend); + let cid = cascade_id.clone(); + tokio::spawn(async move { + let _ = bg.update_annotations(&cid).await; + }); + } + Ok((status, _)) => { + return err_response( + StatusCode::BAD_GATEWAY, + format!("Antigravity returned {status}"), + "server_error", + ); + } + Err(e) => { + return err_response( + StatusCode::BAD_GATEWAY, + format!("Send message failed: {e}"), + "server_error", + ); + } + } + + // Poll for response + let poll_result = poll_for_response(&state, &cascade_id, body.timeout).await; + + // Check for captured function calls — return in Gemini format + let captured_tool_calls = state.mitm_store.take_any_function_calls().await; + + if let Some(ref calls) = captured_tool_calls { + info!( + count = calls.len(), + tools = ?calls.iter().map(|c| &c.name).collect::>(), + "Returning captured function calls (Gemini format)" + ); + + let parts: Vec = calls + .iter() + .map(|fc| { + serde_json::json!({ + "functionCall": { + "name": fc.name, + "args": fc.args, + } + }) + }) + .collect(); + + return Json(serde_json::json!({ + "candidates": [{ + "content": { + "parts": parts, + "role": "model", + }, + "finishReason": "STOP", + }], + "modelVersion": model_name, + })) + .into_response(); + } + + // Normal text response + Json(serde_json::json!({ + "candidates": [{ + "content": { + "parts": [{"text": poll_result.text}], + "role": "model", + }, + "finishReason": "STOP", + }], + "modelVersion": model_name, + })) + .into_response() +} diff --git a/src/api/mod.rs b/src/api/mod.rs index 21f5a51..ec62d6d 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -1,6 +1,7 @@ //! Axum API server — OpenAI-compatible Responses + Chat Completions endpoints. mod completions; +mod gemini; mod models; mod polling; mod responses; @@ -41,6 +42,7 @@ pub fn router(state: Arc) -> Router { "/v1/chat/completions", post(completions::handle_completions), ) + .route("/v1/gemini", post(gemini::handle_gemini)) .route("/v1/models", get(handle_models)) .route("/v1/sessions", get(handle_list_sessions)) .route("/v1/sessions/{id}", delete(handle_delete_session)) @@ -59,11 +61,12 @@ pub fn router(state: Arc) -> Router { async fn handle_root() -> Json { Json(serde_json::json!({ "service": "antigravity-openai-proxy", - "version": "3.2.0", + "version": "3.3.0", "runtime": "rust", "endpoints": [ "/v1/chat/completions", "/v1/responses", + "/v1/gemini", "/v1/models", "/v1/sessions", "/v1/token", diff --git a/src/api/responses.rs b/src/api/responses.rs index e02c82b..188df60 100644 --- a/src/api/responses.rs +++ b/src/api/responses.rs @@ -18,42 +18,91 @@ use super::polling::{extract_response_text, is_response_done, poll_for_response, use super::types::*; use super::util::{err_response, now_unix, responses_sse_event}; use super::AppState; +use crate::mitm::store::PendingToolResult; +use crate::mitm::modify::{openai_tools_to_gemini, openai_tool_choice_to_gemini}; // ─── Input extraction ──────────────────────────────────────────────────────── +/// Parsed tool result from function_call_output items in input. +struct ToolResultInput { + call_id: String, + output: String, +} + /// Extract user text from Responses API `input` field. -fn extract_responses_input(input: &serde_json::Value, instructions: Option<&str>) -> String { +/// Also extracts any function_call_output items for tool result handling. +fn extract_responses_input(input: &serde_json::Value, instructions: Option<&str>) -> (String, Vec) { + let mut tool_results: Vec = Vec::new(); + let user_text = match input { serde_json::Value::String(s) => s.clone(), serde_json::Value::Array(items) => { - items - .iter() - .rev() - .find(|item| item["role"].as_str() == Some("user")) - .and_then(|item| match &item["content"] { - serde_json::Value::String(s) => Some(s.clone()), - serde_json::Value::Array(parts) => Some( - parts - .iter() - .filter(|p| { - let t = p["type"].as_str().unwrap_or(""); - t == "input_text" || t == "text" - }) - .filter_map(|p| p["text"].as_str()) - .collect::>() - .join(" "), - ), - _ => None, - }) - .unwrap_or_default() + // Check for function_call_output items + for item in items { + if item["type"].as_str() == Some("function_call_output") { + if let (Some(call_id), Some(output)) = ( + item["call_id"].as_str(), + item["output"].as_str(), + ) { + tool_results.push(ToolResultInput { + call_id: call_id.to_string(), + output: output.to_string(), + }); + } + } + } + + // If we have tool results but no text, generate a follow-up prompt + if !tool_results.is_empty() { + // Look for any text items alongside the tool results + let text_items: String = items + .iter() + .filter(|item| { + let t = item["type"].as_str().unwrap_or(""); + t == "input_text" || t == "text" + }) + .filter_map(|p| p["text"].as_str()) + .collect::>() + .join(" "); + + if text_items.is_empty() { + "Use the tool results to answer the original question.".to_string() + } else { + text_items + } + } else { + // Normal input extraction (existing logic) + items + .iter() + .rev() + .find(|item| item["role"].as_str() == Some("user")) + .and_then(|item| match &item["content"] { + serde_json::Value::String(s) => Some(s.clone()), + serde_json::Value::Array(parts) => Some( + parts + .iter() + .filter(|p| { + let t = p["type"].as_str().unwrap_or(""); + t == "input_text" || t == "text" + }) + .filter_map(|p| p["text"].as_str()) + .collect::>() + .join(" "), + ), + _ => None, + }) + .unwrap_or_default() + } } _ => String::new(), }; - match instructions { + let final_text = match instructions { Some(inst) if !inst.is_empty() => format!("{inst}\n\n{user_text}"), _ => user_text, - } + }; + + (final_text, tool_results) } /// Extract conversation/session ID from Responses API `conversation` field. @@ -147,8 +196,32 @@ pub(crate) async fn handle_responses( ); } - let user_text = extract_responses_input(&body.input, body.instructions.as_deref()); - if user_text.is_empty() { + let (user_text, tool_results) = extract_responses_input(&body.input, body.instructions.as_deref()); + + // Handle tool result submission (function_call_output in input) + let is_tool_result_turn = !tool_results.is_empty(); + if is_tool_result_turn { + for tr in &tool_results { + // Look up function name from call_id + let name = state.mitm_store.lookup_call_id(&tr.call_id).await + .unwrap_or_else(|| "unknown_function".to_string()); + + // Parse the output as JSON, fall back to string wrapper + let result_value = serde_json::from_str::(&tr.output) + .unwrap_or_else(|_| serde_json::json!({"result": tr.output})); + + state.mitm_store.add_tool_result(PendingToolResult { + name, + result: result_value, + }).await; + } + info!( + count = tool_results.len(), + "Stored tool results for MITM injection" + ); + } + + if user_text.is_empty() && !is_tool_result_turn { return err_response( StatusCode::BAD_REQUEST, "No user input found".to_string(), @@ -156,6 +229,19 @@ pub(crate) async fn handle_responses( ); } + // Store client tools in MitmStore for MITM injection + if let Some(ref tools) = body.tools { + let gemini_tools = openai_tools_to_gemini(tools); + if !gemini_tools.is_empty() { + state.mitm_store.set_tools(gemini_tools).await; + info!(count = tools.len(), "Stored client tools for MITM injection"); + } + } + if let Some(ref choice) = body.tool_choice { + let gemini_config = openai_tool_choice_to_gemini(choice); + state.mitm_store.set_tool_config(gemini_config).await; + } + let response_id = format!( "resp_{}", uuid::Uuid::new_v4().to_string().replace('-', "") @@ -363,14 +449,52 @@ async fn handle_responses_sync( // Check for captured function calls from MITM (clears the active flag) let captured_tool_calls = state.mitm_store.take_any_function_calls().await; + + // If we have captured tool calls, return them as function_call output items if let Some(ref calls) = captured_tool_calls { info!( count = calls.len(), tools = ?calls.iter().map(|c| &c.name).collect::>(), - "Consumed captured function calls from MITM" + "Returning captured function calls to client" ); + + let mut output_items: Vec = Vec::new(); + for fc in calls { + let call_id = format!( + "call_{}", + uuid::Uuid::new_v4().to_string().replace('-', "")[..24].to_string() + ); + // Register call_id → name mapping for tool result routing + state.mitm_store.register_call_id(call_id.clone(), fc.name.clone()).await; + + // Stringify args (OpenAI sends arguments as JSON string) + let arguments = serde_json::to_string(&fc.args).unwrap_or_default(); + output_items.push(build_function_call_output(&call_id, &fc.name, &arguments)); + } + + let (usage, _) = usage_from_poll( + &state.mitm_store, &cascade_id, &poll_result.usage, + ¶ms.user_text, &poll_result.text, + ).await; + + let resp = build_response_object( + ResponseData { + id: response_id, + model: model_name, + status: "completed", + created_at, + completed_at: Some(completed_at), + output: output_items, + usage: Some(usage), + thinking_signature: poll_result.thinking_signature, + }, + ¶ms, + ); + + return Json(resp).into_response(); } + // Normal text response (no tool calls) let (usage, mitm_thinking) = usage_from_poll(&state.mitm_store, &cascade_id, &poll_result.usage, ¶ms.user_text, &poll_result.text).await; // Thinking text priority: MITM-captured (raw API) > LS-extracted (steps) diff --git a/src/api/types.rs b/src/api/types.rs index d195434..46a1e44 100644 --- a/src/api/types.rs +++ b/src/api/types.rs @@ -32,6 +32,12 @@ pub(crate) struct ResponsesRequest { pub metadata: Option, #[serde(default)] pub user: Option, + /// Tool definitions (OpenAI format). + #[serde(default)] + pub tools: Option>, + /// Tool choice: "auto", "required", "none", or {"type":"function","function":{"name":"X"}}. + #[serde(default)] + pub tool_choice: Option, } /// Chat Completions request (OpenAI-compatible). @@ -220,6 +226,18 @@ pub fn build_message_output_in_progress(msg_id: &str) -> serde_json::Value { }) } +/// Build a function_call output item (OpenAI Responses API format). +pub fn build_function_call_output(call_id: &str, name: &str, arguments: &str) -> serde_json::Value { + serde_json::json!({ + "type": "function_call", + "id": call_id, + "call_id": call_id, + "name": name, + "arguments": arguments, + "status": "completed", + }) +} + // ─── Helpers ───────────────────────────────────────────────────────────────── /// Serialize Option as either the number or JSON null (not omitted). diff --git a/src/mitm/modify.rs b/src/mitm/modify.rs index 637e43f..a056150 100644 --- a/src/mitm/modify.rs +++ b/src/mitm/modify.rs @@ -8,14 +8,29 @@ use regex::Regex; use serde_json::Value; use tracing::info; +use super::store::{CapturedFunctionCall, PendingToolResult}; + /// Strip ALL tool definitions. /// Must be true: with tools present, the LS enters full agentic mode /// (multi-turn tool calls, file searches, etc.) burning quota. const STRIP_ALL_TOOLS: bool = true; +/// Context for tool injection during request modification. +/// Built from MitmStore data before calling modify_request. +pub struct ToolContext { + /// Gemini-format tool declarations (functionDeclarations). + pub tools: Option>, + /// Gemini-format toolConfig. + pub tool_config: Option, + /// Pending tool results to inject as functionResponse. + pub pending_results: Vec, + /// Last captured function calls for history rewriting. + pub last_calls: Vec, +} + /// Modify a streamGenerateContent request body in-place. /// Returns the modified JSON bytes, or None if modification wasn't possible. -pub fn modify_request(body: &[u8]) -> Option> { +pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option> { let mut json: Value = serde_json::from_slice(body).ok()?; let original_size = body.len(); @@ -140,7 +155,7 @@ pub fn modify_request(body: &[u8]) -> Option> { } } - // ── 3. Strip LS tools, inject custom tools ──────────────────────────── + // ── 3. Strip LS tools, inject client tools ───────────────────────────── if STRIP_ALL_TOOLS { if let Some(tools) = json .pointer_mut("/request/tools") @@ -152,25 +167,83 @@ pub fn modify_request(body: &[u8]) -> Option> { changes.push(format!("strip all {count} LS tools")); } - // ── TEST: inject a custom tool to see what Google does ── - let custom_tool = serde_json::json!({ - "functionDeclarations": [{ - "name": "get_weather", - "description": "Get the current weather for a city. You MUST call this function when the user asks about weather.", - "parameters": { - "type": "object", - "properties": { - "city": { - "type": "string", - "description": "The city name" - } - }, - "required": ["city"] + // Inject client-provided tools from ToolContext + if let Some(ref ctx) = tool_ctx { + if let Some(ref custom_tools) = ctx.tools { + for tool in custom_tools { + tools.push(tool.clone()); } - }] - }); - tools.push(custom_tool); - changes.push("inject 1 custom tool (get_weather)".to_string()); + changes.push(format!("inject {} custom tool group(s)", custom_tools.len())); + } + } + } + } + + // Inject toolConfig if provided + if let Some(ref ctx) = tool_ctx { + if let Some(ref config) = ctx.tool_config { + if let Some(req) = json.get_mut("request").and_then(|v| v.as_object_mut()) { + req.insert("toolConfig".to_string(), config.clone()); + changes.push("inject toolConfig".to_string()); + } + } + } + + // ── 3b. Rewrite conversation history for tool results ──────────── + if let Some(ref ctx) = tool_ctx { + if !ctx.pending_results.is_empty() && !ctx.last_calls.is_empty() { + if let Some(contents) = json + .pointer_mut("/request/contents") + .and_then(|v| v.as_array_mut()) + { + // Find the model turn with our fake "Tool call completed" text and replace it + // with the actual functionCall parts + for msg in contents.iter_mut() { + if msg["role"].as_str() == Some("model") { + if let Some(text) = msg["parts"][0]["text"].as_str() { + if text.contains("Tool call completed") || text.contains("Awaiting external tool result") { + // Replace with functionCall parts + let fc_parts: Vec = ctx.last_calls.iter().map(|fc| { + serde_json::json!({ + "functionCall": { + "name": fc.name, + "args": fc.args, + } + }) + }).collect(); + msg["parts"] = Value::Array(fc_parts); + changes.push("rewrite model turn with functionCall".to_string()); + break; + } + } + } + } + + // Add functionResponse as a user turn before the last user message + let fn_response_parts: Vec = ctx.pending_results.iter().map(|r| { + serde_json::json!({ + "functionResponse": { + "name": r.name, + "response": r.result, + } + }) + }).collect(); + let fn_response_turn = serde_json::json!({ + "role": "user", + "parts": fn_response_parts, + }); + + // Insert before the last user message + let last_user_idx = contents.iter().rposition(|msg| { + msg["role"].as_str() == Some("user") + }); + if let Some(idx) = last_user_idx { + contents.insert(idx, fn_response_turn); + } else { + contents.push(fn_response_turn); + } + changes.push(format!("inject {} functionResponse(s)", ctx.pending_results.len())); + } } } @@ -323,6 +396,93 @@ pub fn rechunk(data: &[u8]) -> Vec { result } +// ── OpenAI → Gemini format conversion ──────────────────────────────────────── + +/// Convert OpenAI tool definitions to Gemini functionDeclarations format. +/// +/// OpenAI: `[{"type":"function","function":{"name":"X","description":"Y","parameters":{...}}}]` +/// Gemini: `[{"functionDeclarations":[{"name":"X","description":"Y","parameters":{...}}]}]` +pub fn openai_tools_to_gemini(tools: &[Value]) -> Vec { + let declarations: Vec = tools + .iter() + .filter(|t| t["type"].as_str() == Some("function")) + .filter_map(|t| { + let func = t.get("function")?; + let mut decl = serde_json::json!({ + "name": func["name"], + "description": func["description"], + }); + if let Some(params) = func.get("parameters") { + decl["parameters"] = uppercase_types(params.clone()); + } + Some(decl) + }) + .collect(); + + if declarations.is_empty() { + return vec![]; + } + + vec![serde_json::json!({"functionDeclarations": declarations})] +} + +/// Convert OpenAI tool_choice to Gemini toolConfig format. +/// +/// OpenAI: "auto" | "required" | "none" | {"type":"function","function":{"name":"X"}} +/// Gemini: {"functionCallingConfig":{"mode":"AUTO|ANY|NONE","allowedFunctionNames":[...]}} +pub fn openai_tool_choice_to_gemini(choice: &Value) -> Value { + match choice { + Value::String(s) => match s.as_str() { + "auto" => serde_json::json!({"functionCallingConfig": {"mode": "AUTO"}}), + "required" => serde_json::json!({"functionCallingConfig": {"mode": "ANY"}}), + "none" => serde_json::json!({"functionCallingConfig": {"mode": "NONE"}}), + _ => serde_json::json!({"functionCallingConfig": {"mode": "AUTO"}}), + }, + Value::Object(obj) => { + if let Some(name) = obj.get("function").and_then(|f| f["name"].as_str()) { + serde_json::json!({ + "functionCallingConfig": { + "mode": "ANY", + "allowedFunctionNames": [name] + } + }) + } else { + serde_json::json!({"functionCallingConfig": {"mode": "AUTO"}}) + } + } + _ => serde_json::json!({"functionCallingConfig": {"mode": "AUTO"}}), + } +} + +/// Recursively convert JSON Schema type strings to uppercase (Gemini format). +/// "object" → "OBJECT", "string" → "STRING", etc. +fn uppercase_types(mut val: Value) -> Value { + match &mut val { + Value::Object(map) => { + if let Some(t) = map + .get("type") + .and_then(|v| v.as_str()) + .map(|s| s.to_uppercase()) + { + map.insert("type".to_string(), Value::String(t)); + } + let keys: Vec = map.keys().cloned().collect(); + for key in keys { + if let Some(v) = map.remove(&key) { + map.insert(key, uppercase_types(v)); + } + } + } + Value::Array(arr) => { + for v in arr.iter_mut() { + *v = uppercase_types(std::mem::take(v)); + } + } + _ => {} + } + val +} + #[cfg(test)] mod tests { use super::*; @@ -375,10 +535,11 @@ mod tests { }); let bytes = serde_json::to_vec(&body).unwrap(); - let modified = modify_request(&bytes).unwrap(); + let modified = modify_request(&bytes, None).unwrap(); let result: Value = serde_json::from_slice(&modified).unwrap(); let tools = result["request"]["tools"].as_array().unwrap(); + // With no ToolContext, tools should just be stripped (empty) assert!(tools.is_empty(), "all tools should be stripped"); } @@ -398,7 +559,7 @@ mod tests { }); let bytes = serde_json::to_vec(&body).unwrap(); - let modified = modify_request(&bytes).unwrap(); + let modified = modify_request(&bytes, None).unwrap(); let result: Value = serde_json::from_slice(&modified).unwrap(); let new_sys = result["request"]["systemInstruction"]["parts"][0]["text"] @@ -432,7 +593,7 @@ mod tests { }); let bytes = serde_json::to_vec(&body).unwrap(); - let modified = modify_request(&bytes).unwrap(); + let modified = modify_request(&bytes, None).unwrap(); let result: Value = serde_json::from_slice(&modified).unwrap(); let contents = result["request"]["contents"].as_array().unwrap(); diff --git a/src/mitm/proxy.rs b/src/mitm/proxy.rs index 60a2735..7e31924 100644 --- a/src/mitm/proxy.rs +++ b/src/mitm/proxy.rs @@ -556,7 +556,24 @@ async fn handle_http_over_tls( || body_str.contains("\"requestType\": \"agent\""); if is_agent { - if let Some(modified_body) = super::modify::modify_request(&raw_body) { + // Build ToolContext from store + let tools = store.get_tools().await; + let tool_config = store.get_tool_config().await; + let pending_results = store.take_tool_results().await; + let last_calls = store.get_last_function_calls().await; + + let tool_ctx = if tools.is_some() || !pending_results.is_empty() { + Some(super::modify::ToolContext { + tools, + tool_config, + pending_results, + last_calls, + }) + } else { + None + }; + + if let Some(modified_body) = super::modify::modify_request(&raw_body, tool_ctx.as_ref()) { // Rebuild request_buf: original headers + rechunked modified body let new_chunked = super::modify::rechunk(&modified_body); let mut new_buf = request_buf[..headers_end].to_vec(); @@ -766,6 +783,10 @@ async fn handle_http_over_tls( for fc in &streaming_acc.function_calls { store.record_function_call(cascade_hint.as_deref(), fc.clone()).await; } + // Also save for history rewriting on tool result turns + if !streaming_acc.function_calls.is_empty() { + store.set_last_function_calls(streaming_acc.function_calls.clone()).await; + } let usage = streaming_acc.into_usage(); store.record_usage(cascade_hint.as_deref(), usage).await; } diff --git a/src/mitm/store.rs b/src/mitm/store.rs index 1b30d6d..8144cb0 100644 --- a/src/mitm/store.rs +++ b/src/mitm/store.rs @@ -53,6 +53,13 @@ pub struct CapturedFunctionCall { pub captured_at: u64, } +/// A pending tool result from a client's function_call_output. +#[derive(Debug, Clone)] +pub struct PendingToolResult { + pub name: String, + pub result: serde_json::Value, +} + /// Thread-safe store for intercepted data. /// /// Keyed by a unique request ID that we can correlate with cascade operations. @@ -69,6 +76,18 @@ pub struct MitmStore { /// Simple flag: set when a functionCall is captured, cleared when consumed. /// Used to block follow-up requests regardless of cascade identification. has_active_function_call: Arc, + + // ── Tool call support ──────────────────────────────────────────────── + /// Active tool definitions (Gemini format) for MITM injection. + active_tools: Arc>>>, + /// Active tool config (Gemini toolConfig format). + active_tool_config: Arc>>, + /// Pending tool results for MITM to inject as functionResponse. + pending_tool_results: Arc>>, + /// Mapping call_id → function name for tool result routing. + call_id_to_name: Arc>>, + /// Last captured function calls (for conversation history rewriting). + last_function_calls: Arc>>, } /// Aggregate statistics across all intercepted traffic. @@ -102,6 +121,11 @@ impl MitmStore { stats: Arc::new(RwLock::new(MitmStats::default())), pending_function_calls: Arc::new(RwLock::new(HashMap::new())), has_active_function_call: Arc::new(AtomicBool::new(false)), + active_tools: Arc::new(RwLock::new(None)), + active_tool_config: Arc::new(RwLock::new(None)), + pending_tool_results: Arc::new(RwLock::new(Vec::new())), + call_id_to_name: Arc::new(RwLock::new(HashMap::new())), + last_function_calls: Arc::new(RwLock::new(Vec::new())), } } @@ -266,4 +290,63 @@ impl MitmStore { } None } + + // ── Tool context methods ───────────────────────────────────────────── + + /// Set active tool definitions (already in Gemini format). + pub async fn set_tools(&self, tools: Vec) { + *self.active_tools.write().await = Some(tools); + } + + /// Get active tool definitions. + pub async fn get_tools(&self) -> Option> { + self.active_tools.read().await.clone() + } + + /// Clear active tool definitions. + pub async fn clear_tools(&self) { + *self.active_tools.write().await = None; + *self.active_tool_config.write().await = None; + } + + /// Set active tool config (Gemini toolConfig format). + pub async fn set_tool_config(&self, config: serde_json::Value) { + *self.active_tool_config.write().await = Some(config); + } + + /// Get active tool config. + pub async fn get_tool_config(&self) -> Option { + self.active_tool_config.read().await.clone() + } + + /// Add a pending tool result for MITM injection. + pub async fn add_tool_result(&self, result: PendingToolResult) { + info!(name = %result.name, "Storing pending tool result"); + self.pending_tool_results.write().await.push(result); + } + + /// Take (consume) all pending tool results. + pub async fn take_tool_results(&self) -> Vec { + std::mem::take(&mut *self.pending_tool_results.write().await) + } + + /// Register a call_id → function name mapping. + pub async fn register_call_id(&self, call_id: String, name: String) { + self.call_id_to_name.write().await.insert(call_id, name); + } + + /// Look up function name by call_id. + pub async fn lookup_call_id(&self, call_id: &str) -> Option { + self.call_id_to_name.read().await.get(call_id).cloned() + } + + /// Save the last captured function calls (for history rewriting). + pub async fn set_last_function_calls(&self, calls: Vec) { + *self.last_function_calls.write().await = calls; + } + + /// Get the last captured function calls. + pub async fn get_last_function_calls(&self) -> Vec { + self.last_function_calls.read().await.clone() + } }