659 lines
27 KiB
Rust
659 lines
27 KiB
Rust
//! Shared store for intercepted API usage data.
|
|
//!
|
|
//! Per-request state is stored in `RequestContext`, keyed by cascade ID.
|
|
//! The MITM proxy looks up the context when intercepting LS requests,
|
|
//! enabling concurrent request processing without global locks.
|
|
|
|
use serde::{Deserialize, Serialize};
|
|
use std::collections::HashMap;
|
|
use std::sync::Arc;
|
|
use std::time::Instant;
|
|
use tokio::sync::{mpsc, RwLock};
|
|
use tracing::{debug, info};
|
|
|
|
/// Token usage from an intercepted API response.
|
|
///
|
|
/// Covers both Anthropic JSON/SSE responses and Google gRPC protobuf responses.
|
|
/// Fields map to the superset of Anthropic's `usage` object and Google's `ModelUsageStats` proto.
|
|
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
|
pub struct ApiUsage {
|
|
pub input_tokens: u64,
|
|
pub output_tokens: u64,
|
|
/// Anthropic: cache_creation_input_tokens / Google: cache_write_tokens
|
|
pub cache_creation_input_tokens: u64,
|
|
/// Anthropic: cache_read_input_tokens / Google: cache_read_tokens
|
|
pub cache_read_input_tokens: u64,
|
|
/// Google-specific: thinking/reasoning output tokens (extended thinking)
|
|
pub thinking_output_tokens: u64,
|
|
/// The actual thinking/reasoning text from the model.
|
|
/// Captured from Google SSE parts with `thought: true` or Anthropic thinking blocks.
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub thinking_text: Option<String>,
|
|
/// The response text captured from SSE parts (for merge detection).
|
|
#[serde(skip)]
|
|
pub response_text: Option<String>,
|
|
/// Google-specific: response output tokens (non-thinking portion)
|
|
pub response_output_tokens: u64,
|
|
|
|
/// The actual model that served the request.
|
|
pub model: Option<String>,
|
|
/// Stop reason / finish reason from the API.
|
|
pub stop_reason: Option<String>,
|
|
/// API provider (e.g. "anthropic", "google")
|
|
pub api_provider: Option<String>,
|
|
/// gRPC method path (e.g. "/google.internal.cloud.code.v1internal.PredictionService/GenerateContent")
|
|
pub grpc_method: Option<String>,
|
|
/// Timestamp when this usage was captured.
|
|
pub captured_at: u64,
|
|
/// Thinking signature from Google's response (base64 opaque blob).
|
|
/// Required for multi-turn with thinking models.
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub thinking_signature: Option<String>,
|
|
}
|
|
|
|
/// A captured function call from Google's API response.
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct CapturedFunctionCall {
|
|
pub name: String,
|
|
pub args: serde_json::Value,
|
|
/// Google's thought signature — required when injecting functionCall back
|
|
/// into conversation history. Without it, Google returns INVALID_ARGUMENT.
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub thought_signature: Option<String>,
|
|
pub captured_at: u64,
|
|
}
|
|
|
|
/// A pending tool result from a client's function_call_output.
|
|
#[derive(Debug, Clone)]
|
|
pub struct PendingToolResult {
|
|
pub name: String,
|
|
pub result: serde_json::Value,
|
|
}
|
|
|
|
/// A single round of tool calling: the model's function calls paired with
|
|
/// the client's execution results.
|
|
///
|
|
/// In multi-step tool use, each round has its own calls and results.
|
|
/// This preserves per-turn data so history rewriting can map each
|
|
/// "Tool call completed" model turn to the correct functionCall/functionResponse.
|
|
#[derive(Debug, Clone)]
|
|
pub struct ToolRound {
|
|
pub calls: Vec<CapturedFunctionCall>,
|
|
pub results: Vec<PendingToolResult>,
|
|
}
|
|
|
|
/// An upstream error captured from Google's API response.
|
|
/// Stored by the MITM proxy so API handlers can return it to the client
|
|
/// instead of hanging forever waiting for a response that won't come.
|
|
#[derive(Debug, Clone)]
|
|
pub struct UpstreamError {
|
|
/// HTTP status code from Google (e.g. 400, 429, 500).
|
|
pub status: u16,
|
|
/// Raw error body from Google (usually JSON).
|
|
#[allow(dead_code)]
|
|
pub body: String,
|
|
/// Parsed error message, if available.
|
|
pub message: Option<String>,
|
|
/// Google error status string (e.g. "INVALID_ARGUMENT", "RESOURCE_EXHAUSTED").
|
|
pub error_status: Option<String>,
|
|
}
|
|
|
|
/// A pending image to inject via MITM into the Google API request.
|
|
/// The LS doesn't forward images from our SendUserCascadeMessage proto,
|
|
/// so we inject them directly at the MITM layer.
|
|
#[derive(Debug, Clone)]
|
|
pub struct PendingImage {
|
|
/// Base64-encoded image data (no prefix).
|
|
pub base64_data: String,
|
|
/// MIME type, e.g. "image/png".
|
|
pub mime_type: String,
|
|
}
|
|
|
|
/// Client-specified generation parameters for MITM injection.
|
|
/// Set by API handlers, consumed by the MITM modify layer.
|
|
#[derive(Debug, Clone, Default)]
|
|
pub struct GenerationParams {
|
|
pub temperature: Option<f64>,
|
|
pub top_p: Option<f64>,
|
|
pub top_k: Option<u32>,
|
|
pub max_output_tokens: Option<u64>,
|
|
pub stop_sequences: Option<Vec<String>>,
|
|
/// Frequency penalty (OpenAI) — mapped to frequencyPenalty in Gemini.
|
|
pub frequency_penalty: Option<f64>,
|
|
/// Presence penalty (OpenAI) — mapped to presencePenalty in Gemini.
|
|
pub presence_penalty: Option<f64>,
|
|
/// Reasoning effort — mapped to thinkingConfig.thinkingLevel in Gemini 3.
|
|
/// Values: "low", "medium", "high" (maps 1:1 to Google's thinkingLevel).
|
|
pub reasoning_effort: Option<String>,
|
|
/// Response MIME type — injected as generationConfig.responseMimeType.
|
|
/// e.g., "application/json" for JSON mode.
|
|
pub response_mime_type: Option<String>,
|
|
/// Response schema — injected as generationConfig.responseSchema.
|
|
/// Used for structured output (json_schema format).
|
|
pub response_schema: Option<serde_json::Value>,
|
|
/// Enable Google Search grounding — injects {"googleSearch": {}} into tools.
|
|
/// Default off. When enabled, model responses include groundingMetadata.
|
|
pub google_search: bool,
|
|
}
|
|
|
|
/// Cached context from turn 0 of a cascade.
|
|
///
|
|
/// On the first turn, the MITM proxy consumes the `RequestContext` and builds
|
|
/// a `ToolContext`. On subsequent turns (tool-call loops), the `RequestContext`
|
|
/// is gone. This cache stores the essential fields so we can rebuild a lite
|
|
/// `ToolContext` on every turn — ensuring the model always sees the real user
|
|
/// text and has access to custom tools.
|
|
#[derive(Debug, Clone)]
|
|
pub struct CascadeCache {
|
|
/// The real user text (used to replace the "." dot prompt).
|
|
pub user_text: String,
|
|
/// Custom tool definitions (Gemini format).
|
|
pub tools: Option<Vec<serde_json::Value>>,
|
|
/// Custom tool config.
|
|
pub tool_config: Option<serde_json::Value>,
|
|
/// Client generation parameters.
|
|
pub generation_params: Option<GenerationParams>,
|
|
}
|
|
|
|
// ─── Channel-based event pipeline ────────────────────────────────────────────
|
|
|
|
/// Events sent from the MITM proxy to API handlers through a per-request channel.
|
|
/// Replaces the old polling-based approach (shared atomics + RwLocks) with
|
|
/// instant, race-free delivery.
|
|
#[derive(Debug, Clone)]
|
|
pub enum MitmEvent {
|
|
/// Incremental thinking/reasoning text from the model.
|
|
ThinkingDelta(String),
|
|
/// Incremental response text from the model.
|
|
TextDelta(String),
|
|
/// Model requested function call(s).
|
|
FunctionCall(Vec<CapturedFunctionCall>),
|
|
/// Response streaming is complete (finishReason received).
|
|
ResponseComplete,
|
|
/// Google API returned an error.
|
|
UpstreamError(UpstreamError),
|
|
/// Grounding metadata (search results) from the response.
|
|
#[allow(dead_code)]
|
|
Grounding(serde_json::Value),
|
|
/// Token usage data from the response.
|
|
Usage(ApiUsage),
|
|
}
|
|
|
|
// ─── Per-request context ─────────────────────────────────────────────────────
|
|
|
|
/// All per-request state. Keyed by cascade ID in `MitmStore.pending_requests`.
|
|
///
|
|
/// API handlers build this before `send_message`, and the MITM proxy consumes
|
|
/// it when the LS's outbound request is intercepted.
|
|
#[derive(Debug)]
|
|
pub struct RequestContext {
|
|
/// Cascade ID this context belongs to.
|
|
pub cascade_id: String,
|
|
/// Real user text for MITM injection (LS receives "." instead).
|
|
pub pending_user_text: String,
|
|
/// Event channel for real-time streaming from MITM → API handler.
|
|
pub event_channel: mpsc::Sender<MitmEvent>,
|
|
/// Client-specified generation parameters (temperature, top_p, etc.).
|
|
pub generation_params: Option<GenerationParams>,
|
|
/// Image to inject into the Google API request.
|
|
pub pending_image: Option<PendingImage>,
|
|
/// Gemini-format tool declarations for MITM injection.
|
|
pub tools: Option<Vec<serde_json::Value>>,
|
|
/// Gemini-format toolConfig.
|
|
pub tool_config: Option<serde_json::Value>,
|
|
/// Pending tool results to inject as functionResponse.
|
|
pub pending_tool_results: Vec<PendingToolResult>,
|
|
/// Multi-round tool call history for history rewriting.
|
|
pub tool_rounds: Vec<ToolRound>,
|
|
/// Last captured function calls for history rewriting.
|
|
pub last_function_calls: Vec<CapturedFunctionCall>,
|
|
/// Mapping call_id → function name for tool result routing.
|
|
pub call_id_to_name: HashMap<String, String>,
|
|
/// When this context was created (for TTL cleanup).
|
|
pub created_at: Instant,
|
|
/// Gate: signaled when MITM takes this context.
|
|
/// API handlers wait on this with a timeout to detect match failures.
|
|
pub gate: Arc<tokio::sync::Notify>,
|
|
/// Debug trace handle (if tracing is enabled).
|
|
#[allow(dead_code)]
|
|
pub trace_handle: Option<crate::trace::TraceHandle>,
|
|
/// Current turn index in the trace (for multi-turn tracking).
|
|
#[allow(dead_code)]
|
|
pub trace_turn: usize,
|
|
}
|
|
|
|
// ─── MitmStore ───────────────────────────────────────────────────────────────
|
|
|
|
/// Thread-safe store for intercepted data.
|
|
///
|
|
/// Per-request state lives in `pending_requests`, keyed by cascade ID.
|
|
/// Global state (usage stats, function call capture) remains shared.
|
|
#[derive(Clone)]
|
|
pub struct MitmStore {
|
|
/// Most recent usage per cascade ID.
|
|
latest_usage: Arc<RwLock<HashMap<String, ApiUsage>>>,
|
|
/// Global aggregate stats.
|
|
stats: Arc<RwLock<MitmStats>>,
|
|
/// Pending function calls captured from Google responses.
|
|
/// Key: cascade hint or "_latest". Value: list of function calls.
|
|
pending_function_calls: Arc<RwLock<HashMap<String, Vec<CapturedFunctionCall>>>>,
|
|
|
|
// ── Per-request state (keyed by cascade ID) ──────────────────────────
|
|
/// Active request contexts. API handlers register before send_message,
|
|
/// MITM proxy consumes when intercepting the LS request.
|
|
pending_requests: Arc<RwLock<HashMap<String, RequestContext>>>,
|
|
|
|
/// Cached context from turn 0, keyed by cascade ID.
|
|
/// Used to rebuild ToolContext on subsequent turns of the same cascade.
|
|
cascade_cache: Arc<RwLock<HashMap<String, CascadeCache>>>,
|
|
|
|
// ── Legacy direct response capture (used by search.rs) ───────────────
|
|
/// Captured response text from MITM. Used as fallback by search endpoint.
|
|
captured_response_text: Arc<RwLock<Option<String>>>,
|
|
|
|
// ── Grounding metadata capture ──────────────────────────────────────
|
|
/// Captured grounding metadata from Google API responses (search results).
|
|
captured_grounding: Arc<RwLock<Option<serde_json::Value>>>,
|
|
}
|
|
|
|
/// Aggregate statistics across all intercepted traffic.
|
|
#[derive(Debug, Clone, Default, Serialize)]
|
|
pub struct MitmStats {
|
|
pub total_requests: u64,
|
|
pub total_input_tokens: u64,
|
|
pub total_output_tokens: u64,
|
|
pub total_cache_read_tokens: u64,
|
|
pub total_cache_creation_tokens: u64,
|
|
pub total_thinking_output_tokens: u64,
|
|
pub total_response_output_tokens: u64,
|
|
/// Per-model usage breakdown (model name → stats).
|
|
pub per_model: HashMap<String, ModelStats>,
|
|
}
|
|
|
|
/// Per-model usage counters.
|
|
#[derive(Debug, Clone, Default, Serialize)]
|
|
pub struct ModelStats {
|
|
pub requests: u64,
|
|
pub input_tokens: u64,
|
|
pub output_tokens: u64,
|
|
pub cache_read_tokens: u64,
|
|
pub cache_creation_tokens: u64,
|
|
}
|
|
|
|
impl MitmStore {
|
|
pub fn new() -> Self {
|
|
Self {
|
|
latest_usage: Arc::new(RwLock::new(HashMap::new())),
|
|
stats: Arc::new(RwLock::new(MitmStats::default())),
|
|
pending_function_calls: Arc::new(RwLock::new(HashMap::new())),
|
|
pending_requests: Arc::new(RwLock::new(HashMap::new())),
|
|
cascade_cache: Arc::new(RwLock::new(HashMap::new())),
|
|
captured_response_text: Arc::new(RwLock::new(None)),
|
|
captured_grounding: Arc::new(RwLock::new(None)),
|
|
}
|
|
}
|
|
|
|
// ── Per-request context management ───────────────────────────────────
|
|
|
|
/// Register a request context for a cascade. Called by API handlers
|
|
/// before `send_message` so the MITM proxy can find it.
|
|
pub async fn register_request(&self, ctx: RequestContext) {
|
|
let cascade_id = ctx.cascade_id.clone();
|
|
info!(cascade = %cascade_id, "Registered request context");
|
|
self.pending_requests.write().await.insert(cascade_id, ctx);
|
|
}
|
|
|
|
/// Take (consume) the request context for a cascade.
|
|
/// Called by the MITM proxy when intercepting the LS's outbound request.
|
|
pub async fn take_request(&self, cascade_id: &str) -> Option<RequestContext> {
|
|
let ctx = self.pending_requests.write().await.remove(cascade_id);
|
|
if let Some(ref c) = ctx {
|
|
c.gate.notify_one();
|
|
debug!(cascade = %cascade_id, "Took request context (gate signaled)");
|
|
}
|
|
ctx
|
|
}
|
|
|
|
/// Take the most recently registered request context (by creation time).
|
|
/// Fallback when cascade_id can't be extracted from the Google API request.
|
|
pub async fn take_latest_request(&self) -> Option<RequestContext> {
|
|
let mut pending = self.pending_requests.write().await;
|
|
if pending.is_empty() {
|
|
return None;
|
|
}
|
|
// Find the most recently created request
|
|
let latest_key = pending
|
|
.iter()
|
|
.max_by_key(|(_, ctx)| ctx.created_at)
|
|
.map(|(k, _)| k.clone());
|
|
if let Some(key) = latest_key {
|
|
let ctx = pending.remove(&key);
|
|
if let Some(ref c) = ctx {
|
|
c.gate.notify_one();
|
|
debug!(cascade = %key, "Took latest request context (fallback, gate signaled)");
|
|
}
|
|
ctx
|
|
} else {
|
|
None
|
|
}
|
|
}
|
|
|
|
/// Update a request context in-place. Returns false if not found.
|
|
pub async fn update_request<F>(&self, cascade_id: &str, updater: F) -> bool
|
|
where
|
|
F: FnOnce(&mut RequestContext),
|
|
{
|
|
let mut map = self.pending_requests.write().await;
|
|
if let Some(ctx) = map.get_mut(cascade_id) {
|
|
updater(ctx);
|
|
true
|
|
} else {
|
|
false
|
|
}
|
|
}
|
|
|
|
/// Remove a request context (cleanup after response is complete).
|
|
pub async fn remove_request(&self, cascade_id: &str) {
|
|
if self
|
|
.pending_requests
|
|
.write()
|
|
.await
|
|
.remove(cascade_id)
|
|
.is_some()
|
|
{
|
|
debug!(cascade = %cascade_id, "Removed request context");
|
|
}
|
|
}
|
|
|
|
// ── Cascade cache (turn 0 context for re-injection on turn 1+) ──────
|
|
|
|
/// Cache the essential context from turn 0 so it can be re-used on
|
|
/// subsequent turns of the same cascade.
|
|
pub async fn cache_cascade(&self, cascade_id: &str, cache: CascadeCache) {
|
|
debug!(cascade = %cascade_id, user_text_len = cache.user_text.len(),
|
|
has_tools = cache.tools.is_some(),
|
|
"Cached cascade context for subsequent turns");
|
|
self.cascade_cache
|
|
.write()
|
|
.await
|
|
.insert(cascade_id.to_string(), cache);
|
|
}
|
|
|
|
/// Get cached context for a cascade (non-consuming — needed on every turn).
|
|
pub async fn get_cascade_cache(&self, cascade_id: &str) -> Option<CascadeCache> {
|
|
self.cascade_cache.read().await.get(cascade_id).cloned()
|
|
}
|
|
|
|
/// Check if a cascade has been processed (turn 0 complete).
|
|
pub async fn has_cascade_cache(&self, cascade_id: &str) -> bool {
|
|
self.cascade_cache.read().await.contains_key(cascade_id)
|
|
}
|
|
|
|
// ── Usage recording ──────────────────────────────────────────────────
|
|
|
|
/// Record a completed API exchange with usage data.
|
|
pub async fn record_usage(&self, cascade_id: Option<&str>, usage: ApiUsage) {
|
|
debug!(
|
|
input = usage.input_tokens,
|
|
output = usage.output_tokens,
|
|
cache_read = usage.cache_read_input_tokens,
|
|
cache_create = usage.cache_creation_input_tokens,
|
|
thinking = usage.thinking_output_tokens,
|
|
response = usage.response_output_tokens,
|
|
model = ?usage.model,
|
|
provider = ?usage.api_provider,
|
|
grpc = ?usage.grpc_method,
|
|
"MITM captured API usage"
|
|
);
|
|
|
|
// Update aggregate stats
|
|
{
|
|
let mut stats = self.stats.write().await;
|
|
stats.total_requests += 1;
|
|
stats.total_input_tokens += usage.input_tokens;
|
|
stats.total_output_tokens += usage.output_tokens;
|
|
stats.total_cache_read_tokens += usage.cache_read_input_tokens;
|
|
stats.total_cache_creation_tokens += usage.cache_creation_input_tokens;
|
|
stats.total_thinking_output_tokens += usage.thinking_output_tokens;
|
|
stats.total_response_output_tokens += usage.response_output_tokens;
|
|
|
|
// Per-model breakdown
|
|
if let Some(ref model_name) = usage.model {
|
|
let model_stats = stats.per_model.entry(model_name.clone()).or_default();
|
|
model_stats.requests += 1;
|
|
model_stats.input_tokens += usage.input_tokens;
|
|
model_stats.output_tokens += usage.output_tokens;
|
|
model_stats.cache_read_tokens += usage.cache_read_input_tokens;
|
|
model_stats.cache_creation_tokens += usage.cache_creation_input_tokens;
|
|
}
|
|
}
|
|
|
|
// Store latest usage for the cascade (if we can identify it).
|
|
//
|
|
// Merge logic for v1internal thinking summaries:
|
|
// The LS makes TWO Google API calls per thinking request:
|
|
// Call 1: response + thinking token count (thinking_output_tokens > 0, no thinking text)
|
|
// Call 2: thinking summary text (thinking_output_tokens == 0, response_text has the summary)
|
|
//
|
|
// When Call 2 arrives, we merge its response_text as thinking_text into Call 1's usage.
|
|
let key = cascade_id.map_or_else(|| "_latest".to_string(), |s| s.to_string());
|
|
let mut latest = self.latest_usage.write().await;
|
|
|
|
if let Some(existing) = latest.get_mut(&key) {
|
|
if existing.thinking_output_tokens > 0
|
|
&& existing.thinking_text.is_none()
|
|
&& usage.thinking_output_tokens == 0
|
|
&& usage.response_text.is_some()
|
|
{
|
|
// Call 2: thinking summary — merge into existing Call 1 usage
|
|
existing.thinking_text = usage.response_text;
|
|
debug!(
|
|
thinking_text_len = existing.thinking_text.as_ref().map_or(0, |t| t.len()),
|
|
"MITM: merged thinking summary text into existing usage"
|
|
);
|
|
} else {
|
|
// Normal case: replace existing usage
|
|
latest.insert(key, usage);
|
|
}
|
|
} else {
|
|
latest.insert(key, usage);
|
|
}
|
|
|
|
// Evict old entries to prevent unbounded memory growth
|
|
const MAX_ENTRIES: usize = 500;
|
|
if latest.len() > MAX_ENTRIES {
|
|
let oldest_key = latest
|
|
.iter()
|
|
.min_by_key(|(_, v)| v.captured_at)
|
|
.map(|(k, _)| k.clone());
|
|
if let Some(key) = oldest_key {
|
|
latest.remove(&key);
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Peek at usage data for a cascade without consuming it.
|
|
pub async fn peek_usage(&self, cascade_id: &str) -> Option<ApiUsage> {
|
|
let latest = self.latest_usage.read().await;
|
|
latest.get(cascade_id).cloned()
|
|
}
|
|
|
|
/// Only returns exact cascade_id matches — no cross-cascade fallback.
|
|
pub async fn take_usage(&self, cascade_id: &str) -> Option<ApiUsage> {
|
|
let mut latest = self.latest_usage.write().await;
|
|
latest.remove(cascade_id)
|
|
}
|
|
|
|
/// Get aggregate stats.
|
|
pub async fn stats(&self) -> MitmStats {
|
|
self.stats.read().await.clone()
|
|
}
|
|
|
|
// ── Function call capture ────────────────────────────────────────────
|
|
|
|
/// Record a captured function call from Google's response.
|
|
pub async fn record_function_call(&self, cascade_id: Option<&str>, fc: CapturedFunctionCall) {
|
|
let key = cascade_id.map_or_else(|| "_latest".to_string(), |s| s.to_string());
|
|
info!(
|
|
cascade = %key,
|
|
tool = %fc.name,
|
|
args = %fc.args,
|
|
"MITM store: captured functionCall"
|
|
);
|
|
let mut pending = self.pending_function_calls.write().await;
|
|
pending.entry(key).or_default().push(fc);
|
|
}
|
|
|
|
/// Take pending function calls for a specific cascade.
|
|
///
|
|
/// Priority: exact cascade_id → `_latest` → any key.
|
|
pub async fn take_function_calls(&self, cascade_id: &str) -> Option<Vec<CapturedFunctionCall>> {
|
|
let mut pending = self.pending_function_calls.write().await;
|
|
|
|
// 1. Exact cascade match
|
|
if let Some(result) = pending.remove(cascade_id) {
|
|
return Some(result);
|
|
}
|
|
|
|
// 2. Fallback to _latest
|
|
if let Some(result) = pending.remove("_latest") {
|
|
return Some(result);
|
|
}
|
|
|
|
// 3. Last resort: any key
|
|
if let Some(key) = pending.keys().next().cloned() {
|
|
return pending.remove(&key);
|
|
}
|
|
|
|
None
|
|
}
|
|
|
|
/// Take any pending function calls (ignoring cascade ID).
|
|
pub async fn take_any_function_calls(&self) -> Option<Vec<CapturedFunctionCall>> {
|
|
let mut pending = self.pending_function_calls.write().await;
|
|
let result = pending.remove("_latest");
|
|
if result.is_some() {
|
|
return result;
|
|
}
|
|
if let Some(key) = pending.keys().next().cloned() {
|
|
return pending.remove(&key);
|
|
}
|
|
None
|
|
}
|
|
|
|
/// Peek at the thought_signatures of recently captured function calls.
|
|
/// Returns a map of function_name → thought_signature (non-destructive).
|
|
pub async fn peek_thought_signatures(&self) -> std::collections::HashMap<String, String> {
|
|
let pending = self.pending_function_calls.read().await;
|
|
let mut sigs = std::collections::HashMap::new();
|
|
for calls in pending.values() {
|
|
for fc in calls {
|
|
if let Some(ref sig) = fc.thought_signature {
|
|
sigs.insert(fc.name.clone(), sig.clone());
|
|
}
|
|
}
|
|
}
|
|
sigs
|
|
}
|
|
|
|
// ── Legacy direct response capture (search.rs fallback) ──────────────
|
|
|
|
/// Set (replace) the captured response text.
|
|
pub async fn set_response_text(&self, text: &str) {
|
|
*self.captured_response_text.write().await = Some(text.to_string());
|
|
}
|
|
|
|
/// Take the captured response text (consumes it).
|
|
pub async fn take_response_text(&self) -> Option<String> {
|
|
self.captured_response_text.write().await.take()
|
|
}
|
|
|
|
/// Clear stale legacy response state.
|
|
pub async fn clear_response_async(&self) {
|
|
*self.captured_response_text.write().await = None;
|
|
}
|
|
|
|
// ── Grounding metadata capture ──────────────────────────────────────
|
|
|
|
/// Store captured grounding metadata from API response.
|
|
pub async fn set_grounding(&self, meta: serde_json::Value) {
|
|
*self.captured_grounding.write().await = Some(meta);
|
|
}
|
|
|
|
/// Take (consume) captured grounding metadata.
|
|
#[allow(dead_code)]
|
|
pub async fn take_grounding(&self) -> Option<serde_json::Value> {
|
|
self.captured_grounding.write().await.take()
|
|
}
|
|
|
|
/// Peek at grounding metadata without consuming.
|
|
#[allow(dead_code)]
|
|
pub async fn peek_grounding(&self) -> Option<serde_json::Value> {
|
|
self.captured_grounding.read().await.clone()
|
|
}
|
|
|
|
// ── Compat shims for streaming tool-call loops ──────────────────────
|
|
|
|
/// Update the event channel on an existing request context,
|
|
/// or re-register a minimal context if it was already consumed by `take_request`.
|
|
///
|
|
/// This is critical for thinking-only intermediate responses: the MITM proxy
|
|
/// consumes the context via `take_request`, but the handler needs to re-install
|
|
/// a channel for the LS's follow-up request.
|
|
pub async fn set_channel(&self, cascade_id: &str, tx: mpsc::Sender<MitmEvent>) {
|
|
let updated = self
|
|
.update_request(cascade_id, |ctx| {
|
|
ctx.event_channel = tx.clone();
|
|
})
|
|
.await;
|
|
if !updated {
|
|
// Context was already consumed — re-register a minimal one
|
|
// so the MITM proxy can match the follow-up request.
|
|
let gate = std::sync::Arc::new(tokio::sync::Notify::new());
|
|
self.register_request(RequestContext {
|
|
cascade_id: cascade_id.to_string(),
|
|
pending_user_text: String::new(),
|
|
event_channel: tx,
|
|
generation_params: None,
|
|
pending_image: None,
|
|
tools: None,
|
|
tool_config: None,
|
|
pending_tool_results: Vec::new(),
|
|
tool_rounds: Vec::new(),
|
|
last_function_calls: Vec::new(),
|
|
call_id_to_name: std::collections::HashMap::new(),
|
|
created_at: std::time::Instant::now(),
|
|
gate,
|
|
trace_handle: None,
|
|
trace_turn: 0,
|
|
})
|
|
.await;
|
|
tracing::debug!(
|
|
cascade = cascade_id,
|
|
"set_channel: re-registered minimal context (original was consumed)"
|
|
);
|
|
}
|
|
}
|
|
|
|
/// No-op. Upstream errors are now delivered through the event channel.
|
|
/// Kept for API handler compatibility.
|
|
pub async fn clear_upstream_error(&self) {
|
|
// Intentionally empty — errors flow through MitmEvent::UpstreamError
|
|
}
|
|
|
|
/// Returns None. Upstream errors are now captured and delivered via the
|
|
/// per-request event channel rather than stored globally.
|
|
pub async fn take_upstream_error(&self) -> Option<UpstreamError> {
|
|
None
|
|
}
|
|
|
|
/// Store a call_id → function_name mapping in the request context.
|
|
/// Used by streaming tool-call loops when the model returns function calls.
|
|
pub async fn register_call_id(&self, cascade_id: &str, call_id: String, name: String) {
|
|
self.update_request(cascade_id, |ctx| {
|
|
ctx.call_id_to_name.insert(call_id, name);
|
|
})
|
|
.await;
|
|
}
|
|
}
|