feat: capture function calls from Google + block follow-up quota waste

When MITM strips LS tools and injects custom tools:
- Google returns functionCall → captured in MitmStore
- Follow-up LS requests are blocked with fake SSE response
- Proxy consumes captured calls and clears the flag
- Result: 1 real Google API call instead of 5+ per tool call

Flow: Client → Proxy → LS → MITM(inject tool) → Google
      Google returns functionCall → MITM captures it
      LS tries follow-up → MITM blocks (fake response)
      Proxy reads captured functionCall → returns to client
This commit is contained in:
Nikketryhard
2026-02-14 22:37:28 -06:00
parent 146be139a2
commit 8455aa674f
5 changed files with 161 additions and 5 deletions

View File

@@ -4,9 +4,10 @@
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use tokio::sync::RwLock;
use serde::{Deserialize, Serialize};
use tracing::debug;
use tracing::{debug, info};
/// Token usage from an intercepted API response.
///
@@ -44,6 +45,14 @@ pub struct ApiUsage {
pub captured_at: u64,
}
/// A captured function call from Google's API response.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CapturedFunctionCall {
pub name: String,
pub args: serde_json::Value,
pub captured_at: u64,
}
/// Thread-safe store for intercepted data.
///
/// Keyed by a unique request ID that we can correlate with cascade operations.
@@ -54,6 +63,12 @@ pub struct MitmStore {
latest_usage: Arc<RwLock<HashMap<String, ApiUsage>>>,
/// Global aggregate stats.
stats: Arc<RwLock<MitmStats>>,
/// Pending function calls captured from Google responses.
/// Key: cascade hint or "_latest". Value: list of function calls.
pending_function_calls: Arc<RwLock<HashMap<String, Vec<CapturedFunctionCall>>>>,
/// Simple flag: set when a functionCall is captured, cleared when consumed.
/// Used to block follow-up requests regardless of cascade identification.
has_active_function_call: Arc<AtomicBool>,
}
/// Aggregate statistics across all intercepted traffic.
@@ -85,6 +100,8 @@ impl MitmStore {
Self {
latest_usage: Arc::new(RwLock::new(HashMap::new())),
stats: Arc::new(RwLock::new(MitmStats::default())),
pending_function_calls: Arc::new(RwLock::new(HashMap::new())),
has_active_function_call: Arc::new(AtomicBool::new(false)),
}
}
@@ -191,4 +208,62 @@ impl MitmStore {
pub async fn stats(&self) -> MitmStats {
self.stats.read().await.clone()
}
/// Record a captured function call from Google's response.
pub async fn record_function_call(&self, cascade_id: Option<&str>, fc: CapturedFunctionCall) {
let key = cascade_id.map(|s| s.to_string()).unwrap_or_else(|| "_latest".to_string());
info!(
cascade = %key,
tool = %fc.name,
args = %fc.args,
"MITM store: captured functionCall"
);
let mut pending = self.pending_function_calls.write().await;
pending.entry(key).or_default().push(fc);
self.has_active_function_call.store(true, Ordering::SeqCst);
}
/// Check if there's an active (unclaimed) function call.
pub fn has_active_function_call(&self) -> bool {
self.has_active_function_call.load(Ordering::SeqCst)
}
/// Check if there are pending function calls for a cascade.
pub async fn has_pending_function_calls(&self, cascade_id: &str) -> bool {
let pending = self.pending_function_calls.read().await;
pending.get(cascade_id).map_or(false, |v| !v.is_empty())
}
/// Take (consume) pending function calls.
pub async fn take_function_calls(&self, cascade_id: &str) -> Option<Vec<CapturedFunctionCall>> {
let mut pending = self.pending_function_calls.write().await;
let calls = pending.remove(cascade_id);
let result = if calls.is_none() {
pending.remove("_latest")
} else {
calls
};
if result.is_some() {
self.has_active_function_call.store(false, Ordering::SeqCst);
}
result
}
/// Take any pending function calls (ignoring cascade ID).
pub async fn take_any_function_calls(&self) -> Option<Vec<CapturedFunctionCall>> {
let mut pending = self.pending_function_calls.write().await;
let result = pending.remove("_latest");
if result.is_some() {
self.has_active_function_call.store(false, Ordering::SeqCst);
return result;
}
if let Some(key) = pending.keys().next().cloned() {
let result = pending.remove(&key);
if result.is_some() {
self.has_active_function_call.store(false, Ordering::SeqCst);
}
return result;
}
None
}
}