//! Shared store for intercepted API usage data. //! //! Per-request state is stored in `RequestContext`, keyed by cascade ID. //! The MITM proxy looks up the context when intercepting LS requests, //! enabling concurrent request processing without global locks. use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::sync::Arc; use std::time::Instant; use tokio::sync::{mpsc, RwLock}; use tracing::{debug, info}; /// Token usage from an intercepted API response. /// /// Covers both Anthropic JSON/SSE responses and Google gRPC protobuf responses. /// Fields map to the superset of Anthropic's `usage` object and Google's `ModelUsageStats` proto. #[derive(Debug, Clone, Default, Serialize, Deserialize)] pub struct ApiUsage { pub input_tokens: u64, pub output_tokens: u64, /// Anthropic: cache_creation_input_tokens / Google: cache_write_tokens pub cache_creation_input_tokens: u64, /// Anthropic: cache_read_input_tokens / Google: cache_read_tokens pub cache_read_input_tokens: u64, /// Google-specific: thinking/reasoning output tokens (extended thinking) pub thinking_output_tokens: u64, /// The actual thinking/reasoning text from the model. /// Captured from Google SSE parts with `thought: true` or Anthropic thinking blocks. #[serde(skip_serializing_if = "Option::is_none")] pub thinking_text: Option, /// The response text captured from SSE parts (for merge detection). #[serde(skip)] pub response_text: Option, /// Google-specific: response output tokens (non-thinking portion) pub response_output_tokens: u64, /// The actual model that served the request. pub model: Option, /// Stop reason / finish reason from the API. pub stop_reason: Option, /// API provider (e.g. "anthropic", "google") pub api_provider: Option, /// gRPC method path (e.g. "/google.internal.cloud.code.v1internal.PredictionService/GenerateContent") pub grpc_method: Option, /// Timestamp when this usage was captured. pub captured_at: u64, /// Thinking signature from Google's response (base64 opaque blob). /// Required for multi-turn with thinking models. #[serde(skip_serializing_if = "Option::is_none")] pub thinking_signature: Option, } /// A captured function call from Google's API response. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct CapturedFunctionCall { pub name: String, pub args: serde_json::Value, /// Google's thought signature — required when injecting functionCall back /// into conversation history. Without it, Google returns INVALID_ARGUMENT. #[serde(skip_serializing_if = "Option::is_none")] pub thought_signature: Option, pub captured_at: u64, } /// A pending tool result from a client's function_call_output. #[derive(Debug, Clone)] pub struct PendingToolResult { pub name: String, pub result: serde_json::Value, } /// A single round of tool calling: the model's function calls paired with /// the client's execution results. /// /// In multi-step tool use, each round has its own calls and results. /// This preserves per-turn data so history rewriting can map each /// "Tool call completed" model turn to the correct functionCall/functionResponse. #[derive(Debug, Clone)] pub struct ToolRound { pub calls: Vec, pub results: Vec, } /// An upstream error captured from Google's API response. /// Stored by the MITM proxy so API handlers can return it to the client /// instead of hanging forever waiting for a response that won't come. #[derive(Debug, Clone)] pub struct UpstreamError { /// HTTP status code from Google (e.g. 400, 429, 500). pub status: u16, /// Raw error body from Google (usually JSON). #[allow(dead_code)] pub body: String, /// Parsed error message, if available. pub message: Option, /// Google error status string (e.g. "INVALID_ARGUMENT", "RESOURCE_EXHAUSTED"). pub error_status: Option, } /// A pending image to inject via MITM into the Google API request. /// The LS doesn't forward images from our SendUserCascadeMessage proto, /// so we inject them directly at the MITM layer. #[derive(Debug, Clone)] pub struct PendingImage { /// Base64-encoded image data (no prefix). pub base64_data: String, /// MIME type, e.g. "image/png". pub mime_type: String, } /// Client-specified generation parameters for MITM injection. /// Set by API handlers, consumed by the MITM modify layer. #[derive(Debug, Clone, Default)] pub struct GenerationParams { pub temperature: Option, pub top_p: Option, pub top_k: Option, pub max_output_tokens: Option, pub stop_sequences: Option>, /// Frequency penalty (OpenAI) — mapped to frequencyPenalty in Gemini. pub frequency_penalty: Option, /// Presence penalty (OpenAI) — mapped to presencePenalty in Gemini. pub presence_penalty: Option, /// Reasoning effort — mapped to thinkingConfig.thinkingLevel in Gemini 3. /// Values: "low", "medium", "high" (maps 1:1 to Google's thinkingLevel). pub reasoning_effort: Option, /// Response MIME type — injected as generationConfig.responseMimeType. /// e.g., "application/json" for JSON mode. pub response_mime_type: Option, /// Response schema — injected as generationConfig.responseSchema. /// Used for structured output (json_schema format). pub response_schema: Option, /// Enable Google Search grounding — injects {"googleSearch": {}} into tools. /// Default off. When enabled, model responses include groundingMetadata. pub google_search: bool, } /// Cached context from turn 0 of a cascade. /// /// On the first turn, the MITM proxy consumes the `RequestContext` and builds /// a `ToolContext`. On subsequent turns (tool-call loops), the `RequestContext` /// is gone. This cache stores the essential fields so we can rebuild a lite /// `ToolContext` on every turn — ensuring the model always sees the real user /// text and has access to custom tools. #[derive(Debug, Clone)] pub struct CascadeCache { /// The real user text (used to replace the "." dot prompt). pub user_text: String, /// Custom tool definitions (Gemini format). pub tools: Option>, /// Custom tool config. pub tool_config: Option, /// Client generation parameters. pub generation_params: Option, } // ─── Channel-based event pipeline ──────────────────────────────────────────── /// Events sent from the MITM proxy to API handlers through a per-request channel. /// Replaces the old polling-based approach (shared atomics + RwLocks) with /// instant, race-free delivery. #[derive(Debug, Clone)] pub enum MitmEvent { /// Incremental thinking/reasoning text from the model. ThinkingDelta(String), /// Incremental response text from the model. TextDelta(String), /// Model requested function call(s). FunctionCall(Vec), /// Response streaming is complete (finishReason received). ResponseComplete, /// Google API returned an error. UpstreamError(UpstreamError), /// Grounding metadata (search results) from the response. #[allow(dead_code)] Grounding(serde_json::Value), /// Token usage data from the response. Usage(ApiUsage), } // ─── Per-request context ───────────────────────────────────────────────────── /// All per-request state. Keyed by cascade ID in `MitmStore.pending_requests`. /// /// API handlers build this before `send_message`, and the MITM proxy consumes /// it when the LS's outbound request is intercepted. #[derive(Debug)] pub struct RequestContext { /// Cascade ID this context belongs to. pub cascade_id: String, /// Real user text for MITM injection (LS receives "." instead). pub pending_user_text: String, /// Event channel for real-time streaming from MITM → API handler. pub event_channel: mpsc::Sender, /// Client-specified generation parameters (temperature, top_p, etc.). pub generation_params: Option, /// Image to inject into the Google API request. pub pending_image: Option, /// Gemini-format tool declarations for MITM injection. pub tools: Option>, /// Gemini-format toolConfig. pub tool_config: Option, /// Pending tool results to inject as functionResponse. pub pending_tool_results: Vec, /// Multi-round tool call history for history rewriting. pub tool_rounds: Vec, /// Last captured function calls for history rewriting. pub last_function_calls: Vec, /// Mapping call_id → function name for tool result routing. pub call_id_to_name: HashMap, /// When this context was created (for TTL cleanup). pub created_at: Instant, /// Gate: signaled when MITM takes this context. /// API handlers wait on this with a timeout to detect match failures. pub gate: Arc, /// Debug trace handle (if tracing is enabled). #[allow(dead_code)] pub trace_handle: Option, /// Current turn index in the trace (for multi-turn tracking). #[allow(dead_code)] pub trace_turn: usize, } // ─── MitmStore ─────────────────────────────────────────────────────────────── /// Thread-safe store for intercepted data. /// /// Per-request state lives in `pending_requests`, keyed by cascade ID. /// Global state (usage stats, function call capture) remains shared. #[derive(Clone)] pub struct MitmStore { /// Most recent usage per cascade ID. latest_usage: Arc>>, /// Global aggregate stats. stats: Arc>, /// Pending function calls captured from Google responses. /// Key: cascade hint or "_latest". Value: list of function calls. pending_function_calls: Arc>>>, // ── Per-request state (keyed by cascade ID) ────────────────────────── /// Active request contexts. API handlers register before send_message, /// MITM proxy consumes when intercepting the LS request. pending_requests: Arc>>, /// Cached context from turn 0, keyed by cascade ID. /// Used to rebuild ToolContext on subsequent turns of the same cascade. cascade_cache: Arc>>, // ── Legacy direct response capture (used by search.rs) ─────────────── /// Captured response text from MITM. Used as fallback by search endpoint. captured_response_text: Arc>>, // ── Grounding metadata capture ────────────────────────────────────── /// Captured grounding metadata from Google API responses (search results). captured_grounding: Arc>>, } /// Aggregate statistics across all intercepted traffic. #[derive(Debug, Clone, Default, Serialize)] pub struct MitmStats { pub total_requests: u64, pub total_input_tokens: u64, pub total_output_tokens: u64, pub total_cache_read_tokens: u64, pub total_cache_creation_tokens: u64, pub total_thinking_output_tokens: u64, pub total_response_output_tokens: u64, /// Per-model usage breakdown (model name → stats). pub per_model: HashMap, } /// Per-model usage counters. #[derive(Debug, Clone, Default, Serialize)] pub struct ModelStats { pub requests: u64, pub input_tokens: u64, pub output_tokens: u64, pub cache_read_tokens: u64, pub cache_creation_tokens: u64, } impl MitmStore { pub fn new() -> Self { Self { latest_usage: Arc::new(RwLock::new(HashMap::new())), stats: Arc::new(RwLock::new(MitmStats::default())), pending_function_calls: Arc::new(RwLock::new(HashMap::new())), pending_requests: Arc::new(RwLock::new(HashMap::new())), cascade_cache: Arc::new(RwLock::new(HashMap::new())), captured_response_text: Arc::new(RwLock::new(None)), captured_grounding: Arc::new(RwLock::new(None)), } } // ── Per-request context management ─────────────────────────────────── /// Register a request context for a cascade. Called by API handlers /// before `send_message` so the MITM proxy can find it. pub async fn register_request(&self, ctx: RequestContext) { let cascade_id = ctx.cascade_id.clone(); info!(cascade = %cascade_id, "Registered request context"); self.pending_requests.write().await.insert(cascade_id, ctx); } /// Take (consume) the request context for a cascade. /// Called by the MITM proxy when intercepting the LS's outbound request. pub async fn take_request(&self, cascade_id: &str) -> Option { let ctx = self.pending_requests.write().await.remove(cascade_id); if let Some(ref c) = ctx { c.gate.notify_one(); debug!(cascade = %cascade_id, "Took request context (gate signaled)"); } ctx } /// Take the most recently registered request context (by creation time). /// Fallback when cascade_id can't be extracted from the Google API request. pub async fn take_latest_request(&self) -> Option { let mut pending = self.pending_requests.write().await; if pending.is_empty() { return None; } // Find the most recently created request let latest_key = pending .iter() .max_by_key(|(_, ctx)| ctx.created_at) .map(|(k, _)| k.clone()); if let Some(key) = latest_key { let ctx = pending.remove(&key); if let Some(ref c) = ctx { c.gate.notify_one(); debug!(cascade = %key, "Took latest request context (fallback, gate signaled)"); } ctx } else { None } } /// Update a request context in-place. Returns false if not found. pub async fn update_request(&self, cascade_id: &str, updater: F) -> bool where F: FnOnce(&mut RequestContext), { let mut map = self.pending_requests.write().await; if let Some(ctx) = map.get_mut(cascade_id) { updater(ctx); true } else { false } } /// Remove a request context (cleanup after response is complete). pub async fn remove_request(&self, cascade_id: &str) { if self .pending_requests .write() .await .remove(cascade_id) .is_some() { debug!(cascade = %cascade_id, "Removed request context"); } } // ── Cascade cache (turn 0 context for re-injection on turn 1+) ────── /// Cache the essential context from turn 0 so it can be re-used on /// subsequent turns of the same cascade. pub async fn cache_cascade(&self, cascade_id: &str, cache: CascadeCache) { debug!(cascade = %cascade_id, user_text_len = cache.user_text.len(), has_tools = cache.tools.is_some(), "Cached cascade context for subsequent turns"); self.cascade_cache .write() .await .insert(cascade_id.to_string(), cache); } /// Get cached context for a cascade (non-consuming — needed on every turn). pub async fn get_cascade_cache(&self, cascade_id: &str) -> Option { self.cascade_cache.read().await.get(cascade_id).cloned() } /// Check if a cascade has been processed (turn 0 complete). pub async fn has_cascade_cache(&self, cascade_id: &str) -> bool { self.cascade_cache.read().await.contains_key(cascade_id) } // ── Usage recording ────────────────────────────────────────────────── /// Record a completed API exchange with usage data. pub async fn record_usage(&self, cascade_id: Option<&str>, usage: ApiUsage) { debug!( input = usage.input_tokens, output = usage.output_tokens, cache_read = usage.cache_read_input_tokens, cache_create = usage.cache_creation_input_tokens, thinking = usage.thinking_output_tokens, response = usage.response_output_tokens, model = ?usage.model, provider = ?usage.api_provider, grpc = ?usage.grpc_method, "MITM captured API usage" ); // Update aggregate stats { let mut stats = self.stats.write().await; stats.total_requests += 1; stats.total_input_tokens += usage.input_tokens; stats.total_output_tokens += usage.output_tokens; stats.total_cache_read_tokens += usage.cache_read_input_tokens; stats.total_cache_creation_tokens += usage.cache_creation_input_tokens; stats.total_thinking_output_tokens += usage.thinking_output_tokens; stats.total_response_output_tokens += usage.response_output_tokens; // Per-model breakdown if let Some(ref model_name) = usage.model { let model_stats = stats.per_model.entry(model_name.clone()).or_default(); model_stats.requests += 1; model_stats.input_tokens += usage.input_tokens; model_stats.output_tokens += usage.output_tokens; model_stats.cache_read_tokens += usage.cache_read_input_tokens; model_stats.cache_creation_tokens += usage.cache_creation_input_tokens; } } // Store latest usage for the cascade (if we can identify it). // // Merge logic for v1internal thinking summaries: // The LS makes TWO Google API calls per thinking request: // Call 1: response + thinking token count (thinking_output_tokens > 0, no thinking text) // Call 2: thinking summary text (thinking_output_tokens == 0, response_text has the summary) // // When Call 2 arrives, we merge its response_text as thinking_text into Call 1's usage. let key = cascade_id.map_or_else(|| "_latest".to_string(), |s| s.to_string()); let mut latest = self.latest_usage.write().await; if let Some(existing) = latest.get_mut(&key) { if existing.thinking_output_tokens > 0 && existing.thinking_text.is_none() && usage.thinking_output_tokens == 0 && usage.response_text.is_some() { // Call 2: thinking summary — merge into existing Call 1 usage existing.thinking_text = usage.response_text; debug!( thinking_text_len = existing.thinking_text.as_ref().map_or(0, |t| t.len()), "MITM: merged thinking summary text into existing usage" ); } else { // Normal case: replace existing usage latest.insert(key, usage); } } else { latest.insert(key, usage); } // Evict old entries to prevent unbounded memory growth const MAX_ENTRIES: usize = 500; if latest.len() > MAX_ENTRIES { let oldest_key = latest .iter() .min_by_key(|(_, v)| v.captured_at) .map(|(k, _)| k.clone()); if let Some(key) = oldest_key { latest.remove(&key); } } } /// Peek at usage data for a cascade without consuming it. pub async fn peek_usage(&self, cascade_id: &str) -> Option { let latest = self.latest_usage.read().await; latest.get(cascade_id).cloned() } /// Only returns exact cascade_id matches — no cross-cascade fallback. pub async fn take_usage(&self, cascade_id: &str) -> Option { let mut latest = self.latest_usage.write().await; latest.remove(cascade_id) } /// Get aggregate stats. pub async fn stats(&self) -> MitmStats { self.stats.read().await.clone() } // ── Function call capture ──────────────────────────────────────────── /// Record a captured function call from Google's response. pub async fn record_function_call(&self, cascade_id: Option<&str>, fc: CapturedFunctionCall) { let key = cascade_id.map_or_else(|| "_latest".to_string(), |s| s.to_string()); info!( cascade = %key, tool = %fc.name, args = %fc.args, "MITM store: captured functionCall" ); let mut pending = self.pending_function_calls.write().await; pending.entry(key).or_default().push(fc); } /// Take pending function calls for a specific cascade. /// /// Priority: exact cascade_id → `_latest` → any key. pub async fn take_function_calls(&self, cascade_id: &str) -> Option> { let mut pending = self.pending_function_calls.write().await; // 1. Exact cascade match if let Some(result) = pending.remove(cascade_id) { return Some(result); } // 2. Fallback to _latest if let Some(result) = pending.remove("_latest") { return Some(result); } // 3. Last resort: any key if let Some(key) = pending.keys().next().cloned() { return pending.remove(&key); } None } /// Take any pending function calls (ignoring cascade ID). pub async fn take_any_function_calls(&self) -> Option> { let mut pending = self.pending_function_calls.write().await; let result = pending.remove("_latest"); if result.is_some() { return result; } if let Some(key) = pending.keys().next().cloned() { return pending.remove(&key); } None } /// Peek at the thought_signatures of recently captured function calls. /// Returns a map of function_name → thought_signature (non-destructive). pub async fn peek_thought_signatures(&self) -> std::collections::HashMap { let pending = self.pending_function_calls.read().await; let mut sigs = std::collections::HashMap::new(); for calls in pending.values() { for fc in calls { if let Some(ref sig) = fc.thought_signature { sigs.insert(fc.name.clone(), sig.clone()); } } } sigs } // ── Legacy direct response capture (search.rs fallback) ────────────── /// Set (replace) the captured response text. pub async fn set_response_text(&self, text: &str) { *self.captured_response_text.write().await = Some(text.to_string()); } /// Take the captured response text (consumes it). pub async fn take_response_text(&self) -> Option { self.captured_response_text.write().await.take() } /// Clear stale legacy response state. pub async fn clear_response_async(&self) { *self.captured_response_text.write().await = None; } // ── Grounding metadata capture ────────────────────────────────────── /// Store captured grounding metadata from API response. pub async fn set_grounding(&self, meta: serde_json::Value) { *self.captured_grounding.write().await = Some(meta); } /// Take (consume) captured grounding metadata. #[allow(dead_code)] pub async fn take_grounding(&self) -> Option { self.captured_grounding.write().await.take() } /// Peek at grounding metadata without consuming. #[allow(dead_code)] pub async fn peek_grounding(&self) -> Option { self.captured_grounding.read().await.clone() } // ── Compat shims for streaming tool-call loops ────────────────────── /// Update the event channel on an existing request context, /// or re-register a minimal context if it was already consumed by `take_request`. /// /// This is critical for thinking-only intermediate responses: the MITM proxy /// consumes the context via `take_request`, but the handler needs to re-install /// a channel for the LS's follow-up request. pub async fn set_channel(&self, cascade_id: &str, tx: mpsc::Sender) { let updated = self .update_request(cascade_id, |ctx| { ctx.event_channel = tx.clone(); }) .await; if !updated { // Context was already consumed — re-register a minimal one // so the MITM proxy can match the follow-up request. let gate = std::sync::Arc::new(tokio::sync::Notify::new()); self.register_request(RequestContext { cascade_id: cascade_id.to_string(), pending_user_text: String::new(), event_channel: tx, generation_params: None, pending_image: None, tools: None, tool_config: None, pending_tool_results: Vec::new(), tool_rounds: Vec::new(), last_function_calls: Vec::new(), call_id_to_name: std::collections::HashMap::new(), created_at: std::time::Instant::now(), gate, trace_handle: None, trace_turn: 0, }) .await; tracing::debug!( cascade = cascade_id, "set_channel: re-registered minimal context (original was consumed)" ); } } /// No-op. Upstream errors are now delivered through the event channel. /// Kept for API handler compatibility. pub async fn clear_upstream_error(&self) { // Intentionally empty — errors flow through MitmEvent::UpstreamError } /// Returns None. Upstream errors are now captured and delivered via the /// per-request event channel rather than stored globally. pub async fn take_upstream_error(&self) -> Option { None } /// Store a call_id → function_name mapping in the request context. /// Used by streaming tool-call loops when the model returns function calls. pub async fn register_call_id(&self, cascade_id: &str, call_id: String, name: String) { self.update_request(cascade_id, |ctx| { ctx.call_id_to_name.insert(call_id, name); }) .await; } }