feat: capture function calls from Google + block follow-up quota waste
When MITM strips LS tools and injects custom tools:
- Google returns functionCall → captured in MitmStore
- Follow-up LS requests are blocked with fake SSE response
- Proxy consumes captured calls and clears the flag
- Result: 1 real Google API call instead of 5+ per tool call
Flow: Client → Proxy → LS → MITM(inject tool) → Google
Google returns functionCall → MITM captures it
LS tries follow-up → MITM blocks (fake response)
Proxy reads captured functionCall → returns to client
This commit is contained in:
@@ -4,9 +4,10 @@
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use tokio::sync::RwLock;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tracing::debug;
|
||||
use tracing::{debug, info};
|
||||
|
||||
/// Token usage from an intercepted API response.
|
||||
///
|
||||
@@ -44,6 +45,14 @@ pub struct ApiUsage {
|
||||
pub captured_at: u64,
|
||||
}
|
||||
|
||||
/// A captured function call from Google's API response.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CapturedFunctionCall {
|
||||
pub name: String,
|
||||
pub args: serde_json::Value,
|
||||
pub captured_at: u64,
|
||||
}
|
||||
|
||||
/// Thread-safe store for intercepted data.
|
||||
///
|
||||
/// Keyed by a unique request ID that we can correlate with cascade operations.
|
||||
@@ -54,6 +63,12 @@ pub struct MitmStore {
|
||||
latest_usage: Arc<RwLock<HashMap<String, ApiUsage>>>,
|
||||
/// Global aggregate stats.
|
||||
stats: Arc<RwLock<MitmStats>>,
|
||||
/// Pending function calls captured from Google responses.
|
||||
/// Key: cascade hint or "_latest". Value: list of function calls.
|
||||
pending_function_calls: Arc<RwLock<HashMap<String, Vec<CapturedFunctionCall>>>>,
|
||||
/// Simple flag: set when a functionCall is captured, cleared when consumed.
|
||||
/// Used to block follow-up requests regardless of cascade identification.
|
||||
has_active_function_call: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
/// Aggregate statistics across all intercepted traffic.
|
||||
@@ -85,6 +100,8 @@ impl MitmStore {
|
||||
Self {
|
||||
latest_usage: Arc::new(RwLock::new(HashMap::new())),
|
||||
stats: Arc::new(RwLock::new(MitmStats::default())),
|
||||
pending_function_calls: Arc::new(RwLock::new(HashMap::new())),
|
||||
has_active_function_call: Arc::new(AtomicBool::new(false)),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -191,4 +208,62 @@ impl MitmStore {
|
||||
pub async fn stats(&self) -> MitmStats {
|
||||
self.stats.read().await.clone()
|
||||
}
|
||||
|
||||
/// Record a captured function call from Google's response.
|
||||
pub async fn record_function_call(&self, cascade_id: Option<&str>, fc: CapturedFunctionCall) {
|
||||
let key = cascade_id.map(|s| s.to_string()).unwrap_or_else(|| "_latest".to_string());
|
||||
info!(
|
||||
cascade = %key,
|
||||
tool = %fc.name,
|
||||
args = %fc.args,
|
||||
"MITM store: captured functionCall"
|
||||
);
|
||||
let mut pending = self.pending_function_calls.write().await;
|
||||
pending.entry(key).or_default().push(fc);
|
||||
self.has_active_function_call.store(true, Ordering::SeqCst);
|
||||
}
|
||||
|
||||
/// Check if there's an active (unclaimed) function call.
|
||||
pub fn has_active_function_call(&self) -> bool {
|
||||
self.has_active_function_call.load(Ordering::SeqCst)
|
||||
}
|
||||
|
||||
/// Check if there are pending function calls for a cascade.
|
||||
pub async fn has_pending_function_calls(&self, cascade_id: &str) -> bool {
|
||||
let pending = self.pending_function_calls.read().await;
|
||||
pending.get(cascade_id).map_or(false, |v| !v.is_empty())
|
||||
}
|
||||
|
||||
/// Take (consume) pending function calls.
|
||||
pub async fn take_function_calls(&self, cascade_id: &str) -> Option<Vec<CapturedFunctionCall>> {
|
||||
let mut pending = self.pending_function_calls.write().await;
|
||||
let calls = pending.remove(cascade_id);
|
||||
let result = if calls.is_none() {
|
||||
pending.remove("_latest")
|
||||
} else {
|
||||
calls
|
||||
};
|
||||
if result.is_some() {
|
||||
self.has_active_function_call.store(false, Ordering::SeqCst);
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
/// Take any pending function calls (ignoring cascade ID).
|
||||
pub async fn take_any_function_calls(&self) -> Option<Vec<CapturedFunctionCall>> {
|
||||
let mut pending = self.pending_function_calls.write().await;
|
||||
let result = pending.remove("_latest");
|
||||
if result.is_some() {
|
||||
self.has_active_function_call.store(false, Ordering::SeqCst);
|
||||
return result;
|
||||
}
|
||||
if let Some(key) = pending.keys().next().cloned() {
|
||||
let result = pending.remove(&key);
|
||||
if result.is_some() {
|
||||
self.has_active_function_call.store(false, Ordering::SeqCst);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user