Files
zerogravity/src/mitm/store.rs

659 lines
27 KiB
Rust

//! Shared store for intercepted API usage data.
//!
//! Per-request state is stored in `RequestContext`, keyed by cascade ID.
//! The MITM proxy looks up the context when intercepting LS requests,
//! enabling concurrent request processing without global locks.
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::{mpsc, RwLock};
use tracing::{debug, info};
/// Token usage from an intercepted API response.
///
/// Covers both Anthropic JSON/SSE responses and Google gRPC protobuf responses.
/// Fields map to the superset of Anthropic's `usage` object and Google's `ModelUsageStats` proto.
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ApiUsage {
pub input_tokens: u64,
pub output_tokens: u64,
/// Anthropic: cache_creation_input_tokens / Google: cache_write_tokens
pub cache_creation_input_tokens: u64,
/// Anthropic: cache_read_input_tokens / Google: cache_read_tokens
pub cache_read_input_tokens: u64,
/// Google-specific: thinking/reasoning output tokens (extended thinking)
pub thinking_output_tokens: u64,
/// The actual thinking/reasoning text from the model.
/// Captured from Google SSE parts with `thought: true` or Anthropic thinking blocks.
#[serde(skip_serializing_if = "Option::is_none")]
pub thinking_text: Option<String>,
/// The response text captured from SSE parts (for merge detection).
#[serde(skip)]
pub response_text: Option<String>,
/// Google-specific: response output tokens (non-thinking portion)
pub response_output_tokens: u64,
/// The actual model that served the request.
pub model: Option<String>,
/// Stop reason / finish reason from the API.
pub stop_reason: Option<String>,
/// API provider (e.g. "anthropic", "google")
pub api_provider: Option<String>,
/// gRPC method path (e.g. "/google.internal.cloud.code.v1internal.PredictionService/GenerateContent")
pub grpc_method: Option<String>,
/// Timestamp when this usage was captured.
pub captured_at: u64,
/// Thinking signature from Google's response (base64 opaque blob).
/// Required for multi-turn with thinking models.
#[serde(skip_serializing_if = "Option::is_none")]
pub thinking_signature: Option<String>,
}
/// A captured function call from Google's API response.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CapturedFunctionCall {
pub name: String,
pub args: serde_json::Value,
/// Google's thought signature — required when injecting functionCall back
/// into conversation history. Without it, Google returns INVALID_ARGUMENT.
#[serde(skip_serializing_if = "Option::is_none")]
pub thought_signature: Option<String>,
pub captured_at: u64,
}
/// A pending tool result from a client's function_call_output.
#[derive(Debug, Clone)]
pub struct PendingToolResult {
pub name: String,
pub result: serde_json::Value,
}
/// A single round of tool calling: the model's function calls paired with
/// the client's execution results.
///
/// In multi-step tool use, each round has its own calls and results.
/// This preserves per-turn data so history rewriting can map each
/// "Tool call completed" model turn to the correct functionCall/functionResponse.
#[derive(Debug, Clone)]
pub struct ToolRound {
pub calls: Vec<CapturedFunctionCall>,
pub results: Vec<PendingToolResult>,
}
/// An upstream error captured from Google's API response.
/// Stored by the MITM proxy so API handlers can return it to the client
/// instead of hanging forever waiting for a response that won't come.
#[derive(Debug, Clone)]
pub struct UpstreamError {
/// HTTP status code from Google (e.g. 400, 429, 500).
pub status: u16,
/// Raw error body from Google (usually JSON).
#[allow(dead_code)]
pub body: String,
/// Parsed error message, if available.
pub message: Option<String>,
/// Google error status string (e.g. "INVALID_ARGUMENT", "RESOURCE_EXHAUSTED").
pub error_status: Option<String>,
}
/// A pending image to inject via MITM into the Google API request.
/// The LS doesn't forward images from our SendUserCascadeMessage proto,
/// so we inject them directly at the MITM layer.
#[derive(Debug, Clone)]
pub struct PendingImage {
/// Base64-encoded image data (no prefix).
pub base64_data: String,
/// MIME type, e.g. "image/png".
pub mime_type: String,
}
/// Client-specified generation parameters for MITM injection.
/// Set by API handlers, consumed by the MITM modify layer.
#[derive(Debug, Clone, Default)]
pub struct GenerationParams {
pub temperature: Option<f64>,
pub top_p: Option<f64>,
pub top_k: Option<u32>,
pub max_output_tokens: Option<u64>,
pub stop_sequences: Option<Vec<String>>,
/// Frequency penalty (OpenAI) — mapped to frequencyPenalty in Gemini.
pub frequency_penalty: Option<f64>,
/// Presence penalty (OpenAI) — mapped to presencePenalty in Gemini.
pub presence_penalty: Option<f64>,
/// Reasoning effort — mapped to thinkingConfig.thinkingLevel in Gemini 3.
/// Values: "low", "medium", "high" (maps 1:1 to Google's thinkingLevel).
pub reasoning_effort: Option<String>,
/// Response MIME type — injected as generationConfig.responseMimeType.
/// e.g., "application/json" for JSON mode.
pub response_mime_type: Option<String>,
/// Response schema — injected as generationConfig.responseSchema.
/// Used for structured output (json_schema format).
pub response_schema: Option<serde_json::Value>,
/// Enable Google Search grounding — injects {"googleSearch": {}} into tools.
/// Default off. When enabled, model responses include groundingMetadata.
pub google_search: bool,
}
/// Cached context from turn 0 of a cascade.
///
/// On the first turn, the MITM proxy consumes the `RequestContext` and builds
/// a `ToolContext`. On subsequent turns (tool-call loops), the `RequestContext`
/// is gone. This cache stores the essential fields so we can rebuild a lite
/// `ToolContext` on every turn — ensuring the model always sees the real user
/// text and has access to custom tools.
#[derive(Debug, Clone)]
pub struct CascadeCache {
/// The real user text (used to replace the "." dot prompt).
pub user_text: String,
/// Custom tool definitions (Gemini format).
pub tools: Option<Vec<serde_json::Value>>,
/// Custom tool config.
pub tool_config: Option<serde_json::Value>,
/// Client generation parameters.
pub generation_params: Option<GenerationParams>,
}
// ─── Channel-based event pipeline ────────────────────────────────────────────
/// Events sent from the MITM proxy to API handlers through a per-request channel.
/// Replaces the old polling-based approach (shared atomics + RwLocks) with
/// instant, race-free delivery.
#[derive(Debug, Clone)]
pub enum MitmEvent {
/// Incremental thinking/reasoning text from the model.
ThinkingDelta(String),
/// Incremental response text from the model.
TextDelta(String),
/// Model requested function call(s).
FunctionCall(Vec<CapturedFunctionCall>),
/// Response streaming is complete (finishReason received).
ResponseComplete,
/// Google API returned an error.
UpstreamError(UpstreamError),
/// Grounding metadata (search results) from the response.
#[allow(dead_code)]
Grounding(serde_json::Value),
/// Token usage data from the response.
Usage(ApiUsage),
}
// ─── Per-request context ─────────────────────────────────────────────────────
/// All per-request state. Keyed by cascade ID in `MitmStore.pending_requests`.
///
/// API handlers build this before `send_message`, and the MITM proxy consumes
/// it when the LS's outbound request is intercepted.
#[derive(Debug)]
pub struct RequestContext {
/// Cascade ID this context belongs to.
pub cascade_id: String,
/// Real user text for MITM injection (LS receives "." instead).
pub pending_user_text: String,
/// Event channel for real-time streaming from MITM → API handler.
pub event_channel: mpsc::Sender<MitmEvent>,
/// Client-specified generation parameters (temperature, top_p, etc.).
pub generation_params: Option<GenerationParams>,
/// Image to inject into the Google API request.
pub pending_image: Option<PendingImage>,
/// Gemini-format tool declarations for MITM injection.
pub tools: Option<Vec<serde_json::Value>>,
/// Gemini-format toolConfig.
pub tool_config: Option<serde_json::Value>,
/// Pending tool results to inject as functionResponse.
pub pending_tool_results: Vec<PendingToolResult>,
/// Multi-round tool call history for history rewriting.
pub tool_rounds: Vec<ToolRound>,
/// Last captured function calls for history rewriting.
pub last_function_calls: Vec<CapturedFunctionCall>,
/// Mapping call_id → function name for tool result routing.
pub call_id_to_name: HashMap<String, String>,
/// When this context was created (for TTL cleanup).
pub created_at: Instant,
/// Gate: signaled when MITM takes this context.
/// API handlers wait on this with a timeout to detect match failures.
pub gate: Arc<tokio::sync::Notify>,
/// Debug trace handle (if tracing is enabled).
#[allow(dead_code)]
pub trace_handle: Option<crate::trace::TraceHandle>,
/// Current turn index in the trace (for multi-turn tracking).
#[allow(dead_code)]
pub trace_turn: usize,
}
// ─── MitmStore ───────────────────────────────────────────────────────────────
/// Thread-safe store for intercepted data.
///
/// Per-request state lives in `pending_requests`, keyed by cascade ID.
/// Global state (usage stats, function call capture) remains shared.
#[derive(Clone)]
pub struct MitmStore {
/// Most recent usage per cascade ID.
latest_usage: Arc<RwLock<HashMap<String, ApiUsage>>>,
/// Global aggregate stats.
stats: Arc<RwLock<MitmStats>>,
/// Pending function calls captured from Google responses.
/// Key: cascade hint or "_latest". Value: list of function calls.
pending_function_calls: Arc<RwLock<HashMap<String, Vec<CapturedFunctionCall>>>>,
// ── Per-request state (keyed by cascade ID) ──────────────────────────
/// Active request contexts. API handlers register before send_message,
/// MITM proxy consumes when intercepting the LS request.
pending_requests: Arc<RwLock<HashMap<String, RequestContext>>>,
/// Cached context from turn 0, keyed by cascade ID.
/// Used to rebuild ToolContext on subsequent turns of the same cascade.
cascade_cache: Arc<RwLock<HashMap<String, CascadeCache>>>,
// ── Legacy direct response capture (used by search.rs) ───────────────
/// Captured response text from MITM. Used as fallback by search endpoint.
captured_response_text: Arc<RwLock<Option<String>>>,
// ── Grounding metadata capture ──────────────────────────────────────
/// Captured grounding metadata from Google API responses (search results).
captured_grounding: Arc<RwLock<Option<serde_json::Value>>>,
}
/// Aggregate statistics across all intercepted traffic.
#[derive(Debug, Clone, Default, Serialize)]
pub struct MitmStats {
pub total_requests: u64,
pub total_input_tokens: u64,
pub total_output_tokens: u64,
pub total_cache_read_tokens: u64,
pub total_cache_creation_tokens: u64,
pub total_thinking_output_tokens: u64,
pub total_response_output_tokens: u64,
/// Per-model usage breakdown (model name → stats).
pub per_model: HashMap<String, ModelStats>,
}
/// Per-model usage counters.
#[derive(Debug, Clone, Default, Serialize)]
pub struct ModelStats {
pub requests: u64,
pub input_tokens: u64,
pub output_tokens: u64,
pub cache_read_tokens: u64,
pub cache_creation_tokens: u64,
}
impl MitmStore {
pub fn new() -> Self {
Self {
latest_usage: Arc::new(RwLock::new(HashMap::new())),
stats: Arc::new(RwLock::new(MitmStats::default())),
pending_function_calls: Arc::new(RwLock::new(HashMap::new())),
pending_requests: Arc::new(RwLock::new(HashMap::new())),
cascade_cache: Arc::new(RwLock::new(HashMap::new())),
captured_response_text: Arc::new(RwLock::new(None)),
captured_grounding: Arc::new(RwLock::new(None)),
}
}
// ── Per-request context management ───────────────────────────────────
/// Register a request context for a cascade. Called by API handlers
/// before `send_message` so the MITM proxy can find it.
pub async fn register_request(&self, ctx: RequestContext) {
let cascade_id = ctx.cascade_id.clone();
info!(cascade = %cascade_id, "Registered request context");
self.pending_requests.write().await.insert(cascade_id, ctx);
}
/// Take (consume) the request context for a cascade.
/// Called by the MITM proxy when intercepting the LS's outbound request.
pub async fn take_request(&self, cascade_id: &str) -> Option<RequestContext> {
let ctx = self.pending_requests.write().await.remove(cascade_id);
if let Some(ref c) = ctx {
c.gate.notify_one();
debug!(cascade = %cascade_id, "Took request context (gate signaled)");
}
ctx
}
/// Take the most recently registered request context (by creation time).
/// Fallback when cascade_id can't be extracted from the Google API request.
pub async fn take_latest_request(&self) -> Option<RequestContext> {
let mut pending = self.pending_requests.write().await;
if pending.is_empty() {
return None;
}
// Find the most recently created request
let latest_key = pending
.iter()
.max_by_key(|(_, ctx)| ctx.created_at)
.map(|(k, _)| k.clone());
if let Some(key) = latest_key {
let ctx = pending.remove(&key);
if let Some(ref c) = ctx {
c.gate.notify_one();
debug!(cascade = %key, "Took latest request context (fallback, gate signaled)");
}
ctx
} else {
None
}
}
/// Update a request context in-place. Returns false if not found.
pub async fn update_request<F>(&self, cascade_id: &str, updater: F) -> bool
where
F: FnOnce(&mut RequestContext),
{
let mut map = self.pending_requests.write().await;
if let Some(ctx) = map.get_mut(cascade_id) {
updater(ctx);
true
} else {
false
}
}
/// Remove a request context (cleanup after response is complete).
pub async fn remove_request(&self, cascade_id: &str) {
if self
.pending_requests
.write()
.await
.remove(cascade_id)
.is_some()
{
debug!(cascade = %cascade_id, "Removed request context");
}
}
// ── Cascade cache (turn 0 context for re-injection on turn 1+) ──────
/// Cache the essential context from turn 0 so it can be re-used on
/// subsequent turns of the same cascade.
pub async fn cache_cascade(&self, cascade_id: &str, cache: CascadeCache) {
debug!(cascade = %cascade_id, user_text_len = cache.user_text.len(),
has_tools = cache.tools.is_some(),
"Cached cascade context for subsequent turns");
self.cascade_cache
.write()
.await
.insert(cascade_id.to_string(), cache);
}
/// Get cached context for a cascade (non-consuming — needed on every turn).
pub async fn get_cascade_cache(&self, cascade_id: &str) -> Option<CascadeCache> {
self.cascade_cache.read().await.get(cascade_id).cloned()
}
/// Check if a cascade has been processed (turn 0 complete).
pub async fn has_cascade_cache(&self, cascade_id: &str) -> bool {
self.cascade_cache.read().await.contains_key(cascade_id)
}
// ── Usage recording ──────────────────────────────────────────────────
/// Record a completed API exchange with usage data.
pub async fn record_usage(&self, cascade_id: Option<&str>, usage: ApiUsage) {
debug!(
input = usage.input_tokens,
output = usage.output_tokens,
cache_read = usage.cache_read_input_tokens,
cache_create = usage.cache_creation_input_tokens,
thinking = usage.thinking_output_tokens,
response = usage.response_output_tokens,
model = ?usage.model,
provider = ?usage.api_provider,
grpc = ?usage.grpc_method,
"MITM captured API usage"
);
// Update aggregate stats
{
let mut stats = self.stats.write().await;
stats.total_requests += 1;
stats.total_input_tokens += usage.input_tokens;
stats.total_output_tokens += usage.output_tokens;
stats.total_cache_read_tokens += usage.cache_read_input_tokens;
stats.total_cache_creation_tokens += usage.cache_creation_input_tokens;
stats.total_thinking_output_tokens += usage.thinking_output_tokens;
stats.total_response_output_tokens += usage.response_output_tokens;
// Per-model breakdown
if let Some(ref model_name) = usage.model {
let model_stats = stats.per_model.entry(model_name.clone()).or_default();
model_stats.requests += 1;
model_stats.input_tokens += usage.input_tokens;
model_stats.output_tokens += usage.output_tokens;
model_stats.cache_read_tokens += usage.cache_read_input_tokens;
model_stats.cache_creation_tokens += usage.cache_creation_input_tokens;
}
}
// Store latest usage for the cascade (if we can identify it).
//
// Merge logic for v1internal thinking summaries:
// The LS makes TWO Google API calls per thinking request:
// Call 1: response + thinking token count (thinking_output_tokens > 0, no thinking text)
// Call 2: thinking summary text (thinking_output_tokens == 0, response_text has the summary)
//
// When Call 2 arrives, we merge its response_text as thinking_text into Call 1's usage.
let key = cascade_id.map_or_else(|| "_latest".to_string(), |s| s.to_string());
let mut latest = self.latest_usage.write().await;
if let Some(existing) = latest.get_mut(&key) {
if existing.thinking_output_tokens > 0
&& existing.thinking_text.is_none()
&& usage.thinking_output_tokens == 0
&& usage.response_text.is_some()
{
// Call 2: thinking summary — merge into existing Call 1 usage
existing.thinking_text = usage.response_text;
debug!(
thinking_text_len = existing.thinking_text.as_ref().map_or(0, |t| t.len()),
"MITM: merged thinking summary text into existing usage"
);
} else {
// Normal case: replace existing usage
latest.insert(key, usage);
}
} else {
latest.insert(key, usage);
}
// Evict old entries to prevent unbounded memory growth
const MAX_ENTRIES: usize = 500;
if latest.len() > MAX_ENTRIES {
let oldest_key = latest
.iter()
.min_by_key(|(_, v)| v.captured_at)
.map(|(k, _)| k.clone());
if let Some(key) = oldest_key {
latest.remove(&key);
}
}
}
/// Peek at usage data for a cascade without consuming it.
pub async fn peek_usage(&self, cascade_id: &str) -> Option<ApiUsage> {
let latest = self.latest_usage.read().await;
latest.get(cascade_id).cloned()
}
/// Only returns exact cascade_id matches — no cross-cascade fallback.
pub async fn take_usage(&self, cascade_id: &str) -> Option<ApiUsage> {
let mut latest = self.latest_usage.write().await;
latest.remove(cascade_id)
}
/// Get aggregate stats.
pub async fn stats(&self) -> MitmStats {
self.stats.read().await.clone()
}
// ── Function call capture ────────────────────────────────────────────
/// Record a captured function call from Google's response.
pub async fn record_function_call(&self, cascade_id: Option<&str>, fc: CapturedFunctionCall) {
let key = cascade_id.map_or_else(|| "_latest".to_string(), |s| s.to_string());
info!(
cascade = %key,
tool = %fc.name,
args = %fc.args,
"MITM store: captured functionCall"
);
let mut pending = self.pending_function_calls.write().await;
pending.entry(key).or_default().push(fc);
}
/// Take pending function calls for a specific cascade.
///
/// Priority: exact cascade_id → `_latest` → any key.
pub async fn take_function_calls(&self, cascade_id: &str) -> Option<Vec<CapturedFunctionCall>> {
let mut pending = self.pending_function_calls.write().await;
// 1. Exact cascade match
if let Some(result) = pending.remove(cascade_id) {
return Some(result);
}
// 2. Fallback to _latest
if let Some(result) = pending.remove("_latest") {
return Some(result);
}
// 3. Last resort: any key
if let Some(key) = pending.keys().next().cloned() {
return pending.remove(&key);
}
None
}
/// Take any pending function calls (ignoring cascade ID).
pub async fn take_any_function_calls(&self) -> Option<Vec<CapturedFunctionCall>> {
let mut pending = self.pending_function_calls.write().await;
let result = pending.remove("_latest");
if result.is_some() {
return result;
}
if let Some(key) = pending.keys().next().cloned() {
return pending.remove(&key);
}
None
}
/// Peek at the thought_signatures of recently captured function calls.
/// Returns a map of function_name → thought_signature (non-destructive).
pub async fn peek_thought_signatures(&self) -> std::collections::HashMap<String, String> {
let pending = self.pending_function_calls.read().await;
let mut sigs = std::collections::HashMap::new();
for calls in pending.values() {
for fc in calls {
if let Some(ref sig) = fc.thought_signature {
sigs.insert(fc.name.clone(), sig.clone());
}
}
}
sigs
}
// ── Legacy direct response capture (search.rs fallback) ──────────────
/// Set (replace) the captured response text.
pub async fn set_response_text(&self, text: &str) {
*self.captured_response_text.write().await = Some(text.to_string());
}
/// Take the captured response text (consumes it).
pub async fn take_response_text(&self) -> Option<String> {
self.captured_response_text.write().await.take()
}
/// Clear stale legacy response state.
pub async fn clear_response_async(&self) {
*self.captured_response_text.write().await = None;
}
// ── Grounding metadata capture ──────────────────────────────────────
/// Store captured grounding metadata from API response.
pub async fn set_grounding(&self, meta: serde_json::Value) {
*self.captured_grounding.write().await = Some(meta);
}
/// Take (consume) captured grounding metadata.
#[allow(dead_code)]
pub async fn take_grounding(&self) -> Option<serde_json::Value> {
self.captured_grounding.write().await.take()
}
/// Peek at grounding metadata without consuming.
#[allow(dead_code)]
pub async fn peek_grounding(&self) -> Option<serde_json::Value> {
self.captured_grounding.read().await.clone()
}
// ── Compat shims for streaming tool-call loops ──────────────────────
/// Update the event channel on an existing request context,
/// or re-register a minimal context if it was already consumed by `take_request`.
///
/// This is critical for thinking-only intermediate responses: the MITM proxy
/// consumes the context via `take_request`, but the handler needs to re-install
/// a channel for the LS's follow-up request.
pub async fn set_channel(&self, cascade_id: &str, tx: mpsc::Sender<MitmEvent>) {
let updated = self
.update_request(cascade_id, |ctx| {
ctx.event_channel = tx.clone();
})
.await;
if !updated {
// Context was already consumed — re-register a minimal one
// so the MITM proxy can match the follow-up request.
let gate = std::sync::Arc::new(tokio::sync::Notify::new());
self.register_request(RequestContext {
cascade_id: cascade_id.to_string(),
pending_user_text: String::new(),
event_channel: tx,
generation_params: None,
pending_image: None,
tools: None,
tool_config: None,
pending_tool_results: Vec::new(),
tool_rounds: Vec::new(),
last_function_calls: Vec::new(),
call_id_to_name: std::collections::HashMap::new(),
created_at: std::time::Instant::now(),
gate,
trace_handle: None,
trace_turn: 0,
})
.await;
tracing::debug!(
cascade = cascade_id,
"set_channel: re-registered minimal context (original was consumed)"
);
}
}
/// No-op. Upstream errors are now delivered through the event channel.
/// Kept for API handler compatibility.
pub async fn clear_upstream_error(&self) {
// Intentionally empty — errors flow through MitmEvent::UpstreamError
}
/// Returns None. Upstream errors are now captured and delivered via the
/// per-request event channel rather than stored globally.
pub async fn take_upstream_error(&self) -> Option<UpstreamError> {
None
}
/// Store a call_id → function_name mapping in the request context.
/// Used by streaming tool-call loops when the model returns function calls.
pub async fn register_call_id(&self, cascade_id: &str, call_id: String, name: String) {
self.update_request(cascade_id, |ctx| {
ctx.call_id_to_name.insert(call_id, name);
})
.await;
}
}