feat: capture function calls from Google + block follow-up quota waste
When MITM strips LS tools and injects custom tools:
- Google returns functionCall → captured in MitmStore
- Follow-up LS requests are blocked with fake SSE response
- Proxy consumes captured calls and clears the flag
- Result: 1 real Google API call instead of 5+ per tool call
Flow: Client → Proxy → LS → MITM(inject tool) → Google
Google returns functionCall → MITM captures it
LS tries follow-up → MITM blocks (fake response)
Proxy reads captured functionCall → returns to client
This commit is contained in:
@@ -361,6 +361,16 @@ async fn handle_responses_sync(
|
||||
uuid::Uuid::new_v4().to_string().replace('-', "")
|
||||
);
|
||||
|
||||
// Check for captured function calls from MITM (clears the active flag)
|
||||
let captured_tool_calls = state.mitm_store.take_any_function_calls().await;
|
||||
if let Some(ref calls) = captured_tool_calls {
|
||||
info!(
|
||||
count = calls.len(),
|
||||
tools = ?calls.iter().map(|c| &c.name).collect::<Vec<_>>(),
|
||||
"Consumed captured function calls from MITM"
|
||||
);
|
||||
}
|
||||
|
||||
let (usage, mitm_thinking) = usage_from_poll(&state.mitm_store, &cascade_id, &poll_result.usage, ¶ms.user_text, &poll_result.text).await;
|
||||
|
||||
// Thinking text priority: MITM-captured (raw API) > LS-extracted (steps)
|
||||
|
||||
@@ -2,9 +2,9 @@
|
||||
//!
|
||||
//! Handles both streaming (SSE) and non-streaming (JSON) responses.
|
||||
|
||||
use super::store::ApiUsage;
|
||||
use super::store::{ApiUsage, CapturedFunctionCall};
|
||||
use serde_json::Value;
|
||||
use tracing::{debug, trace};
|
||||
use tracing::{debug, info, trace};
|
||||
|
||||
/// Parse a complete (non-streaming) Anthropic Messages API response body.
|
||||
///
|
||||
@@ -66,6 +66,8 @@ pub struct StreamingAccumulator {
|
||||
pub stop_reason: Option<String>,
|
||||
pub is_complete: bool,
|
||||
pub api_provider: Option<String>,
|
||||
/// Captured function calls from Google's response.
|
||||
pub function_calls: Vec<CapturedFunctionCall>,
|
||||
}
|
||||
|
||||
impl StreamingAccumulator {
|
||||
@@ -96,6 +98,24 @@ impl StreamingAccumulator {
|
||||
self.thinking_text.push_str(text);
|
||||
}
|
||||
}
|
||||
// Detect functionCall from Google (tool call response)
|
||||
else if let Some(fc) = part.get("functionCall") {
|
||||
let name = fc["name"].as_str().unwrap_or("unknown").to_string();
|
||||
let args = fc["args"].clone();
|
||||
info!(
|
||||
tool_name = %name,
|
||||
tool_args = %args,
|
||||
"MITM: Google returned functionCall!"
|
||||
);
|
||||
self.function_calls.push(CapturedFunctionCall {
|
||||
name,
|
||||
args,
|
||||
captured_at: std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs(),
|
||||
});
|
||||
}
|
||||
// Capture non-thinking response text (skip thoughtSignature parts)
|
||||
else if part.get("thoughtSignature").is_none() {
|
||||
if let Some(text) = part["text"].as_str() {
|
||||
@@ -112,6 +132,10 @@ impl StreamingAccumulator {
|
||||
if reason == "STOP" {
|
||||
self.is_complete = true;
|
||||
}
|
||||
// Log non-STOP finish reasons
|
||||
if reason != "STOP" {
|
||||
info!(finish_reason = reason, "MITM: non-STOP finish reason");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -140,7 +140,7 @@ pub fn modify_request(body: &[u8]) -> Option<Vec<u8>> {
|
||||
}
|
||||
}
|
||||
|
||||
// ── 3. Strip all tool definitions ────────────────────────────────────
|
||||
// ── 3. Strip LS tools, inject custom tools ────────────────────────────
|
||||
if STRIP_ALL_TOOLS {
|
||||
if let Some(tools) = json
|
||||
.pointer_mut("/request/tools")
|
||||
@@ -149,8 +149,28 @@ pub fn modify_request(body: &[u8]) -> Option<Vec<u8>> {
|
||||
let count = tools.len();
|
||||
if count > 0 {
|
||||
tools.clear();
|
||||
changes.push(format!("strip all {count} tools"));
|
||||
changes.push(format!("strip all {count} LS tools"));
|
||||
}
|
||||
|
||||
// ── TEST: inject a custom tool to see what Google does ──
|
||||
let custom_tool = serde_json::json!({
|
||||
"functionDeclarations": [{
|
||||
"name": "get_weather",
|
||||
"description": "Get the current weather for a city. You MUST call this function when the user asks about weather.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "The city name"
|
||||
}
|
||||
},
|
||||
"required": ["city"]
|
||||
}
|
||||
}]
|
||||
});
|
||||
tools.push(custom_tool);
|
||||
changes.push("inject 1 custom tool (get_weather)".to_string());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -565,6 +565,29 @@ async fn handle_http_over_tls(
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Block follow-up requests when we already have a captured functionCall ──
|
||||
// The LS doesn't know what to do with the functionCall, so it tries more
|
||||
// Google API calls. Block those to save quota.
|
||||
if store.has_active_function_call() {
|
||||
info!(
|
||||
"MITM: blocking follow-up request — functionCall already captured"
|
||||
);
|
||||
// Return a fake SSE response that makes the LS stop
|
||||
let fake_response = "HTTP/1.1 200 OK\r\n\
|
||||
Content-Type: text/event-stream\r\n\
|
||||
Transfer-Encoding: chunked\r\n\
|
||||
\r\n";
|
||||
let fake_sse = "data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"Tool call completed. Awaiting external tool result.\"}],\"role\":\"model\"},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":0,\"candidatesTokenCount\":1,\"totalTokenCount\":1}}}\n\ndata: [DONE]\n\n";
|
||||
let chunked_body = super::modify::rechunk(fake_sse.as_bytes());
|
||||
let mut response = fake_response.as_bytes().to_vec();
|
||||
response.extend_from_slice(&chunked_body);
|
||||
if let Err(e) = client.write_all(&response).await {
|
||||
warn!(error = %e, "MITM: failed to write fake response");
|
||||
}
|
||||
let _ = client.flush().await;
|
||||
continue; // Skip the real upstream call
|
||||
}
|
||||
} else {
|
||||
debug!(
|
||||
domain,
|
||||
@@ -739,6 +762,10 @@ async fn handle_http_over_tls(
|
||||
// Capture usage data
|
||||
if is_streaming_response {
|
||||
if streaming_acc.is_complete || streaming_acc.output_tokens > 0 {
|
||||
// Save any captured function calls before consuming the accumulator
|
||||
for fc in &streaming_acc.function_calls {
|
||||
store.record_function_call(cascade_hint.as_deref(), fc.clone()).await;
|
||||
}
|
||||
let usage = streaming_acc.into_usage();
|
||||
store.record_usage(cascade_hint.as_deref(), usage).await;
|
||||
}
|
||||
|
||||
@@ -4,9 +4,10 @@
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use tokio::sync::RwLock;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tracing::debug;
|
||||
use tracing::{debug, info};
|
||||
|
||||
/// Token usage from an intercepted API response.
|
||||
///
|
||||
@@ -44,6 +45,14 @@ pub struct ApiUsage {
|
||||
pub captured_at: u64,
|
||||
}
|
||||
|
||||
/// A captured function call from Google's API response.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CapturedFunctionCall {
|
||||
pub name: String,
|
||||
pub args: serde_json::Value,
|
||||
pub captured_at: u64,
|
||||
}
|
||||
|
||||
/// Thread-safe store for intercepted data.
|
||||
///
|
||||
/// Keyed by a unique request ID that we can correlate with cascade operations.
|
||||
@@ -54,6 +63,12 @@ pub struct MitmStore {
|
||||
latest_usage: Arc<RwLock<HashMap<String, ApiUsage>>>,
|
||||
/// Global aggregate stats.
|
||||
stats: Arc<RwLock<MitmStats>>,
|
||||
/// Pending function calls captured from Google responses.
|
||||
/// Key: cascade hint or "_latest". Value: list of function calls.
|
||||
pending_function_calls: Arc<RwLock<HashMap<String, Vec<CapturedFunctionCall>>>>,
|
||||
/// Simple flag: set when a functionCall is captured, cleared when consumed.
|
||||
/// Used to block follow-up requests regardless of cascade identification.
|
||||
has_active_function_call: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
/// Aggregate statistics across all intercepted traffic.
|
||||
@@ -85,6 +100,8 @@ impl MitmStore {
|
||||
Self {
|
||||
latest_usage: Arc::new(RwLock::new(HashMap::new())),
|
||||
stats: Arc::new(RwLock::new(MitmStats::default())),
|
||||
pending_function_calls: Arc::new(RwLock::new(HashMap::new())),
|
||||
has_active_function_call: Arc::new(AtomicBool::new(false)),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -191,4 +208,62 @@ impl MitmStore {
|
||||
pub async fn stats(&self) -> MitmStats {
|
||||
self.stats.read().await.clone()
|
||||
}
|
||||
|
||||
/// Record a captured function call from Google's response.
|
||||
pub async fn record_function_call(&self, cascade_id: Option<&str>, fc: CapturedFunctionCall) {
|
||||
let key = cascade_id.map(|s| s.to_string()).unwrap_or_else(|| "_latest".to_string());
|
||||
info!(
|
||||
cascade = %key,
|
||||
tool = %fc.name,
|
||||
args = %fc.args,
|
||||
"MITM store: captured functionCall"
|
||||
);
|
||||
let mut pending = self.pending_function_calls.write().await;
|
||||
pending.entry(key).or_default().push(fc);
|
||||
self.has_active_function_call.store(true, Ordering::SeqCst);
|
||||
}
|
||||
|
||||
/// Check if there's an active (unclaimed) function call.
|
||||
pub fn has_active_function_call(&self) -> bool {
|
||||
self.has_active_function_call.load(Ordering::SeqCst)
|
||||
}
|
||||
|
||||
/// Check if there are pending function calls for a cascade.
|
||||
pub async fn has_pending_function_calls(&self, cascade_id: &str) -> bool {
|
||||
let pending = self.pending_function_calls.read().await;
|
||||
pending.get(cascade_id).map_or(false, |v| !v.is_empty())
|
||||
}
|
||||
|
||||
/// Take (consume) pending function calls.
|
||||
pub async fn take_function_calls(&self, cascade_id: &str) -> Option<Vec<CapturedFunctionCall>> {
|
||||
let mut pending = self.pending_function_calls.write().await;
|
||||
let calls = pending.remove(cascade_id);
|
||||
let result = if calls.is_none() {
|
||||
pending.remove("_latest")
|
||||
} else {
|
||||
calls
|
||||
};
|
||||
if result.is_some() {
|
||||
self.has_active_function_call.store(false, Ordering::SeqCst);
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
/// Take any pending function calls (ignoring cascade ID).
|
||||
pub async fn take_any_function_calls(&self) -> Option<Vec<CapturedFunctionCall>> {
|
||||
let mut pending = self.pending_function_calls.write().await;
|
||||
let result = pending.remove("_latest");
|
||||
if result.is_some() {
|
||||
self.has_active_function_call.store(false, Ordering::SeqCst);
|
||||
return result;
|
||||
}
|
||||
if let Some(key) = pending.keys().next().cloned() {
|
||||
let result = pending.remove(&key);
|
||||
if result.is_some() {
|
||||
self.has_active_function_call.store(false, Ordering::SeqCst);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user