feat: capture function calls from Google + block follow-up quota waste
When MITM strips LS tools and injects custom tools:
- Google returns functionCall → captured in MitmStore
- Follow-up LS requests are blocked with fake SSE response
- Proxy consumes captured calls and clears the flag
- Result: 1 real Google API call instead of 5+ per tool call
Flow: Client → Proxy → LS → MITM(inject tool) → Google
Google returns functionCall → MITM captures it
LS tries follow-up → MITM blocks (fake response)
Proxy reads captured functionCall → returns to client
This commit is contained in:
@@ -361,6 +361,16 @@ async fn handle_responses_sync(
|
|||||||
uuid::Uuid::new_v4().to_string().replace('-', "")
|
uuid::Uuid::new_v4().to_string().replace('-', "")
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// Check for captured function calls from MITM (clears the active flag)
|
||||||
|
let captured_tool_calls = state.mitm_store.take_any_function_calls().await;
|
||||||
|
if let Some(ref calls) = captured_tool_calls {
|
||||||
|
info!(
|
||||||
|
count = calls.len(),
|
||||||
|
tools = ?calls.iter().map(|c| &c.name).collect::<Vec<_>>(),
|
||||||
|
"Consumed captured function calls from MITM"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
let (usage, mitm_thinking) = usage_from_poll(&state.mitm_store, &cascade_id, &poll_result.usage, ¶ms.user_text, &poll_result.text).await;
|
let (usage, mitm_thinking) = usage_from_poll(&state.mitm_store, &cascade_id, &poll_result.usage, ¶ms.user_text, &poll_result.text).await;
|
||||||
|
|
||||||
// Thinking text priority: MITM-captured (raw API) > LS-extracted (steps)
|
// Thinking text priority: MITM-captured (raw API) > LS-extracted (steps)
|
||||||
|
|||||||
@@ -2,9 +2,9 @@
|
|||||||
//!
|
//!
|
||||||
//! Handles both streaming (SSE) and non-streaming (JSON) responses.
|
//! Handles both streaming (SSE) and non-streaming (JSON) responses.
|
||||||
|
|
||||||
use super::store::ApiUsage;
|
use super::store::{ApiUsage, CapturedFunctionCall};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use tracing::{debug, trace};
|
use tracing::{debug, info, trace};
|
||||||
|
|
||||||
/// Parse a complete (non-streaming) Anthropic Messages API response body.
|
/// Parse a complete (non-streaming) Anthropic Messages API response body.
|
||||||
///
|
///
|
||||||
@@ -66,6 +66,8 @@ pub struct StreamingAccumulator {
|
|||||||
pub stop_reason: Option<String>,
|
pub stop_reason: Option<String>,
|
||||||
pub is_complete: bool,
|
pub is_complete: bool,
|
||||||
pub api_provider: Option<String>,
|
pub api_provider: Option<String>,
|
||||||
|
/// Captured function calls from Google's response.
|
||||||
|
pub function_calls: Vec<CapturedFunctionCall>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl StreamingAccumulator {
|
impl StreamingAccumulator {
|
||||||
@@ -96,6 +98,24 @@ impl StreamingAccumulator {
|
|||||||
self.thinking_text.push_str(text);
|
self.thinking_text.push_str(text);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// Detect functionCall from Google (tool call response)
|
||||||
|
else if let Some(fc) = part.get("functionCall") {
|
||||||
|
let name = fc["name"].as_str().unwrap_or("unknown").to_string();
|
||||||
|
let args = fc["args"].clone();
|
||||||
|
info!(
|
||||||
|
tool_name = %name,
|
||||||
|
tool_args = %args,
|
||||||
|
"MITM: Google returned functionCall!"
|
||||||
|
);
|
||||||
|
self.function_calls.push(CapturedFunctionCall {
|
||||||
|
name,
|
||||||
|
args,
|
||||||
|
captured_at: std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.unwrap_or_default()
|
||||||
|
.as_secs(),
|
||||||
|
});
|
||||||
|
}
|
||||||
// Capture non-thinking response text (skip thoughtSignature parts)
|
// Capture non-thinking response text (skip thoughtSignature parts)
|
||||||
else if part.get("thoughtSignature").is_none() {
|
else if part.get("thoughtSignature").is_none() {
|
||||||
if let Some(text) = part["text"].as_str() {
|
if let Some(text) = part["text"].as_str() {
|
||||||
@@ -112,6 +132,10 @@ impl StreamingAccumulator {
|
|||||||
if reason == "STOP" {
|
if reason == "STOP" {
|
||||||
self.is_complete = true;
|
self.is_complete = true;
|
||||||
}
|
}
|
||||||
|
// Log non-STOP finish reasons
|
||||||
|
if reason != "STOP" {
|
||||||
|
info!(finish_reason = reason, "MITM: non-STOP finish reason");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -140,7 +140,7 @@ pub fn modify_request(body: &[u8]) -> Option<Vec<u8>> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ── 3. Strip all tool definitions ────────────────────────────────────
|
// ── 3. Strip LS tools, inject custom tools ────────────────────────────
|
||||||
if STRIP_ALL_TOOLS {
|
if STRIP_ALL_TOOLS {
|
||||||
if let Some(tools) = json
|
if let Some(tools) = json
|
||||||
.pointer_mut("/request/tools")
|
.pointer_mut("/request/tools")
|
||||||
@@ -149,8 +149,28 @@ pub fn modify_request(body: &[u8]) -> Option<Vec<u8>> {
|
|||||||
let count = tools.len();
|
let count = tools.len();
|
||||||
if count > 0 {
|
if count > 0 {
|
||||||
tools.clear();
|
tools.clear();
|
||||||
changes.push(format!("strip all {count} tools"));
|
changes.push(format!("strip all {count} LS tools"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ── TEST: inject a custom tool to see what Google does ──
|
||||||
|
let custom_tool = serde_json::json!({
|
||||||
|
"functionDeclarations": [{
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "Get the current weather for a city. You MUST call this function when the user asks about weather.",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"city": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city name"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["city"]
|
||||||
|
}
|
||||||
|
}]
|
||||||
|
});
|
||||||
|
tools.push(custom_tool);
|
||||||
|
changes.push("inject 1 custom tool (get_weather)".to_string());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -565,6 +565,29 @@ async fn handle_http_over_tls(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ── Block follow-up requests when we already have a captured functionCall ──
|
||||||
|
// The LS doesn't know what to do with the functionCall, so it tries more
|
||||||
|
// Google API calls. Block those to save quota.
|
||||||
|
if store.has_active_function_call() {
|
||||||
|
info!(
|
||||||
|
"MITM: blocking follow-up request — functionCall already captured"
|
||||||
|
);
|
||||||
|
// Return a fake SSE response that makes the LS stop
|
||||||
|
let fake_response = "HTTP/1.1 200 OK\r\n\
|
||||||
|
Content-Type: text/event-stream\r\n\
|
||||||
|
Transfer-Encoding: chunked\r\n\
|
||||||
|
\r\n";
|
||||||
|
let fake_sse = "data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"Tool call completed. Awaiting external tool result.\"}],\"role\":\"model\"},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":0,\"candidatesTokenCount\":1,\"totalTokenCount\":1}}}\n\ndata: [DONE]\n\n";
|
||||||
|
let chunked_body = super::modify::rechunk(fake_sse.as_bytes());
|
||||||
|
let mut response = fake_response.as_bytes().to_vec();
|
||||||
|
response.extend_from_slice(&chunked_body);
|
||||||
|
if let Err(e) = client.write_all(&response).await {
|
||||||
|
warn!(error = %e, "MITM: failed to write fake response");
|
||||||
|
}
|
||||||
|
let _ = client.flush().await;
|
||||||
|
continue; // Skip the real upstream call
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
debug!(
|
debug!(
|
||||||
domain,
|
domain,
|
||||||
@@ -739,6 +762,10 @@ async fn handle_http_over_tls(
|
|||||||
// Capture usage data
|
// Capture usage data
|
||||||
if is_streaming_response {
|
if is_streaming_response {
|
||||||
if streaming_acc.is_complete || streaming_acc.output_tokens > 0 {
|
if streaming_acc.is_complete || streaming_acc.output_tokens > 0 {
|
||||||
|
// Save any captured function calls before consuming the accumulator
|
||||||
|
for fc in &streaming_acc.function_calls {
|
||||||
|
store.record_function_call(cascade_hint.as_deref(), fc.clone()).await;
|
||||||
|
}
|
||||||
let usage = streaming_acc.into_usage();
|
let usage = streaming_acc.into_usage();
|
||||||
store.record_usage(cascade_hint.as_deref(), usage).await;
|
store.record_usage(cascade_hint.as_deref(), usage).await;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,9 +4,10 @@
|
|||||||
|
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
use tokio::sync::RwLock;
|
use tokio::sync::RwLock;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use tracing::debug;
|
use tracing::{debug, info};
|
||||||
|
|
||||||
/// Token usage from an intercepted API response.
|
/// Token usage from an intercepted API response.
|
||||||
///
|
///
|
||||||
@@ -44,6 +45,14 @@ pub struct ApiUsage {
|
|||||||
pub captured_at: u64,
|
pub captured_at: u64,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// A captured function call from Google's API response.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct CapturedFunctionCall {
|
||||||
|
pub name: String,
|
||||||
|
pub args: serde_json::Value,
|
||||||
|
pub captured_at: u64,
|
||||||
|
}
|
||||||
|
|
||||||
/// Thread-safe store for intercepted data.
|
/// Thread-safe store for intercepted data.
|
||||||
///
|
///
|
||||||
/// Keyed by a unique request ID that we can correlate with cascade operations.
|
/// Keyed by a unique request ID that we can correlate with cascade operations.
|
||||||
@@ -54,6 +63,12 @@ pub struct MitmStore {
|
|||||||
latest_usage: Arc<RwLock<HashMap<String, ApiUsage>>>,
|
latest_usage: Arc<RwLock<HashMap<String, ApiUsage>>>,
|
||||||
/// Global aggregate stats.
|
/// Global aggregate stats.
|
||||||
stats: Arc<RwLock<MitmStats>>,
|
stats: Arc<RwLock<MitmStats>>,
|
||||||
|
/// Pending function calls captured from Google responses.
|
||||||
|
/// Key: cascade hint or "_latest". Value: list of function calls.
|
||||||
|
pending_function_calls: Arc<RwLock<HashMap<String, Vec<CapturedFunctionCall>>>>,
|
||||||
|
/// Simple flag: set when a functionCall is captured, cleared when consumed.
|
||||||
|
/// Used to block follow-up requests regardless of cascade identification.
|
||||||
|
has_active_function_call: Arc<AtomicBool>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Aggregate statistics across all intercepted traffic.
|
/// Aggregate statistics across all intercepted traffic.
|
||||||
@@ -85,6 +100,8 @@ impl MitmStore {
|
|||||||
Self {
|
Self {
|
||||||
latest_usage: Arc::new(RwLock::new(HashMap::new())),
|
latest_usage: Arc::new(RwLock::new(HashMap::new())),
|
||||||
stats: Arc::new(RwLock::new(MitmStats::default())),
|
stats: Arc::new(RwLock::new(MitmStats::default())),
|
||||||
|
pending_function_calls: Arc::new(RwLock::new(HashMap::new())),
|
||||||
|
has_active_function_call: Arc::new(AtomicBool::new(false)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -191,4 +208,62 @@ impl MitmStore {
|
|||||||
pub async fn stats(&self) -> MitmStats {
|
pub async fn stats(&self) -> MitmStats {
|
||||||
self.stats.read().await.clone()
|
self.stats.read().await.clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Record a captured function call from Google's response.
|
||||||
|
pub async fn record_function_call(&self, cascade_id: Option<&str>, fc: CapturedFunctionCall) {
|
||||||
|
let key = cascade_id.map(|s| s.to_string()).unwrap_or_else(|| "_latest".to_string());
|
||||||
|
info!(
|
||||||
|
cascade = %key,
|
||||||
|
tool = %fc.name,
|
||||||
|
args = %fc.args,
|
||||||
|
"MITM store: captured functionCall"
|
||||||
|
);
|
||||||
|
let mut pending = self.pending_function_calls.write().await;
|
||||||
|
pending.entry(key).or_default().push(fc);
|
||||||
|
self.has_active_function_call.store(true, Ordering::SeqCst);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if there's an active (unclaimed) function call.
|
||||||
|
pub fn has_active_function_call(&self) -> bool {
|
||||||
|
self.has_active_function_call.load(Ordering::SeqCst)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if there are pending function calls for a cascade.
|
||||||
|
pub async fn has_pending_function_calls(&self, cascade_id: &str) -> bool {
|
||||||
|
let pending = self.pending_function_calls.read().await;
|
||||||
|
pending.get(cascade_id).map_or(false, |v| !v.is_empty())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Take (consume) pending function calls.
|
||||||
|
pub async fn take_function_calls(&self, cascade_id: &str) -> Option<Vec<CapturedFunctionCall>> {
|
||||||
|
let mut pending = self.pending_function_calls.write().await;
|
||||||
|
let calls = pending.remove(cascade_id);
|
||||||
|
let result = if calls.is_none() {
|
||||||
|
pending.remove("_latest")
|
||||||
|
} else {
|
||||||
|
calls
|
||||||
|
};
|
||||||
|
if result.is_some() {
|
||||||
|
self.has_active_function_call.store(false, Ordering::SeqCst);
|
||||||
|
}
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Take any pending function calls (ignoring cascade ID).
|
||||||
|
pub async fn take_any_function_calls(&self) -> Option<Vec<CapturedFunctionCall>> {
|
||||||
|
let mut pending = self.pending_function_calls.write().await;
|
||||||
|
let result = pending.remove("_latest");
|
||||||
|
if result.is_some() {
|
||||||
|
self.has_active_function_call.store(false, Ordering::SeqCst);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
if let Some(key) = pending.keys().next().cloned() {
|
||||||
|
let result = pending.remove(&key);
|
||||||
|
if result.is_some() {
|
||||||
|
self.has_active_function_call.store(false, Ordering::SeqCst);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
None
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user