feat: capture function calls from Google + block follow-up quota waste

When MITM strips LS tools and injects custom tools:
- Google returns functionCall → captured in MitmStore
- Follow-up LS requests are blocked with fake SSE response
- Proxy consumes captured calls and clears the flag
- Result: 1 real Google API call instead of 5+ per tool call

Flow: Client → Proxy → LS → MITM(inject tool) → Google
      Google returns functionCall → MITM captures it
      LS tries follow-up → MITM blocks (fake response)
      Proxy reads captured functionCall → returns to client
This commit is contained in:
Nikketryhard
2026-02-14 22:37:28 -06:00
parent 146be139a2
commit 8455aa674f
5 changed files with 161 additions and 5 deletions

View File

@@ -361,6 +361,16 @@ async fn handle_responses_sync(
uuid::Uuid::new_v4().to_string().replace('-', "") uuid::Uuid::new_v4().to_string().replace('-', "")
); );
// Check for captured function calls from MITM (clears the active flag)
let captured_tool_calls = state.mitm_store.take_any_function_calls().await;
if let Some(ref calls) = captured_tool_calls {
info!(
count = calls.len(),
tools = ?calls.iter().map(|c| &c.name).collect::<Vec<_>>(),
"Consumed captured function calls from MITM"
);
}
let (usage, mitm_thinking) = usage_from_poll(&state.mitm_store, &cascade_id, &poll_result.usage, &params.user_text, &poll_result.text).await; let (usage, mitm_thinking) = usage_from_poll(&state.mitm_store, &cascade_id, &poll_result.usage, &params.user_text, &poll_result.text).await;
// Thinking text priority: MITM-captured (raw API) > LS-extracted (steps) // Thinking text priority: MITM-captured (raw API) > LS-extracted (steps)

View File

@@ -2,9 +2,9 @@
//! //!
//! Handles both streaming (SSE) and non-streaming (JSON) responses. //! Handles both streaming (SSE) and non-streaming (JSON) responses.
use super::store::ApiUsage; use super::store::{ApiUsage, CapturedFunctionCall};
use serde_json::Value; use serde_json::Value;
use tracing::{debug, trace}; use tracing::{debug, info, trace};
/// Parse a complete (non-streaming) Anthropic Messages API response body. /// Parse a complete (non-streaming) Anthropic Messages API response body.
/// ///
@@ -66,6 +66,8 @@ pub struct StreamingAccumulator {
pub stop_reason: Option<String>, pub stop_reason: Option<String>,
pub is_complete: bool, pub is_complete: bool,
pub api_provider: Option<String>, pub api_provider: Option<String>,
/// Captured function calls from Google's response.
pub function_calls: Vec<CapturedFunctionCall>,
} }
impl StreamingAccumulator { impl StreamingAccumulator {
@@ -96,6 +98,24 @@ impl StreamingAccumulator {
self.thinking_text.push_str(text); self.thinking_text.push_str(text);
} }
} }
// Detect functionCall from Google (tool call response)
else if let Some(fc) = part.get("functionCall") {
let name = fc["name"].as_str().unwrap_or("unknown").to_string();
let args = fc["args"].clone();
info!(
tool_name = %name,
tool_args = %args,
"MITM: Google returned functionCall!"
);
self.function_calls.push(CapturedFunctionCall {
name,
args,
captured_at: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
});
}
// Capture non-thinking response text (skip thoughtSignature parts) // Capture non-thinking response text (skip thoughtSignature parts)
else if part.get("thoughtSignature").is_none() { else if part.get("thoughtSignature").is_none() {
if let Some(text) = part["text"].as_str() { if let Some(text) = part["text"].as_str() {
@@ -112,6 +132,10 @@ impl StreamingAccumulator {
if reason == "STOP" { if reason == "STOP" {
self.is_complete = true; self.is_complete = true;
} }
// Log non-STOP finish reasons
if reason != "STOP" {
info!(finish_reason = reason, "MITM: non-STOP finish reason");
}
} }
} }
} }

View File

@@ -140,7 +140,7 @@ pub fn modify_request(body: &[u8]) -> Option<Vec<u8>> {
} }
} }
// ── 3. Strip all tool definitions ──────────────────────────────────── // ── 3. Strip LS tools, inject custom tools ────────────────────────────
if STRIP_ALL_TOOLS { if STRIP_ALL_TOOLS {
if let Some(tools) = json if let Some(tools) = json
.pointer_mut("/request/tools") .pointer_mut("/request/tools")
@@ -149,8 +149,28 @@ pub fn modify_request(body: &[u8]) -> Option<Vec<u8>> {
let count = tools.len(); let count = tools.len();
if count > 0 { if count > 0 {
tools.clear(); tools.clear();
changes.push(format!("strip all {count} tools")); changes.push(format!("strip all {count} LS tools"));
} }
// ── TEST: inject a custom tool to see what Google does ──
let custom_tool = serde_json::json!({
"functionDeclarations": [{
"name": "get_weather",
"description": "Get the current weather for a city. You MUST call this function when the user asks about weather.",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The city name"
}
},
"required": ["city"]
}
}]
});
tools.push(custom_tool);
changes.push("inject 1 custom tool (get_weather)".to_string());
} }
} }

View File

@@ -565,6 +565,29 @@ async fn handle_http_over_tls(
} }
} }
} }
// ── Block follow-up requests when we already have a captured functionCall ──
// The LS doesn't know what to do with the functionCall, so it tries more
// Google API calls. Block those to save quota.
if store.has_active_function_call() {
info!(
"MITM: blocking follow-up request — functionCall already captured"
);
// Return a fake SSE response that makes the LS stop
let fake_response = "HTTP/1.1 200 OK\r\n\
Content-Type: text/event-stream\r\n\
Transfer-Encoding: chunked\r\n\
\r\n";
let fake_sse = "data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"Tool call completed. Awaiting external tool result.\"}],\"role\":\"model\"},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":0,\"candidatesTokenCount\":1,\"totalTokenCount\":1}}}\n\ndata: [DONE]\n\n";
let chunked_body = super::modify::rechunk(fake_sse.as_bytes());
let mut response = fake_response.as_bytes().to_vec();
response.extend_from_slice(&chunked_body);
if let Err(e) = client.write_all(&response).await {
warn!(error = %e, "MITM: failed to write fake response");
}
let _ = client.flush().await;
continue; // Skip the real upstream call
}
} else { } else {
debug!( debug!(
domain, domain,
@@ -739,6 +762,10 @@ async fn handle_http_over_tls(
// Capture usage data // Capture usage data
if is_streaming_response { if is_streaming_response {
if streaming_acc.is_complete || streaming_acc.output_tokens > 0 { if streaming_acc.is_complete || streaming_acc.output_tokens > 0 {
// Save any captured function calls before consuming the accumulator
for fc in &streaming_acc.function_calls {
store.record_function_call(cascade_hint.as_deref(), fc.clone()).await;
}
let usage = streaming_acc.into_usage(); let usage = streaming_acc.into_usage();
store.record_usage(cascade_hint.as_deref(), usage).await; store.record_usage(cascade_hint.as_deref(), usage).await;
} }

View File

@@ -4,9 +4,10 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use tokio::sync::RwLock; use tokio::sync::RwLock;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tracing::debug; use tracing::{debug, info};
/// Token usage from an intercepted API response. /// Token usage from an intercepted API response.
/// ///
@@ -44,6 +45,14 @@ pub struct ApiUsage {
pub captured_at: u64, pub captured_at: u64,
} }
/// A captured function call from Google's API response.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CapturedFunctionCall {
pub name: String,
pub args: serde_json::Value,
pub captured_at: u64,
}
/// Thread-safe store for intercepted data. /// Thread-safe store for intercepted data.
/// ///
/// Keyed by a unique request ID that we can correlate with cascade operations. /// Keyed by a unique request ID that we can correlate with cascade operations.
@@ -54,6 +63,12 @@ pub struct MitmStore {
latest_usage: Arc<RwLock<HashMap<String, ApiUsage>>>, latest_usage: Arc<RwLock<HashMap<String, ApiUsage>>>,
/// Global aggregate stats. /// Global aggregate stats.
stats: Arc<RwLock<MitmStats>>, stats: Arc<RwLock<MitmStats>>,
/// Pending function calls captured from Google responses.
/// Key: cascade hint or "_latest". Value: list of function calls.
pending_function_calls: Arc<RwLock<HashMap<String, Vec<CapturedFunctionCall>>>>,
/// Simple flag: set when a functionCall is captured, cleared when consumed.
/// Used to block follow-up requests regardless of cascade identification.
has_active_function_call: Arc<AtomicBool>,
} }
/// Aggregate statistics across all intercepted traffic. /// Aggregate statistics across all intercepted traffic.
@@ -85,6 +100,8 @@ impl MitmStore {
Self { Self {
latest_usage: Arc::new(RwLock::new(HashMap::new())), latest_usage: Arc::new(RwLock::new(HashMap::new())),
stats: Arc::new(RwLock::new(MitmStats::default())), stats: Arc::new(RwLock::new(MitmStats::default())),
pending_function_calls: Arc::new(RwLock::new(HashMap::new())),
has_active_function_call: Arc::new(AtomicBool::new(false)),
} }
} }
@@ -191,4 +208,62 @@ impl MitmStore {
pub async fn stats(&self) -> MitmStats { pub async fn stats(&self) -> MitmStats {
self.stats.read().await.clone() self.stats.read().await.clone()
} }
/// Record a captured function call from Google's response.
pub async fn record_function_call(&self, cascade_id: Option<&str>, fc: CapturedFunctionCall) {
let key = cascade_id.map(|s| s.to_string()).unwrap_or_else(|| "_latest".to_string());
info!(
cascade = %key,
tool = %fc.name,
args = %fc.args,
"MITM store: captured functionCall"
);
let mut pending = self.pending_function_calls.write().await;
pending.entry(key).or_default().push(fc);
self.has_active_function_call.store(true, Ordering::SeqCst);
}
/// Check if there's an active (unclaimed) function call.
pub fn has_active_function_call(&self) -> bool {
self.has_active_function_call.load(Ordering::SeqCst)
}
/// Check if there are pending function calls for a cascade.
pub async fn has_pending_function_calls(&self, cascade_id: &str) -> bool {
let pending = self.pending_function_calls.read().await;
pending.get(cascade_id).map_or(false, |v| !v.is_empty())
}
/// Take (consume) pending function calls.
pub async fn take_function_calls(&self, cascade_id: &str) -> Option<Vec<CapturedFunctionCall>> {
let mut pending = self.pending_function_calls.write().await;
let calls = pending.remove(cascade_id);
let result = if calls.is_none() {
pending.remove("_latest")
} else {
calls
};
if result.is_some() {
self.has_active_function_call.store(false, Ordering::SeqCst);
}
result
}
/// Take any pending function calls (ignoring cascade ID).
pub async fn take_any_function_calls(&self) -> Option<Vec<CapturedFunctionCall>> {
let mut pending = self.pending_function_calls.write().await;
let result = pending.remove("_latest");
if result.is_some() {
self.has_active_function_call.store(false, Ordering::SeqCst);
return result;
}
if let Some(key) = pending.keys().next().cloned() {
let result = pending.remove(&key);
if result.is_some() {
self.has_active_function_call.store(false, Ordering::SeqCst);
}
return result;
}
None
}
} }