diff --git a/src/api/responses.rs b/src/api/responses.rs index 40674de..e02c82b 100644 --- a/src/api/responses.rs +++ b/src/api/responses.rs @@ -361,6 +361,16 @@ async fn handle_responses_sync( uuid::Uuid::new_v4().to_string().replace('-', "") ); + // Check for captured function calls from MITM (clears the active flag) + let captured_tool_calls = state.mitm_store.take_any_function_calls().await; + if let Some(ref calls) = captured_tool_calls { + info!( + count = calls.len(), + tools = ?calls.iter().map(|c| &c.name).collect::>(), + "Consumed captured function calls from MITM" + ); + } + let (usage, mitm_thinking) = usage_from_poll(&state.mitm_store, &cascade_id, &poll_result.usage, ¶ms.user_text, &poll_result.text).await; // Thinking text priority: MITM-captured (raw API) > LS-extracted (steps) diff --git a/src/mitm/intercept.rs b/src/mitm/intercept.rs index 3447319..c973218 100644 --- a/src/mitm/intercept.rs +++ b/src/mitm/intercept.rs @@ -2,9 +2,9 @@ //! //! Handles both streaming (SSE) and non-streaming (JSON) responses. -use super::store::ApiUsage; +use super::store::{ApiUsage, CapturedFunctionCall}; use serde_json::Value; -use tracing::{debug, trace}; +use tracing::{debug, info, trace}; /// Parse a complete (non-streaming) Anthropic Messages API response body. /// @@ -66,6 +66,8 @@ pub struct StreamingAccumulator { pub stop_reason: Option, pub is_complete: bool, pub api_provider: Option, + /// Captured function calls from Google's response. + pub function_calls: Vec, } impl StreamingAccumulator { @@ -96,6 +98,24 @@ impl StreamingAccumulator { self.thinking_text.push_str(text); } } + // Detect functionCall from Google (tool call response) + else if let Some(fc) = part.get("functionCall") { + let name = fc["name"].as_str().unwrap_or("unknown").to_string(); + let args = fc["args"].clone(); + info!( + tool_name = %name, + tool_args = %args, + "MITM: Google returned functionCall!" + ); + self.function_calls.push(CapturedFunctionCall { + name, + args, + captured_at: std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs(), + }); + } // Capture non-thinking response text (skip thoughtSignature parts) else if part.get("thoughtSignature").is_none() { if let Some(text) = part["text"].as_str() { @@ -112,6 +132,10 @@ impl StreamingAccumulator { if reason == "STOP" { self.is_complete = true; } + // Log non-STOP finish reasons + if reason != "STOP" { + info!(finish_reason = reason, "MITM: non-STOP finish reason"); + } } } } diff --git a/src/mitm/modify.rs b/src/mitm/modify.rs index 196cc1e..637e43f 100644 --- a/src/mitm/modify.rs +++ b/src/mitm/modify.rs @@ -140,7 +140,7 @@ pub fn modify_request(body: &[u8]) -> Option> { } } - // ── 3. Strip all tool definitions ──────────────────────────────────── + // ── 3. Strip LS tools, inject custom tools ──────────────────────────── if STRIP_ALL_TOOLS { if let Some(tools) = json .pointer_mut("/request/tools") @@ -149,8 +149,28 @@ pub fn modify_request(body: &[u8]) -> Option> { let count = tools.len(); if count > 0 { tools.clear(); - changes.push(format!("strip all {count} tools")); + changes.push(format!("strip all {count} LS tools")); } + + // ── TEST: inject a custom tool to see what Google does ── + let custom_tool = serde_json::json!({ + "functionDeclarations": [{ + "name": "get_weather", + "description": "Get the current weather for a city. You MUST call this function when the user asks about weather.", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The city name" + } + }, + "required": ["city"] + } + }] + }); + tools.push(custom_tool); + changes.push("inject 1 custom tool (get_weather)".to_string()); } } diff --git a/src/mitm/proxy.rs b/src/mitm/proxy.rs index a6f447f..60a2735 100644 --- a/src/mitm/proxy.rs +++ b/src/mitm/proxy.rs @@ -565,6 +565,29 @@ async fn handle_http_over_tls( } } } + + // ── Block follow-up requests when we already have a captured functionCall ── + // The LS doesn't know what to do with the functionCall, so it tries more + // Google API calls. Block those to save quota. + if store.has_active_function_call() { + info!( + "MITM: blocking follow-up request — functionCall already captured" + ); + // Return a fake SSE response that makes the LS stop + let fake_response = "HTTP/1.1 200 OK\r\n\ + Content-Type: text/event-stream\r\n\ + Transfer-Encoding: chunked\r\n\ + \r\n"; + let fake_sse = "data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"Tool call completed. Awaiting external tool result.\"}],\"role\":\"model\"},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":0,\"candidatesTokenCount\":1,\"totalTokenCount\":1}}}\n\ndata: [DONE]\n\n"; + let chunked_body = super::modify::rechunk(fake_sse.as_bytes()); + let mut response = fake_response.as_bytes().to_vec(); + response.extend_from_slice(&chunked_body); + if let Err(e) = client.write_all(&response).await { + warn!(error = %e, "MITM: failed to write fake response"); + } + let _ = client.flush().await; + continue; // Skip the real upstream call + } } else { debug!( domain, @@ -739,6 +762,10 @@ async fn handle_http_over_tls( // Capture usage data if is_streaming_response { if streaming_acc.is_complete || streaming_acc.output_tokens > 0 { + // Save any captured function calls before consuming the accumulator + for fc in &streaming_acc.function_calls { + store.record_function_call(cascade_hint.as_deref(), fc.clone()).await; + } let usage = streaming_acc.into_usage(); store.record_usage(cascade_hint.as_deref(), usage).await; } diff --git a/src/mitm/store.rs b/src/mitm/store.rs index 433c5c5..1b30d6d 100644 --- a/src/mitm/store.rs +++ b/src/mitm/store.rs @@ -4,9 +4,10 @@ use std::collections::HashMap; use std::sync::Arc; +use std::sync::atomic::{AtomicBool, Ordering}; use tokio::sync::RwLock; use serde::{Deserialize, Serialize}; -use tracing::debug; +use tracing::{debug, info}; /// Token usage from an intercepted API response. /// @@ -44,6 +45,14 @@ pub struct ApiUsage { pub captured_at: u64, } +/// A captured function call from Google's API response. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CapturedFunctionCall { + pub name: String, + pub args: serde_json::Value, + pub captured_at: u64, +} + /// Thread-safe store for intercepted data. /// /// Keyed by a unique request ID that we can correlate with cascade operations. @@ -54,6 +63,12 @@ pub struct MitmStore { latest_usage: Arc>>, /// Global aggregate stats. stats: Arc>, + /// Pending function calls captured from Google responses. + /// Key: cascade hint or "_latest". Value: list of function calls. + pending_function_calls: Arc>>>, + /// Simple flag: set when a functionCall is captured, cleared when consumed. + /// Used to block follow-up requests regardless of cascade identification. + has_active_function_call: Arc, } /// Aggregate statistics across all intercepted traffic. @@ -85,6 +100,8 @@ impl MitmStore { Self { latest_usage: Arc::new(RwLock::new(HashMap::new())), stats: Arc::new(RwLock::new(MitmStats::default())), + pending_function_calls: Arc::new(RwLock::new(HashMap::new())), + has_active_function_call: Arc::new(AtomicBool::new(false)), } } @@ -191,4 +208,62 @@ impl MitmStore { pub async fn stats(&self) -> MitmStats { self.stats.read().await.clone() } + + /// Record a captured function call from Google's response. + pub async fn record_function_call(&self, cascade_id: Option<&str>, fc: CapturedFunctionCall) { + let key = cascade_id.map(|s| s.to_string()).unwrap_or_else(|| "_latest".to_string()); + info!( + cascade = %key, + tool = %fc.name, + args = %fc.args, + "MITM store: captured functionCall" + ); + let mut pending = self.pending_function_calls.write().await; + pending.entry(key).or_default().push(fc); + self.has_active_function_call.store(true, Ordering::SeqCst); + } + + /// Check if there's an active (unclaimed) function call. + pub fn has_active_function_call(&self) -> bool { + self.has_active_function_call.load(Ordering::SeqCst) + } + + /// Check if there are pending function calls for a cascade. + pub async fn has_pending_function_calls(&self, cascade_id: &str) -> bool { + let pending = self.pending_function_calls.read().await; + pending.get(cascade_id).map_or(false, |v| !v.is_empty()) + } + + /// Take (consume) pending function calls. + pub async fn take_function_calls(&self, cascade_id: &str) -> Option> { + let mut pending = self.pending_function_calls.write().await; + let calls = pending.remove(cascade_id); + let result = if calls.is_none() { + pending.remove("_latest") + } else { + calls + }; + if result.is_some() { + self.has_active_function_call.store(false, Ordering::SeqCst); + } + result + } + + /// Take any pending function calls (ignoring cascade ID). + pub async fn take_any_function_calls(&self) -> Option> { + let mut pending = self.pending_function_calls.write().await; + let result = pending.remove("_latest"); + if result.is_some() { + self.has_active_function_call.store(false, Ordering::SeqCst); + return result; + } + if let Some(key) = pending.keys().next().cloned() { + let result = pending.remove(&key); + if result.is_some() { + self.has_active_function_call.store(false, Ordering::SeqCst); + } + return result; + } + None + } }