feat: full tool call support (OpenAI + Gemini endpoints)
- store.rs: Add tool context storage (active tools, tool config, pending tool results, call_id mapping, last function calls for history rewrite) - types.rs: Add tools/tool_choice fields to ResponsesRequest, add build_function_call_output helper for OpenAI function_call output items - modify.rs: Replace hardcoded get_weather with dynamic ToolContext injection. Add openai_tools_to_gemini and openai_tool_choice_to_gemini converters. Add conversation history rewriting for tool result turns (replaces fake 'Tool call completed' model turn with real functionCall, injects functionResponse before last user turn) - proxy.rs: Build ToolContext from MitmStore before calling modify_request. Save last_function_calls for history rewriting on subsequent turns - responses.rs: Store client tools in MitmStore before LS call. Detect function_call_output in input array for tool result submission. Return captured functionCalls as OpenAI function_call output items with generated call_ids and stringified arguments - gemini.rs: New Gemini-native endpoint (POST /v1/gemini) with zero format translation. Accepts functionDeclarations directly, returns functionCall in Gemini format directly - mod.rs: Wire /v1/gemini route, bump version to 3.3.0
This commit is contained in:
292
.gemini/plans/tool-calls-implementation.md
Normal file
292
.gemini/plans/tool-calls-implementation.md
Normal file
@@ -0,0 +1,292 @@
|
||||
# Tool Call Implementation Plan
|
||||
|
||||
## Overview
|
||||
|
||||
Add full tool call support to the Antigravity proxy. Primary endpoint is OpenAI Responses API (`/v1/responses`), with a Gemini-native backup endpoint (`/v1/gemini`). Tools are stored per-session, all `tool_choice` modes supported, parallel tool calls supported.
|
||||
|
||||
## Data Flow
|
||||
|
||||
```
|
||||
┌─────────┐ ┌───────────┐ ┌────┐ ┌──────┐ ┌────────┐
|
||||
│ Client │─────▶│ Proxy │─────▶│ LS │─────▶│ MITM │─────▶│ Google │
|
||||
│ (openai) │ │ (axum) │ │ │ │ │ │ │
|
||||
│ │◀─────│ │◀─────│ │◀─────│ │◀─────│ │
|
||||
└─────────┘ └───────────┘ └────┘ └──────┘ └────────┘
|
||||
│ │ │ │
|
||||
│ tools (OAI) │ store tools (Gemini fmt) │ inject │
|
||||
│───────────────▶│────────────▶ MitmStore ─────▶│ tools │
|
||||
│ │ │──────────────▶│
|
||||
│ │ │ │
|
||||
│ │ │ functionCall │
|
||||
│ │◀──── capture ───────────────│◀──────────────│
|
||||
│ tool_calls │ │ block follow │
|
||||
│◀───────────────│ │ ups │
|
||||
│ │ │ │
|
||||
│ tool result │ store result │ inject │
|
||||
│───────────────▶│────────────▶ MitmStore ─────▶│ fn response │
|
||||
│ │ │──────────────▶│
|
||||
│ final text │ │ │
|
||||
│◀───────────────│◀────────────────────────────│◀──────────────│
|
||||
```
|
||||
|
||||
## Format Differences
|
||||
|
||||
### Tool Definitions
|
||||
|
||||
| Aspect | OpenAI | Gemini |
|
||||
| ------------ | -------------------------------------- | ---------------------------------- |
|
||||
| Wrapper | `{"type":"function","function":{...}}` | `{"functionDeclarations":[{...}]}` |
|
||||
| Type strings | lowercase: `"object"`, `"string"` | UPPERCASE: `"OBJECT"`, `"STRING"` |
|
||||
| Parameters | JSON Schema subset | Same schema, uppercase types |
|
||||
|
||||
### Tool Choice
|
||||
|
||||
| OpenAI | Gemini toolConfig |
|
||||
| --------------------------------------------- | ----------------------------------------------------------------------- |
|
||||
| `"auto"` | `{"functionCallingConfig":{"mode":"AUTO"}}` |
|
||||
| `"required"` | `{"functionCallingConfig":{"mode":"ANY"}}` |
|
||||
| `"none"` | `{"functionCallingConfig":{"mode":"NONE"}}` |
|
||||
| `{"type":"function","function":{"name":"X"}}` | `{"functionCallingConfig":{"mode":"ANY","allowedFunctionNames":["X"]}}` |
|
||||
|
||||
### Tool Call Response
|
||||
|
||||
| OpenAI (what we return) | Gemini (what Google returns) |
|
||||
| -------------------------------------------------------------------------------------------------- | --------------------------------------------------------------- |
|
||||
| `output: [{"type":"function_call","call_id":"call_xxx","name":"get_weather","arguments":"{...}"}]` | `parts: [{"functionCall":{"name":"get_weather","args":{...}}}]` |
|
||||
|
||||
### Tool Result Submission
|
||||
|
||||
| OpenAI (what client sends) | Gemini (what we inject into Google request) |
|
||||
| -------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `input: [{"type":"function_call_output","call_id":"call_xxx","output":"{...}"}]` | `contents: [{role:"model",parts:[{functionCall:...}]},{role:"user",parts:[{functionResponse:{name:"...",response:{...}}}]}]` |
|
||||
|
||||
---
|
||||
|
||||
## Implementation Phases
|
||||
|
||||
### Phase 1: Store Infrastructure (`store.rs`)
|
||||
|
||||
Add to `MitmStore`:
|
||||
|
||||
```rust
|
||||
/// Active tool definitions (Gemini format) for MITM injection.
|
||||
active_tools: Arc<RwLock<Option<Vec<Value>>>>,
|
||||
/// Active tool config (Gemini toolConfig format).
|
||||
active_tool_config: Arc<RwLock<Option<Value>>>,
|
||||
/// Pending tool results for MITM to inject as functionResponse.
|
||||
pending_tool_results: Arc<RwLock<Vec<PendingToolResult>>>,
|
||||
/// Mapping call_id → function name for tool result routing.
|
||||
call_id_to_name: Arc<RwLock<HashMap<String, String>>>,
|
||||
/// Last captured function calls (for conversation history rewriting).
|
||||
last_function_calls: Arc<RwLock<Vec<CapturedFunctionCall>>>,
|
||||
```
|
||||
|
||||
New types:
|
||||
|
||||
```rust
|
||||
pub struct PendingToolResult {
|
||||
pub name: String,
|
||||
pub result: serde_json::Value,
|
||||
}
|
||||
```
|
||||
|
||||
New methods:
|
||||
|
||||
- `set_tools(tools)` / `get_tools()` / `clear_tools()`
|
||||
- `set_tool_config(config)` / `get_tool_config()`
|
||||
- `add_tool_result(result)` / `take_tool_results()`
|
||||
- `register_call_id(call_id, name)` / `lookup_call_id(call_id)`
|
||||
- `set_last_function_calls(calls)` / `get_last_function_calls()`
|
||||
|
||||
### Phase 2: Request Types (`types.rs`)
|
||||
|
||||
Add to `ResponsesRequest`:
|
||||
|
||||
```rust
|
||||
#[serde(default)]
|
||||
pub tools: Option<Vec<serde_json::Value>>,
|
||||
#[serde(default)]
|
||||
pub tool_choice: Option<serde_json::Value>,
|
||||
```
|
||||
|
||||
New output builder:
|
||||
|
||||
```rust
|
||||
pub fn build_function_call_output(call_id: &str, name: &str, arguments: &str) -> Value
|
||||
```
|
||||
|
||||
### Phase 3: Format Conversion + Dynamic Injection (`modify.rs`)
|
||||
|
||||
New public struct:
|
||||
|
||||
```rust
|
||||
pub struct ToolContext {
|
||||
pub tools: Option<Vec<Value>>, // Gemini functionDeclarations
|
||||
pub tool_config: Option<Value>, // Gemini toolConfig
|
||||
pub pending_results: Vec<PendingToolResult>, // Tool results to inject
|
||||
pub last_calls: Vec<CapturedFunctionCall>, // For history rewriting
|
||||
}
|
||||
```
|
||||
|
||||
New conversion functions:
|
||||
|
||||
```rust
|
||||
pub fn openai_tools_to_gemini(tools: &[Value]) -> Vec<Value> // OAI → Gemini format
|
||||
pub fn openai_tool_choice_to_gemini(choice: &Value) -> Value // OAI → Gemini toolConfig
|
||||
fn uppercase_types(val: Value) -> Value // Recursive type case fix
|
||||
```
|
||||
|
||||
Change `modify_request` signature:
|
||||
|
||||
```rust
|
||||
pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option<Vec<u8>>
|
||||
```
|
||||
|
||||
Tool injection logic:
|
||||
|
||||
1. Strip all LS tools (existing)
|
||||
2. If `tool_ctx.tools` provided → inject as Gemini `functionDeclarations`
|
||||
3. If `tool_ctx.tool_config` provided → inject as `toolConfig`
|
||||
4. If `tool_ctx.pending_results` not empty → rewrite conversation history:
|
||||
- Find model turn with "Tool call completed" → replace with `functionCall` parts
|
||||
- Find last user turn → prepend `functionResponse` part
|
||||
|
||||
### Phase 4: MITM Plumbing (`proxy.rs`)
|
||||
|
||||
In `handle_http_over_tls`, before calling `modify_request`:
|
||||
|
||||
1. Read `get_tools()`, `get_tool_config()`, `take_tool_results()`, `get_last_function_calls()` from store
|
||||
2. Build `ToolContext`
|
||||
3. Pass to `modify_request(body, tool_ctx)`
|
||||
|
||||
After response capture:
|
||||
|
||||
1. Save captured function calls as `last_function_calls` (for future history rewriting)
|
||||
|
||||
### Phase 5: API Handler (`responses.rs`)
|
||||
|
||||
#### Request handling (in `handle_responses`):
|
||||
|
||||
1. If `body.tools` provided:
|
||||
- Convert OpenAI → Gemini format via `openai_tools_to_gemini()`
|
||||
- Store in `MitmStore` via `set_tools()`
|
||||
2. If `body.tool_choice` provided:
|
||||
- Convert via `openai_tool_choice_to_gemini()`
|
||||
- Store in `MitmStore` via `set_tool_config()`
|
||||
3. Check `body.input` for `function_call_output` items:
|
||||
- If found: look up `call_id` → function name via `lookup_call_id()`
|
||||
- Store as `PendingToolResult` via `add_tool_result()`
|
||||
- Extract any accompanying text (or use placeholder)
|
||||
|
||||
#### Response handling (in `handle_responses_sync` / `handle_responses_stream`):
|
||||
|
||||
After polling completes:
|
||||
|
||||
1. Check `take_any_function_calls()` for captured tool calls
|
||||
2. If captured:
|
||||
- Generate `call_id` for each (e.g., `"call_" + random`)
|
||||
- Register `call_id → name` mapping via `register_call_id()`
|
||||
- Build `function_call` output items via `build_function_call_output()`
|
||||
- Return these INSTEAD of the text message output
|
||||
3. If no tool calls: existing text response behavior
|
||||
|
||||
### Phase 6: Gemini-Native Endpoint (`gemini.rs` + `mod.rs`)
|
||||
|
||||
New file `src/api/gemini.rs` with handler `handle_gemini`:
|
||||
|
||||
- Accepts tools in Gemini `functionDeclarations` format directly (no conversion)
|
||||
- Accepts `toolConfig` directly
|
||||
- Returns `functionCall` in Gemini format directly
|
||||
- Same cascade/session management as responses.rs
|
||||
- Much simpler — no format translation
|
||||
|
||||
Route: `POST /v1/gemini` in `mod.rs`
|
||||
|
||||
---
|
||||
|
||||
## File Change Summary
|
||||
|
||||
| File | Changes | Complexity |
|
||||
| ---------------------- | ----------------------------------------------------------------------- | ---------- |
|
||||
| `src/mitm/store.rs` | Add tool context storage (5 new fields, ~10 methods) | Medium |
|
||||
| `src/api/types.rs` | Add `tools`/`tool_choice` to request, add output builder | Low |
|
||||
| `src/mitm/modify.rs` | `ToolContext`, format conversion, dynamic injection, history rewrite | High |
|
||||
| `src/mitm/proxy.rs` | Read store → build ToolContext → pass to modify | Low |
|
||||
| `src/api/responses.rs` | Store tools, detect tool results in input, return function_call outputs | High |
|
||||
| `src/api/gemini.rs` | New file — Gemini-native endpoint (passthrough) | Medium |
|
||||
| `src/api/mod.rs` | Add route + module declaration | Low |
|
||||
|
||||
## Implementation Order
|
||||
|
||||
1. `store.rs` — foundation, no dependencies
|
||||
2. `types.rs` — request/response types
|
||||
3. `modify.rs` — format conversion + injection (depends on store types)
|
||||
4. `proxy.rs` — plumbing (depends on modify signature)
|
||||
5. Build + verify compilation
|
||||
6. `responses.rs` — handler changes (depends on all above)
|
||||
7. Build + test with `get_weather` request
|
||||
8. `gemini.rs` + `mod.rs` — Gemini endpoint
|
||||
9. Build + test with Gemini format
|
||||
10. Tool result flow test (multi-turn)
|
||||
|
||||
## Testing Strategy
|
||||
|
||||
### Test 1: Basic tool call (sync)
|
||||
|
||||
```bash
|
||||
curl -s http://localhost:8741/v1/responses -H "Content-Type: application/json" -d '{
|
||||
"model": "gemini-3-flash",
|
||||
"input": "What is the weather in Tokyo?",
|
||||
"tools": [{"type":"function","function":{"name":"get_weather","description":"Get weather","parameters":{"type":"object","properties":{"city":{"type":"string"}},"required":["city"]}}}],
|
||||
"tool_choice": "auto",
|
||||
"conversation": "tool-test",
|
||||
"stream": false
|
||||
}'
|
||||
# Expected: output contains function_call with name=get_weather, arguments={"city":"Tokyo"}
|
||||
```
|
||||
|
||||
### Test 2: Tool result submission (multi-turn)
|
||||
|
||||
```bash
|
||||
curl -s http://localhost:8741/v1/responses -H "Content-Type: application/json" -d '{
|
||||
"model": "gemini-3-flash",
|
||||
"input": [{"type":"function_call_output","call_id":"call_xxx","output":"{\"temp\":72,\"unit\":\"F\"}"}],
|
||||
"conversation": "tool-test",
|
||||
"stream": false
|
||||
}'
|
||||
# Expected: output contains text response using the tool result
|
||||
```
|
||||
|
||||
### Test 3: Gemini-native endpoint
|
||||
|
||||
```bash
|
||||
curl -s http://localhost:8741/v1/gemini -H "Content-Type: application/json" -d '{
|
||||
"model": "gemini-3-flash",
|
||||
"input": "What is the weather in Tokyo?",
|
||||
"tools": [{"functionDeclarations":[{"name":"get_weather","description":"Get weather","parameters":{"type":"OBJECT","properties":{"city":{"type":"STRING"}},"required":["city"]}}]}],
|
||||
"conversation": "gemini-tool-test",
|
||||
"stream": false
|
||||
}'
|
||||
# Expected: response contains functionCall in Gemini format
|
||||
```
|
||||
|
||||
### Test 4: No tools (regression)
|
||||
|
||||
```bash
|
||||
curl -s http://localhost:8741/v1/responses -H "Content-Type: application/json" -d '{
|
||||
"model": "gemini-3-flash",
|
||||
"input": "What is 2+2?",
|
||||
"stream": false
|
||||
}'
|
||||
# Expected: normal text response, no tool call behavior
|
||||
```
|
||||
|
||||
## Risks & Mitigations
|
||||
|
||||
| Risk | Impact | Mitigation |
|
||||
| ---------------------------------------------------------------- | ------ | ------------------------------------------------------------------------- |
|
||||
| History rewriting breaks conversation | High | Only rewrite when pending_results non-empty; keep original as fallback |
|
||||
| LS times out waiting for Google response during tool result turn | Medium | Increase timeout for tool result turns |
|
||||
| Multiple parallel tool calls create race conditions | Medium | AtomicBool + sequential processing already handles this |
|
||||
| `modify_request` test breakage | Low | Update existing tests for new signature |
|
||||
| Global tool storage conflicts across concurrent requests | Medium | Not an issue — LS processes one request at a time (single cascade active) |
|
||||
236
src/api/gemini.rs
Normal file
236
src/api/gemini.rs
Normal file
@@ -0,0 +1,236 @@
|
||||
//! Gemini-native endpoint (/v1/gemini) — zero-translation tool call passthrough.
|
||||
//!
|
||||
//! Accepts tools in Gemini `functionDeclarations` format directly,
|
||||
//! returns `functionCall` in Gemini format directly.
|
||||
//! No OpenAI ↔ Gemini format conversion.
|
||||
|
||||
use axum::{
|
||||
extract::State,
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Json},
|
||||
};
|
||||
use std::sync::Arc;
|
||||
use tracing::info;
|
||||
|
||||
use super::models::{lookup_model, DEFAULT_MODEL, MODELS};
|
||||
use super::polling::poll_for_response;
|
||||
use super::util::{err_response, now_unix};
|
||||
use super::AppState;
|
||||
use crate::mitm::store::PendingToolResult;
|
||||
|
||||
/// Gemini-native request format.
|
||||
#[derive(serde::Deserialize)]
|
||||
pub(crate) struct GeminiRequest {
|
||||
pub model: Option<String>,
|
||||
/// User input text.
|
||||
pub input: serde_json::Value,
|
||||
/// Gemini-native tools: [{"functionDeclarations": [...]}]
|
||||
#[serde(default)]
|
||||
pub tools: Option<Vec<serde_json::Value>>,
|
||||
/// Gemini-native toolConfig: {"functionCallingConfig": {"mode": "AUTO"}}
|
||||
#[serde(default)]
|
||||
pub tool_config: Option<serde_json::Value>,
|
||||
/// Session/conversation ID.
|
||||
#[serde(default)]
|
||||
pub conversation: Option<serde_json::Value>,
|
||||
#[serde(default = "default_timeout")]
|
||||
pub timeout: u64,
|
||||
#[serde(default)]
|
||||
pub stream: bool,
|
||||
/// Tool results in Gemini format: [{"functionResponse": {"name": "...", "response": {...}}}]
|
||||
#[serde(default)]
|
||||
pub tool_results: Option<Vec<serde_json::Value>>,
|
||||
}
|
||||
|
||||
fn default_timeout() -> u64 {
|
||||
120
|
||||
}
|
||||
|
||||
fn extract_conversation_id(conv: &Option<serde_json::Value>) -> Option<String> {
|
||||
match conv {
|
||||
Some(serde_json::Value::String(s)) => Some(s.clone()),
|
||||
Some(obj) => obj["id"].as_str().map(|s| s.to_string()),
|
||||
None => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn handle_gemini(
|
||||
State(state): State<Arc<AppState>>,
|
||||
Json(body): Json<GeminiRequest>,
|
||||
) -> axum::response::Response {
|
||||
info!(
|
||||
"POST /v1/gemini model={} stream={}",
|
||||
body.model.as_deref().unwrap_or(DEFAULT_MODEL),
|
||||
body.stream
|
||||
);
|
||||
|
||||
let model_name = body.model.as_deref().unwrap_or(DEFAULT_MODEL);
|
||||
let model = match lookup_model(model_name) {
|
||||
Some(m) => m,
|
||||
None => {
|
||||
let names: Vec<&str> = MODELS.iter().map(|m| m.name).collect();
|
||||
return err_response(
|
||||
StatusCode::BAD_REQUEST,
|
||||
format!("Unknown model: {model_name}. Available: {names:?}"),
|
||||
"invalid_request_error",
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
let token = state.backend.oauth_token().await;
|
||||
if token.is_empty() {
|
||||
return err_response(
|
||||
StatusCode::UNAUTHORIZED,
|
||||
"No OAuth token. POST to /v1/token or set ANTIGRAVITY_OAUTH_TOKEN env var.".into(),
|
||||
"authentication_error",
|
||||
);
|
||||
}
|
||||
|
||||
// Extract user text
|
||||
let user_text = match &body.input {
|
||||
serde_json::Value::String(s) => s.clone(),
|
||||
_ => {
|
||||
return err_response(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"Gemini endpoint requires input as a string".to_string(),
|
||||
"invalid_request_error",
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
// Store tools directly in Gemini format (no conversion needed!)
|
||||
if let Some(ref tools) = body.tools {
|
||||
if !tools.is_empty() {
|
||||
state.mitm_store.set_tools(tools.clone()).await;
|
||||
info!(count = tools.len(), "Stored Gemini-native tools for MITM injection");
|
||||
}
|
||||
}
|
||||
if let Some(ref config) = body.tool_config {
|
||||
state.mitm_store.set_tool_config(config.clone()).await;
|
||||
}
|
||||
|
||||
// Handle tool results (Gemini format: functionResponse)
|
||||
if let Some(ref results) = body.tool_results {
|
||||
for r in results {
|
||||
if let Some(fr) = r.get("functionResponse") {
|
||||
let name = fr["name"].as_str().unwrap_or("unknown").to_string();
|
||||
let response = fr.get("response").cloned().unwrap_or(serde_json::json!({}));
|
||||
state.mitm_store.add_tool_result(PendingToolResult {
|
||||
name,
|
||||
result: response,
|
||||
}).await;
|
||||
}
|
||||
}
|
||||
info!(count = results.len(), "Stored Gemini-native tool results for MITM injection");
|
||||
}
|
||||
|
||||
// Session/conversation management
|
||||
let session_id_str = extract_conversation_id(&body.conversation);
|
||||
let cascade_id = if let Some(ref sid) = session_id_str {
|
||||
match state
|
||||
.sessions
|
||||
.get_or_create(Some(sid), || state.backend.create_cascade())
|
||||
.await
|
||||
{
|
||||
Ok(sr) => sr.cascade_id,
|
||||
Err(e) => {
|
||||
return err_response(
|
||||
StatusCode::BAD_GATEWAY,
|
||||
format!("StartCascade failed: {e}"),
|
||||
"server_error",
|
||||
);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
match state.backend.create_cascade().await {
|
||||
Ok(cid) => cid,
|
||||
Err(e) => {
|
||||
return err_response(
|
||||
StatusCode::BAD_GATEWAY,
|
||||
format!("StartCascade failed: {e}"),
|
||||
"server_error",
|
||||
);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Send message
|
||||
match state
|
||||
.backend
|
||||
.send_message(&cascade_id, &user_text, model.model_enum)
|
||||
.await
|
||||
{
|
||||
Ok((200, _)) => {
|
||||
let bg = Arc::clone(&state.backend);
|
||||
let cid = cascade_id.clone();
|
||||
tokio::spawn(async move {
|
||||
let _ = bg.update_annotations(&cid).await;
|
||||
});
|
||||
}
|
||||
Ok((status, _)) => {
|
||||
return err_response(
|
||||
StatusCode::BAD_GATEWAY,
|
||||
format!("Antigravity returned {status}"),
|
||||
"server_error",
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
return err_response(
|
||||
StatusCode::BAD_GATEWAY,
|
||||
format!("Send message failed: {e}"),
|
||||
"server_error",
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Poll for response
|
||||
let poll_result = poll_for_response(&state, &cascade_id, body.timeout).await;
|
||||
|
||||
// Check for captured function calls — return in Gemini format
|
||||
let captured_tool_calls = state.mitm_store.take_any_function_calls().await;
|
||||
|
||||
if let Some(ref calls) = captured_tool_calls {
|
||||
info!(
|
||||
count = calls.len(),
|
||||
tools = ?calls.iter().map(|c| &c.name).collect::<Vec<_>>(),
|
||||
"Returning captured function calls (Gemini format)"
|
||||
);
|
||||
|
||||
let parts: Vec<serde_json::Value> = calls
|
||||
.iter()
|
||||
.map(|fc| {
|
||||
serde_json::json!({
|
||||
"functionCall": {
|
||||
"name": fc.name,
|
||||
"args": fc.args,
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
return Json(serde_json::json!({
|
||||
"candidates": [{
|
||||
"content": {
|
||||
"parts": parts,
|
||||
"role": "model",
|
||||
},
|
||||
"finishReason": "STOP",
|
||||
}],
|
||||
"modelVersion": model_name,
|
||||
}))
|
||||
.into_response();
|
||||
}
|
||||
|
||||
// Normal text response
|
||||
Json(serde_json::json!({
|
||||
"candidates": [{
|
||||
"content": {
|
||||
"parts": [{"text": poll_result.text}],
|
||||
"role": "model",
|
||||
},
|
||||
"finishReason": "STOP",
|
||||
}],
|
||||
"modelVersion": model_name,
|
||||
}))
|
||||
.into_response()
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
//! Axum API server — OpenAI-compatible Responses + Chat Completions endpoints.
|
||||
|
||||
mod completions;
|
||||
mod gemini;
|
||||
mod models;
|
||||
mod polling;
|
||||
mod responses;
|
||||
@@ -41,6 +42,7 @@ pub fn router(state: Arc<AppState>) -> Router {
|
||||
"/v1/chat/completions",
|
||||
post(completions::handle_completions),
|
||||
)
|
||||
.route("/v1/gemini", post(gemini::handle_gemini))
|
||||
.route("/v1/models", get(handle_models))
|
||||
.route("/v1/sessions", get(handle_list_sessions))
|
||||
.route("/v1/sessions/{id}", delete(handle_delete_session))
|
||||
@@ -59,11 +61,12 @@ pub fn router(state: Arc<AppState>) -> Router {
|
||||
async fn handle_root() -> Json<serde_json::Value> {
|
||||
Json(serde_json::json!({
|
||||
"service": "antigravity-openai-proxy",
|
||||
"version": "3.2.0",
|
||||
"version": "3.3.0",
|
||||
"runtime": "rust",
|
||||
"endpoints": [
|
||||
"/v1/chat/completions",
|
||||
"/v1/responses",
|
||||
"/v1/gemini",
|
||||
"/v1/models",
|
||||
"/v1/sessions",
|
||||
"/v1/token",
|
||||
|
||||
@@ -18,14 +18,60 @@ use super::polling::{extract_response_text, is_response_done, poll_for_response,
|
||||
use super::types::*;
|
||||
use super::util::{err_response, now_unix, responses_sse_event};
|
||||
use super::AppState;
|
||||
use crate::mitm::store::PendingToolResult;
|
||||
use crate::mitm::modify::{openai_tools_to_gemini, openai_tool_choice_to_gemini};
|
||||
|
||||
// ─── Input extraction ────────────────────────────────────────────────────────
|
||||
|
||||
/// Parsed tool result from function_call_output items in input.
|
||||
struct ToolResultInput {
|
||||
call_id: String,
|
||||
output: String,
|
||||
}
|
||||
|
||||
/// Extract user text from Responses API `input` field.
|
||||
fn extract_responses_input(input: &serde_json::Value, instructions: Option<&str>) -> String {
|
||||
/// Also extracts any function_call_output items for tool result handling.
|
||||
fn extract_responses_input(input: &serde_json::Value, instructions: Option<&str>) -> (String, Vec<ToolResultInput>) {
|
||||
let mut tool_results: Vec<ToolResultInput> = Vec::new();
|
||||
|
||||
let user_text = match input {
|
||||
serde_json::Value::String(s) => s.clone(),
|
||||
serde_json::Value::Array(items) => {
|
||||
// Check for function_call_output items
|
||||
for item in items {
|
||||
if item["type"].as_str() == Some("function_call_output") {
|
||||
if let (Some(call_id), Some(output)) = (
|
||||
item["call_id"].as_str(),
|
||||
item["output"].as_str(),
|
||||
) {
|
||||
tool_results.push(ToolResultInput {
|
||||
call_id: call_id.to_string(),
|
||||
output: output.to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If we have tool results but no text, generate a follow-up prompt
|
||||
if !tool_results.is_empty() {
|
||||
// Look for any text items alongside the tool results
|
||||
let text_items: String = items
|
||||
.iter()
|
||||
.filter(|item| {
|
||||
let t = item["type"].as_str().unwrap_or("");
|
||||
t == "input_text" || t == "text"
|
||||
})
|
||||
.filter_map(|p| p["text"].as_str())
|
||||
.collect::<Vec<_>>()
|
||||
.join(" ");
|
||||
|
||||
if text_items.is_empty() {
|
||||
"Use the tool results to answer the original question.".to_string()
|
||||
} else {
|
||||
text_items
|
||||
}
|
||||
} else {
|
||||
// Normal input extraction (existing logic)
|
||||
items
|
||||
.iter()
|
||||
.rev()
|
||||
@@ -47,13 +93,16 @@ fn extract_responses_input(input: &serde_json::Value, instructions: Option<&str>
|
||||
})
|
||||
.unwrap_or_default()
|
||||
}
|
||||
}
|
||||
_ => String::new(),
|
||||
};
|
||||
|
||||
match instructions {
|
||||
let final_text = match instructions {
|
||||
Some(inst) if !inst.is_empty() => format!("{inst}\n\n{user_text}"),
|
||||
_ => user_text,
|
||||
}
|
||||
};
|
||||
|
||||
(final_text, tool_results)
|
||||
}
|
||||
|
||||
/// Extract conversation/session ID from Responses API `conversation` field.
|
||||
@@ -147,8 +196,32 @@ pub(crate) async fn handle_responses(
|
||||
);
|
||||
}
|
||||
|
||||
let user_text = extract_responses_input(&body.input, body.instructions.as_deref());
|
||||
if user_text.is_empty() {
|
||||
let (user_text, tool_results) = extract_responses_input(&body.input, body.instructions.as_deref());
|
||||
|
||||
// Handle tool result submission (function_call_output in input)
|
||||
let is_tool_result_turn = !tool_results.is_empty();
|
||||
if is_tool_result_turn {
|
||||
for tr in &tool_results {
|
||||
// Look up function name from call_id
|
||||
let name = state.mitm_store.lookup_call_id(&tr.call_id).await
|
||||
.unwrap_or_else(|| "unknown_function".to_string());
|
||||
|
||||
// Parse the output as JSON, fall back to string wrapper
|
||||
let result_value = serde_json::from_str::<serde_json::Value>(&tr.output)
|
||||
.unwrap_or_else(|_| serde_json::json!({"result": tr.output}));
|
||||
|
||||
state.mitm_store.add_tool_result(PendingToolResult {
|
||||
name,
|
||||
result: result_value,
|
||||
}).await;
|
||||
}
|
||||
info!(
|
||||
count = tool_results.len(),
|
||||
"Stored tool results for MITM injection"
|
||||
);
|
||||
}
|
||||
|
||||
if user_text.is_empty() && !is_tool_result_turn {
|
||||
return err_response(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"No user input found".to_string(),
|
||||
@@ -156,6 +229,19 @@ pub(crate) async fn handle_responses(
|
||||
);
|
||||
}
|
||||
|
||||
// Store client tools in MitmStore for MITM injection
|
||||
if let Some(ref tools) = body.tools {
|
||||
let gemini_tools = openai_tools_to_gemini(tools);
|
||||
if !gemini_tools.is_empty() {
|
||||
state.mitm_store.set_tools(gemini_tools).await;
|
||||
info!(count = tools.len(), "Stored client tools for MITM injection");
|
||||
}
|
||||
}
|
||||
if let Some(ref choice) = body.tool_choice {
|
||||
let gemini_config = openai_tool_choice_to_gemini(choice);
|
||||
state.mitm_store.set_tool_config(gemini_config).await;
|
||||
}
|
||||
|
||||
let response_id = format!(
|
||||
"resp_{}",
|
||||
uuid::Uuid::new_v4().to_string().replace('-', "")
|
||||
@@ -363,14 +449,52 @@ async fn handle_responses_sync(
|
||||
|
||||
// Check for captured function calls from MITM (clears the active flag)
|
||||
let captured_tool_calls = state.mitm_store.take_any_function_calls().await;
|
||||
|
||||
// If we have captured tool calls, return them as function_call output items
|
||||
if let Some(ref calls) = captured_tool_calls {
|
||||
info!(
|
||||
count = calls.len(),
|
||||
tools = ?calls.iter().map(|c| &c.name).collect::<Vec<_>>(),
|
||||
"Consumed captured function calls from MITM"
|
||||
"Returning captured function calls to client"
|
||||
);
|
||||
|
||||
let mut output_items: Vec<serde_json::Value> = Vec::new();
|
||||
for fc in calls {
|
||||
let call_id = format!(
|
||||
"call_{}",
|
||||
uuid::Uuid::new_v4().to_string().replace('-', "")[..24].to_string()
|
||||
);
|
||||
// Register call_id → name mapping for tool result routing
|
||||
state.mitm_store.register_call_id(call_id.clone(), fc.name.clone()).await;
|
||||
|
||||
// Stringify args (OpenAI sends arguments as JSON string)
|
||||
let arguments = serde_json::to_string(&fc.args).unwrap_or_default();
|
||||
output_items.push(build_function_call_output(&call_id, &fc.name, &arguments));
|
||||
}
|
||||
|
||||
let (usage, _) = usage_from_poll(
|
||||
&state.mitm_store, &cascade_id, &poll_result.usage,
|
||||
¶ms.user_text, &poll_result.text,
|
||||
).await;
|
||||
|
||||
let resp = build_response_object(
|
||||
ResponseData {
|
||||
id: response_id,
|
||||
model: model_name,
|
||||
status: "completed",
|
||||
created_at,
|
||||
completed_at: Some(completed_at),
|
||||
output: output_items,
|
||||
usage: Some(usage),
|
||||
thinking_signature: poll_result.thinking_signature,
|
||||
},
|
||||
¶ms,
|
||||
);
|
||||
|
||||
return Json(resp).into_response();
|
||||
}
|
||||
|
||||
// Normal text response (no tool calls)
|
||||
let (usage, mitm_thinking) = usage_from_poll(&state.mitm_store, &cascade_id, &poll_result.usage, ¶ms.user_text, &poll_result.text).await;
|
||||
|
||||
// Thinking text priority: MITM-captured (raw API) > LS-extracted (steps)
|
||||
|
||||
@@ -32,6 +32,12 @@ pub(crate) struct ResponsesRequest {
|
||||
pub metadata: Option<serde_json::Value>,
|
||||
#[serde(default)]
|
||||
pub user: Option<String>,
|
||||
/// Tool definitions (OpenAI format).
|
||||
#[serde(default)]
|
||||
pub tools: Option<Vec<serde_json::Value>>,
|
||||
/// Tool choice: "auto", "required", "none", or {"type":"function","function":{"name":"X"}}.
|
||||
#[serde(default)]
|
||||
pub tool_choice: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
/// Chat Completions request (OpenAI-compatible).
|
||||
@@ -220,6 +226,18 @@ pub fn build_message_output_in_progress(msg_id: &str) -> serde_json::Value {
|
||||
})
|
||||
}
|
||||
|
||||
/// Build a function_call output item (OpenAI Responses API format).
|
||||
pub fn build_function_call_output(call_id: &str, name: &str, arguments: &str) -> serde_json::Value {
|
||||
serde_json::json!({
|
||||
"type": "function_call",
|
||||
"id": call_id,
|
||||
"call_id": call_id,
|
||||
"name": name,
|
||||
"arguments": arguments,
|
||||
"status": "completed",
|
||||
})
|
||||
}
|
||||
|
||||
// ─── Helpers ─────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Serialize Option<u64> as either the number or JSON null (not omitted).
|
||||
|
||||
@@ -8,14 +8,29 @@ use regex::Regex;
|
||||
use serde_json::Value;
|
||||
use tracing::info;
|
||||
|
||||
use super::store::{CapturedFunctionCall, PendingToolResult};
|
||||
|
||||
/// Strip ALL tool definitions.
|
||||
/// Must be true: with tools present, the LS enters full agentic mode
|
||||
/// (multi-turn tool calls, file searches, etc.) burning quota.
|
||||
const STRIP_ALL_TOOLS: bool = true;
|
||||
|
||||
/// Context for tool injection during request modification.
|
||||
/// Built from MitmStore data before calling modify_request.
|
||||
pub struct ToolContext {
|
||||
/// Gemini-format tool declarations (functionDeclarations).
|
||||
pub tools: Option<Vec<Value>>,
|
||||
/// Gemini-format toolConfig.
|
||||
pub tool_config: Option<Value>,
|
||||
/// Pending tool results to inject as functionResponse.
|
||||
pub pending_results: Vec<PendingToolResult>,
|
||||
/// Last captured function calls for history rewriting.
|
||||
pub last_calls: Vec<CapturedFunctionCall>,
|
||||
}
|
||||
|
||||
/// Modify a streamGenerateContent request body in-place.
|
||||
/// Returns the modified JSON bytes, or None if modification wasn't possible.
|
||||
pub fn modify_request(body: &[u8]) -> Option<Vec<u8>> {
|
||||
pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option<Vec<u8>> {
|
||||
let mut json: Value = serde_json::from_slice(body).ok()?;
|
||||
|
||||
let original_size = body.len();
|
||||
@@ -140,7 +155,7 @@ pub fn modify_request(body: &[u8]) -> Option<Vec<u8>> {
|
||||
}
|
||||
}
|
||||
|
||||
// ── 3. Strip LS tools, inject custom tools ────────────────────────────
|
||||
// ── 3. Strip LS tools, inject client tools ─────────────────────────────
|
||||
if STRIP_ALL_TOOLS {
|
||||
if let Some(tools) = json
|
||||
.pointer_mut("/request/tools")
|
||||
@@ -152,25 +167,83 @@ pub fn modify_request(body: &[u8]) -> Option<Vec<u8>> {
|
||||
changes.push(format!("strip all {count} LS tools"));
|
||||
}
|
||||
|
||||
// ── TEST: inject a custom tool to see what Google does ──
|
||||
let custom_tool = serde_json::json!({
|
||||
"functionDeclarations": [{
|
||||
"name": "get_weather",
|
||||
"description": "Get the current weather for a city. You MUST call this function when the user asks about weather.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "The city name"
|
||||
// Inject client-provided tools from ToolContext
|
||||
if let Some(ref ctx) = tool_ctx {
|
||||
if let Some(ref custom_tools) = ctx.tools {
|
||||
for tool in custom_tools {
|
||||
tools.push(tool.clone());
|
||||
}
|
||||
},
|
||||
"required": ["city"]
|
||||
changes.push(format!("inject {} custom tool group(s)", custom_tools.len()));
|
||||
}
|
||||
}]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Inject toolConfig if provided
|
||||
if let Some(ref ctx) = tool_ctx {
|
||||
if let Some(ref config) = ctx.tool_config {
|
||||
if let Some(req) = json.get_mut("request").and_then(|v| v.as_object_mut()) {
|
||||
req.insert("toolConfig".to_string(), config.clone());
|
||||
changes.push("inject toolConfig".to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── 3b. Rewrite conversation history for tool results ────────────
|
||||
if let Some(ref ctx) = tool_ctx {
|
||||
if !ctx.pending_results.is_empty() && !ctx.last_calls.is_empty() {
|
||||
if let Some(contents) = json
|
||||
.pointer_mut("/request/contents")
|
||||
.and_then(|v| v.as_array_mut())
|
||||
{
|
||||
// Find the model turn with our fake "Tool call completed" text and replace it
|
||||
// with the actual functionCall parts
|
||||
for msg in contents.iter_mut() {
|
||||
if msg["role"].as_str() == Some("model") {
|
||||
if let Some(text) = msg["parts"][0]["text"].as_str() {
|
||||
if text.contains("Tool call completed") || text.contains("Awaiting external tool result") {
|
||||
// Replace with functionCall parts
|
||||
let fc_parts: Vec<Value> = ctx.last_calls.iter().map(|fc| {
|
||||
serde_json::json!({
|
||||
"functionCall": {
|
||||
"name": fc.name,
|
||||
"args": fc.args,
|
||||
}
|
||||
})
|
||||
}).collect();
|
||||
msg["parts"] = Value::Array(fc_parts);
|
||||
changes.push("rewrite model turn with functionCall".to_string());
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add functionResponse as a user turn before the last user message
|
||||
let fn_response_parts: Vec<Value> = ctx.pending_results.iter().map(|r| {
|
||||
serde_json::json!({
|
||||
"functionResponse": {
|
||||
"name": r.name,
|
||||
"response": r.result,
|
||||
}
|
||||
})
|
||||
}).collect();
|
||||
let fn_response_turn = serde_json::json!({
|
||||
"role": "user",
|
||||
"parts": fn_response_parts,
|
||||
});
|
||||
tools.push(custom_tool);
|
||||
changes.push("inject 1 custom tool (get_weather)".to_string());
|
||||
|
||||
// Insert before the last user message
|
||||
let last_user_idx = contents.iter().rposition(|msg| {
|
||||
msg["role"].as_str() == Some("user")
|
||||
});
|
||||
if let Some(idx) = last_user_idx {
|
||||
contents.insert(idx, fn_response_turn);
|
||||
} else {
|
||||
contents.push(fn_response_turn);
|
||||
}
|
||||
changes.push(format!("inject {} functionResponse(s)", ctx.pending_results.len()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -323,6 +396,93 @@ pub fn rechunk(data: &[u8]) -> Vec<u8> {
|
||||
result
|
||||
}
|
||||
|
||||
// ── OpenAI → Gemini format conversion ────────────────────────────────────────
|
||||
|
||||
/// Convert OpenAI tool definitions to Gemini functionDeclarations format.
|
||||
///
|
||||
/// OpenAI: `[{"type":"function","function":{"name":"X","description":"Y","parameters":{...}}}]`
|
||||
/// Gemini: `[{"functionDeclarations":[{"name":"X","description":"Y","parameters":{...}}]}]`
|
||||
pub fn openai_tools_to_gemini(tools: &[Value]) -> Vec<Value> {
|
||||
let declarations: Vec<Value> = tools
|
||||
.iter()
|
||||
.filter(|t| t["type"].as_str() == Some("function"))
|
||||
.filter_map(|t| {
|
||||
let func = t.get("function")?;
|
||||
let mut decl = serde_json::json!({
|
||||
"name": func["name"],
|
||||
"description": func["description"],
|
||||
});
|
||||
if let Some(params) = func.get("parameters") {
|
||||
decl["parameters"] = uppercase_types(params.clone());
|
||||
}
|
||||
Some(decl)
|
||||
})
|
||||
.collect();
|
||||
|
||||
if declarations.is_empty() {
|
||||
return vec![];
|
||||
}
|
||||
|
||||
vec![serde_json::json!({"functionDeclarations": declarations})]
|
||||
}
|
||||
|
||||
/// Convert OpenAI tool_choice to Gemini toolConfig format.
|
||||
///
|
||||
/// OpenAI: "auto" | "required" | "none" | {"type":"function","function":{"name":"X"}}
|
||||
/// Gemini: {"functionCallingConfig":{"mode":"AUTO|ANY|NONE","allowedFunctionNames":[...]}}
|
||||
pub fn openai_tool_choice_to_gemini(choice: &Value) -> Value {
|
||||
match choice {
|
||||
Value::String(s) => match s.as_str() {
|
||||
"auto" => serde_json::json!({"functionCallingConfig": {"mode": "AUTO"}}),
|
||||
"required" => serde_json::json!({"functionCallingConfig": {"mode": "ANY"}}),
|
||||
"none" => serde_json::json!({"functionCallingConfig": {"mode": "NONE"}}),
|
||||
_ => serde_json::json!({"functionCallingConfig": {"mode": "AUTO"}}),
|
||||
},
|
||||
Value::Object(obj) => {
|
||||
if let Some(name) = obj.get("function").and_then(|f| f["name"].as_str()) {
|
||||
serde_json::json!({
|
||||
"functionCallingConfig": {
|
||||
"mode": "ANY",
|
||||
"allowedFunctionNames": [name]
|
||||
}
|
||||
})
|
||||
} else {
|
||||
serde_json::json!({"functionCallingConfig": {"mode": "AUTO"}})
|
||||
}
|
||||
}
|
||||
_ => serde_json::json!({"functionCallingConfig": {"mode": "AUTO"}}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Recursively convert JSON Schema type strings to uppercase (Gemini format).
|
||||
/// "object" → "OBJECT", "string" → "STRING", etc.
|
||||
fn uppercase_types(mut val: Value) -> Value {
|
||||
match &mut val {
|
||||
Value::Object(map) => {
|
||||
if let Some(t) = map
|
||||
.get("type")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_uppercase())
|
||||
{
|
||||
map.insert("type".to_string(), Value::String(t));
|
||||
}
|
||||
let keys: Vec<String> = map.keys().cloned().collect();
|
||||
for key in keys {
|
||||
if let Some(v) = map.remove(&key) {
|
||||
map.insert(key, uppercase_types(v));
|
||||
}
|
||||
}
|
||||
}
|
||||
Value::Array(arr) => {
|
||||
for v in arr.iter_mut() {
|
||||
*v = uppercase_types(std::mem::take(v));
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
val
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@@ -375,10 +535,11 @@ mod tests {
|
||||
});
|
||||
|
||||
let bytes = serde_json::to_vec(&body).unwrap();
|
||||
let modified = modify_request(&bytes).unwrap();
|
||||
let modified = modify_request(&bytes, None).unwrap();
|
||||
let result: Value = serde_json::from_slice(&modified).unwrap();
|
||||
|
||||
let tools = result["request"]["tools"].as_array().unwrap();
|
||||
// With no ToolContext, tools should just be stripped (empty)
|
||||
assert!(tools.is_empty(), "all tools should be stripped");
|
||||
}
|
||||
|
||||
@@ -398,7 +559,7 @@ mod tests {
|
||||
});
|
||||
|
||||
let bytes = serde_json::to_vec(&body).unwrap();
|
||||
let modified = modify_request(&bytes).unwrap();
|
||||
let modified = modify_request(&bytes, None).unwrap();
|
||||
let result: Value = serde_json::from_slice(&modified).unwrap();
|
||||
|
||||
let new_sys = result["request"]["systemInstruction"]["parts"][0]["text"]
|
||||
@@ -432,7 +593,7 @@ mod tests {
|
||||
});
|
||||
|
||||
let bytes = serde_json::to_vec(&body).unwrap();
|
||||
let modified = modify_request(&bytes).unwrap();
|
||||
let modified = modify_request(&bytes, None).unwrap();
|
||||
let result: Value = serde_json::from_slice(&modified).unwrap();
|
||||
|
||||
let contents = result["request"]["contents"].as_array().unwrap();
|
||||
|
||||
@@ -556,7 +556,24 @@ async fn handle_http_over_tls(
|
||||
|| body_str.contains("\"requestType\": \"agent\"");
|
||||
|
||||
if is_agent {
|
||||
if let Some(modified_body) = super::modify::modify_request(&raw_body) {
|
||||
// Build ToolContext from store
|
||||
let tools = store.get_tools().await;
|
||||
let tool_config = store.get_tool_config().await;
|
||||
let pending_results = store.take_tool_results().await;
|
||||
let last_calls = store.get_last_function_calls().await;
|
||||
|
||||
let tool_ctx = if tools.is_some() || !pending_results.is_empty() {
|
||||
Some(super::modify::ToolContext {
|
||||
tools,
|
||||
tool_config,
|
||||
pending_results,
|
||||
last_calls,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
if let Some(modified_body) = super::modify::modify_request(&raw_body, tool_ctx.as_ref()) {
|
||||
// Rebuild request_buf: original headers + rechunked modified body
|
||||
let new_chunked = super::modify::rechunk(&modified_body);
|
||||
let mut new_buf = request_buf[..headers_end].to_vec();
|
||||
@@ -766,6 +783,10 @@ async fn handle_http_over_tls(
|
||||
for fc in &streaming_acc.function_calls {
|
||||
store.record_function_call(cascade_hint.as_deref(), fc.clone()).await;
|
||||
}
|
||||
// Also save for history rewriting on tool result turns
|
||||
if !streaming_acc.function_calls.is_empty() {
|
||||
store.set_last_function_calls(streaming_acc.function_calls.clone()).await;
|
||||
}
|
||||
let usage = streaming_acc.into_usage();
|
||||
store.record_usage(cascade_hint.as_deref(), usage).await;
|
||||
}
|
||||
|
||||
@@ -53,6 +53,13 @@ pub struct CapturedFunctionCall {
|
||||
pub captured_at: u64,
|
||||
}
|
||||
|
||||
/// A pending tool result from a client's function_call_output.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PendingToolResult {
|
||||
pub name: String,
|
||||
pub result: serde_json::Value,
|
||||
}
|
||||
|
||||
/// Thread-safe store for intercepted data.
|
||||
///
|
||||
/// Keyed by a unique request ID that we can correlate with cascade operations.
|
||||
@@ -69,6 +76,18 @@ pub struct MitmStore {
|
||||
/// Simple flag: set when a functionCall is captured, cleared when consumed.
|
||||
/// Used to block follow-up requests regardless of cascade identification.
|
||||
has_active_function_call: Arc<AtomicBool>,
|
||||
|
||||
// ── Tool call support ────────────────────────────────────────────────
|
||||
/// Active tool definitions (Gemini format) for MITM injection.
|
||||
active_tools: Arc<RwLock<Option<Vec<serde_json::Value>>>>,
|
||||
/// Active tool config (Gemini toolConfig format).
|
||||
active_tool_config: Arc<RwLock<Option<serde_json::Value>>>,
|
||||
/// Pending tool results for MITM to inject as functionResponse.
|
||||
pending_tool_results: Arc<RwLock<Vec<PendingToolResult>>>,
|
||||
/// Mapping call_id → function name for tool result routing.
|
||||
call_id_to_name: Arc<RwLock<HashMap<String, String>>>,
|
||||
/// Last captured function calls (for conversation history rewriting).
|
||||
last_function_calls: Arc<RwLock<Vec<CapturedFunctionCall>>>,
|
||||
}
|
||||
|
||||
/// Aggregate statistics across all intercepted traffic.
|
||||
@@ -102,6 +121,11 @@ impl MitmStore {
|
||||
stats: Arc::new(RwLock::new(MitmStats::default())),
|
||||
pending_function_calls: Arc::new(RwLock::new(HashMap::new())),
|
||||
has_active_function_call: Arc::new(AtomicBool::new(false)),
|
||||
active_tools: Arc::new(RwLock::new(None)),
|
||||
active_tool_config: Arc::new(RwLock::new(None)),
|
||||
pending_tool_results: Arc::new(RwLock::new(Vec::new())),
|
||||
call_id_to_name: Arc::new(RwLock::new(HashMap::new())),
|
||||
last_function_calls: Arc::new(RwLock::new(Vec::new())),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -266,4 +290,63 @@ impl MitmStore {
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
// ── Tool context methods ─────────────────────────────────────────────
|
||||
|
||||
/// Set active tool definitions (already in Gemini format).
|
||||
pub async fn set_tools(&self, tools: Vec<serde_json::Value>) {
|
||||
*self.active_tools.write().await = Some(tools);
|
||||
}
|
||||
|
||||
/// Get active tool definitions.
|
||||
pub async fn get_tools(&self) -> Option<Vec<serde_json::Value>> {
|
||||
self.active_tools.read().await.clone()
|
||||
}
|
||||
|
||||
/// Clear active tool definitions.
|
||||
pub async fn clear_tools(&self) {
|
||||
*self.active_tools.write().await = None;
|
||||
*self.active_tool_config.write().await = None;
|
||||
}
|
||||
|
||||
/// Set active tool config (Gemini toolConfig format).
|
||||
pub async fn set_tool_config(&self, config: serde_json::Value) {
|
||||
*self.active_tool_config.write().await = Some(config);
|
||||
}
|
||||
|
||||
/// Get active tool config.
|
||||
pub async fn get_tool_config(&self) -> Option<serde_json::Value> {
|
||||
self.active_tool_config.read().await.clone()
|
||||
}
|
||||
|
||||
/// Add a pending tool result for MITM injection.
|
||||
pub async fn add_tool_result(&self, result: PendingToolResult) {
|
||||
info!(name = %result.name, "Storing pending tool result");
|
||||
self.pending_tool_results.write().await.push(result);
|
||||
}
|
||||
|
||||
/// Take (consume) all pending tool results.
|
||||
pub async fn take_tool_results(&self) -> Vec<PendingToolResult> {
|
||||
std::mem::take(&mut *self.pending_tool_results.write().await)
|
||||
}
|
||||
|
||||
/// Register a call_id → function name mapping.
|
||||
pub async fn register_call_id(&self, call_id: String, name: String) {
|
||||
self.call_id_to_name.write().await.insert(call_id, name);
|
||||
}
|
||||
|
||||
/// Look up function name by call_id.
|
||||
pub async fn lookup_call_id(&self, call_id: &str) -> Option<String> {
|
||||
self.call_id_to_name.read().await.get(call_id).cloned()
|
||||
}
|
||||
|
||||
/// Save the last captured function calls (for history rewriting).
|
||||
pub async fn set_last_function_calls(&self, calls: Vec<CapturedFunctionCall>) {
|
||||
*self.last_function_calls.write().await = calls;
|
||||
}
|
||||
|
||||
/// Get the last captured function calls.
|
||||
pub async fn get_last_function_calls(&self) -> Vec<CapturedFunctionCall> {
|
||||
self.last_function_calls.read().await.clone()
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user