refactor: decompose large functions and remove dead code
- Decompose modify_request() into 7 single-responsibility helpers - Decompose handle_http_over_tls(): extract read_full_request, dispatch_stream_events - Promote connect_upstream/resolve_upstream to module-level functions - Split standalone.rs (1238 lines) into 4 submodules: standalone/mod.rs, spawn.rs, discovery.rs, stub.rs - Extract proto wire primitives into proto/wire.rs - Remove 6 dead MitmStore methods - Remove dead SessionResult, DEFAULT_SESSION, get_or_create - Remove dead decode_varint_at, extract_conversation_id - Clean all unused imports across 10 files - Suppress structural dead_code warnings on deserialization fields Warnings: 20 -> 0. All 43 tests pass.
This commit is contained in:
@@ -1,13 +1,13 @@
|
||||
//! Shared store for intercepted API usage data.
|
||||
//!
|
||||
//! The MITM proxy writes usage data here; the API handlers read from it.
|
||||
//! When custom tools are active, the MITM proxy sends real-time events
|
||||
//! through a channel instead of writing to shared state.
|
||||
//! Per-request state is stored in `RequestContext`, keyed by cascade ID.
|
||||
//! The MITM proxy looks up the context when intercepting LS requests,
|
||||
//! enabling concurrent request processing without global locks.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
use tokio::sync::{mpsc, RwLock};
|
||||
use tracing::{debug, info};
|
||||
|
||||
@@ -52,6 +52,10 @@ pub struct ApiUsage {
|
||||
pub struct CapturedFunctionCall {
|
||||
pub name: String,
|
||||
pub args: serde_json::Value,
|
||||
/// Google's thought signature — required when injecting functionCall back
|
||||
/// into conversation history. Without it, Google returns INVALID_ARGUMENT.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub thought_signature: Option<String>,
|
||||
pub captured_at: u64,
|
||||
}
|
||||
|
||||
@@ -128,6 +132,25 @@ pub struct GenerationParams {
|
||||
pub google_search: bool,
|
||||
}
|
||||
|
||||
/// Cached context from turn 0 of a cascade.
|
||||
///
|
||||
/// On the first turn, the MITM proxy consumes the `RequestContext` and builds
|
||||
/// a `ToolContext`. On subsequent turns (tool-call loops), the `RequestContext`
|
||||
/// is gone. This cache stores the essential fields so we can rebuild a lite
|
||||
/// `ToolContext` on every turn — ensuring the model always sees the real user
|
||||
/// text and has access to custom tools.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CascadeCache {
|
||||
/// The real user text (used to replace the "." dot prompt).
|
||||
pub user_text: String,
|
||||
/// Custom tool definitions (Gemini format).
|
||||
pub tools: Option<Vec<serde_json::Value>>,
|
||||
/// Custom tool config.
|
||||
pub tool_config: Option<serde_json::Value>,
|
||||
/// Client generation parameters.
|
||||
pub generation_params: Option<GenerationParams>,
|
||||
}
|
||||
|
||||
// ─── Channel-based event pipeline ────────────────────────────────────────────
|
||||
|
||||
/// Events sent from the MITM proxy to API handlers through a per-request channel.
|
||||
@@ -146,15 +169,53 @@ pub enum MitmEvent {
|
||||
/// Google API returned an error.
|
||||
UpstreamError(UpstreamError),
|
||||
/// Grounding metadata (search results) from the response.
|
||||
#[allow(dead_code)]
|
||||
Grounding(serde_json::Value),
|
||||
/// Token usage data from the response.
|
||||
Usage(ApiUsage),
|
||||
}
|
||||
|
||||
// ─── Per-request context ─────────────────────────────────────────────────────
|
||||
|
||||
/// All per-request state. Keyed by cascade ID in `MitmStore.pending_requests`.
|
||||
///
|
||||
/// API handlers build this before `send_message`, and the MITM proxy consumes
|
||||
/// it when the LS's outbound request is intercepted.
|
||||
#[derive(Debug)]
|
||||
pub struct RequestContext {
|
||||
/// Cascade ID this context belongs to.
|
||||
pub cascade_id: String,
|
||||
/// Real user text for MITM injection (LS receives "." instead).
|
||||
pub pending_user_text: String,
|
||||
/// Event channel for real-time streaming from MITM → API handler.
|
||||
/// Only present when custom tools are active.
|
||||
pub event_channel: Option<mpsc::Sender<MitmEvent>>,
|
||||
/// Client-specified generation parameters (temperature, top_p, etc.).
|
||||
pub generation_params: Option<GenerationParams>,
|
||||
/// Image to inject into the Google API request.
|
||||
pub pending_image: Option<PendingImage>,
|
||||
/// Gemini-format tool declarations for MITM injection.
|
||||
pub tools: Option<Vec<serde_json::Value>>,
|
||||
/// Gemini-format toolConfig.
|
||||
pub tool_config: Option<serde_json::Value>,
|
||||
/// Pending tool results to inject as functionResponse.
|
||||
pub pending_tool_results: Vec<PendingToolResult>,
|
||||
/// Multi-round tool call history for history rewriting.
|
||||
pub tool_rounds: Vec<ToolRound>,
|
||||
/// Last captured function calls for history rewriting.
|
||||
pub last_function_calls: Vec<CapturedFunctionCall>,
|
||||
/// Mapping call_id → function name for tool result routing.
|
||||
pub call_id_to_name: HashMap<String, String>,
|
||||
/// When this context was created (for TTL cleanup).
|
||||
pub created_at: Instant,
|
||||
}
|
||||
|
||||
// ─── MitmStore ───────────────────────────────────────────────────────────────
|
||||
|
||||
/// Thread-safe store for intercepted data.
|
||||
///
|
||||
/// Keyed by a unique request ID that we can correlate with cascade operations.
|
||||
/// In practice, we use the cascade ID + a sequence number.
|
||||
/// Per-request state lives in `pending_requests`, keyed by cascade ID.
|
||||
/// Global state (usage stats, function call capture) remains shared.
|
||||
#[derive(Clone)]
|
||||
pub struct MitmStore {
|
||||
/// Most recent usage per cascade ID.
|
||||
@@ -163,62 +224,24 @@ pub struct MitmStore {
|
||||
stats: Arc<RwLock<MitmStats>>,
|
||||
/// Pending function calls captured from Google responses.
|
||||
/// Key: cascade hint or "_latest". Value: list of function calls.
|
||||
/// Used by the non-tool LS path (normal sync responses).
|
||||
pending_function_calls: Arc<RwLock<HashMap<String, Vec<CapturedFunctionCall>>>>,
|
||||
/// Set when the MITM forwards the first LLM request with custom tools.
|
||||
/// Blocks ALL subsequent LS requests until the API handler clears it.
|
||||
request_in_flight: Arc<AtomicBool>,
|
||||
|
||||
// ── Channel-based event pipeline (replaces old polling) ──────────────
|
||||
/// Active channel sender for the current tool-path request.
|
||||
/// When present, the MITM proxy sends events through this instead of
|
||||
/// writing to shared state. The channel's existence = request in-flight.
|
||||
active_channel: Arc<RwLock<Option<mpsc::Sender<MitmEvent>>>>,
|
||||
// ── Per-request state (keyed by cascade ID) ──────────────────────────
|
||||
/// Active request contexts. API handlers register before send_message,
|
||||
/// MITM proxy consumes when intercepting the LS request.
|
||||
pending_requests: Arc<RwLock<HashMap<String, RequestContext>>>,
|
||||
|
||||
// ── Tool call support ────────────────────────────────────────────────
|
||||
/// Active tool definitions (Gemini format) for MITM injection.
|
||||
active_tools: Arc<RwLock<Option<Vec<serde_json::Value>>>>,
|
||||
/// Active tool config (Gemini toolConfig format).
|
||||
active_tool_config: Arc<RwLock<Option<serde_json::Value>>>,
|
||||
/// Pending tool results for MITM to inject as functionResponse.
|
||||
pending_tool_results: Arc<RwLock<Vec<PendingToolResult>>>,
|
||||
/// Mapping call_id → function name for tool result routing.
|
||||
call_id_to_name: Arc<RwLock<HashMap<String, String>>>,
|
||||
/// Last captured function calls (for conversation history rewriting).
|
||||
last_function_calls: Arc<RwLock<Vec<CapturedFunctionCall>>>,
|
||||
/// Multi-round tool call history for correct per-turn history rewriting.
|
||||
/// Set by completions/responses handler, consumed by modify_request.
|
||||
tool_rounds: Arc<RwLock<Vec<ToolRound>>>,
|
||||
|
||||
// ── Cascade correlation ──────────────────────────────────────────────
|
||||
/// Active cascade ID set by the API layer before sending a message.
|
||||
/// Used by the MITM proxy to correlate intercepted traffic to cascades.
|
||||
active_cascade_id: Arc<RwLock<Option<String>>>,
|
||||
/// Cached context from turn 0, keyed by cascade ID.
|
||||
/// Used to rebuild ToolContext on subsequent turns of the same cascade.
|
||||
cascade_cache: Arc<RwLock<HashMap<String, CascadeCache>>>,
|
||||
|
||||
// ── Legacy direct response capture (used by search.rs) ───────────────
|
||||
/// Captured response text from MITM. Used as fallback by search endpoint.
|
||||
captured_response_text: Arc<RwLock<Option<String>>>,
|
||||
|
||||
// ── Generation parameters for MITM injection ─────────────────────────
|
||||
/// Client-specified sampling parameters to inject into Google API requests.
|
||||
generation_params: Arc<RwLock<Option<GenerationParams>>>,
|
||||
|
||||
// ── Grounding metadata capture ──────────────────────────────────────
|
||||
/// Captured grounding metadata from Google API responses (search results).
|
||||
captured_grounding: Arc<RwLock<Option<serde_json::Value>>>,
|
||||
|
||||
// ── Pending image for MITM injection ─────────────────────────────────
|
||||
/// Image to inject into the next Google API request via MITM.
|
||||
pending_image: Arc<RwLock<Option<PendingImage>>>,
|
||||
|
||||
// ── Upstream error capture (legacy, used when no channel) ────────────
|
||||
/// Error from Google's API, captured by MITM for forwarding to client.
|
||||
upstream_error: Arc<RwLock<Option<UpstreamError>>>,
|
||||
|
||||
// ── Standard LS input: real user text for MITM injection ─────────────
|
||||
/// The real user text to inject into the Google API request.
|
||||
/// API handlers store this before sending a dummy prompt to the LS.
|
||||
pending_user_text: Arc<RwLock<Option<String>>>,
|
||||
}
|
||||
|
||||
/// Aggregate statistics across all intercepted traffic.
|
||||
@@ -251,24 +274,106 @@ impl MitmStore {
|
||||
latest_usage: Arc::new(RwLock::new(HashMap::new())),
|
||||
stats: Arc::new(RwLock::new(MitmStats::default())),
|
||||
pending_function_calls: Arc::new(RwLock::new(HashMap::new())),
|
||||
request_in_flight: Arc::new(AtomicBool::new(false)),
|
||||
active_channel: Arc::new(RwLock::new(None)),
|
||||
active_tools: Arc::new(RwLock::new(None)),
|
||||
active_tool_config: Arc::new(RwLock::new(None)),
|
||||
pending_tool_results: Arc::new(RwLock::new(Vec::new())),
|
||||
call_id_to_name: Arc::new(RwLock::new(HashMap::new())),
|
||||
last_function_calls: Arc::new(RwLock::new(Vec::new())),
|
||||
tool_rounds: Arc::new(RwLock::new(Vec::new())),
|
||||
active_cascade_id: Arc::new(RwLock::new(None)),
|
||||
pending_requests: Arc::new(RwLock::new(HashMap::new())),
|
||||
cascade_cache: Arc::new(RwLock::new(HashMap::new())),
|
||||
captured_response_text: Arc::new(RwLock::new(None)),
|
||||
generation_params: Arc::new(RwLock::new(None)),
|
||||
captured_grounding: Arc::new(RwLock::new(None)),
|
||||
pending_image: Arc::new(RwLock::new(None)),
|
||||
upstream_error: Arc::new(RwLock::new(None)),
|
||||
pending_user_text: Arc::new(RwLock::new(None)),
|
||||
}
|
||||
}
|
||||
|
||||
// ── Per-request context management ───────────────────────────────────
|
||||
|
||||
/// Register a request context for a cascade. Called by API handlers
|
||||
/// before `send_message` so the MITM proxy can find it.
|
||||
pub async fn register_request(&self, ctx: RequestContext) {
|
||||
let cascade_id = ctx.cascade_id.clone();
|
||||
info!(cascade = %cascade_id, "Registered request context");
|
||||
self.pending_requests.write().await.insert(cascade_id, ctx);
|
||||
}
|
||||
|
||||
/// Take (consume) the request context for a cascade.
|
||||
/// Called by the MITM proxy when intercepting the LS's outbound request.
|
||||
pub async fn take_request(&self, cascade_id: &str) -> Option<RequestContext> {
|
||||
let ctx = self.pending_requests.write().await.remove(cascade_id);
|
||||
if ctx.is_some() {
|
||||
debug!(cascade = %cascade_id, "Took request context");
|
||||
}
|
||||
ctx
|
||||
}
|
||||
|
||||
/// Take the most recently registered request context (by creation time).
|
||||
/// Fallback when cascade_id can't be extracted from the Google API request.
|
||||
pub async fn take_latest_request(&self) -> Option<RequestContext> {
|
||||
let mut pending = self.pending_requests.write().await;
|
||||
if pending.is_empty() {
|
||||
return None;
|
||||
}
|
||||
// Find the most recently created request
|
||||
let latest_key = pending
|
||||
.iter()
|
||||
.max_by_key(|(_, ctx)| ctx.created_at)
|
||||
.map(|(k, _)| k.clone());
|
||||
if let Some(key) = latest_key {
|
||||
let ctx = pending.remove(&key);
|
||||
if ctx.is_some() {
|
||||
debug!(cascade = %key, "Took latest request context (fallback)");
|
||||
}
|
||||
ctx
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
/// Update a request context in-place. Returns false if not found.
|
||||
pub async fn update_request<F>(&self, cascade_id: &str, updater: F) -> bool
|
||||
where
|
||||
F: FnOnce(&mut RequestContext),
|
||||
{
|
||||
let mut map = self.pending_requests.write().await;
|
||||
if let Some(ctx) = map.get_mut(cascade_id) {
|
||||
updater(ctx);
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove a request context (cleanup after response is complete).
|
||||
pub async fn remove_request(&self, cascade_id: &str) {
|
||||
if self.pending_requests.write().await.remove(cascade_id).is_some() {
|
||||
debug!(cascade = %cascade_id, "Removed request context");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
// ── Cascade cache (turn 0 context for re-injection on turn 1+) ──────
|
||||
|
||||
/// Cache the essential context from turn 0 so it can be re-used on
|
||||
/// subsequent turns of the same cascade.
|
||||
pub async fn cache_cascade(&self, cascade_id: &str, cache: CascadeCache) {
|
||||
debug!(cascade = %cascade_id, user_text_len = cache.user_text.len(),
|
||||
has_tools = cache.tools.is_some(),
|
||||
"Cached cascade context for subsequent turns");
|
||||
self.cascade_cache.write().await.insert(cascade_id.to_string(), cache);
|
||||
}
|
||||
|
||||
/// Get cached context for a cascade (non-consuming — needed on every turn).
|
||||
pub async fn get_cascade_cache(&self, cascade_id: &str) -> Option<CascadeCache> {
|
||||
self.cascade_cache.read().await.get(cascade_id).cloned()
|
||||
}
|
||||
|
||||
/// Check if a cascade has been processed (turn 0 complete).
|
||||
pub async fn has_cascade_cache(&self, cascade_id: &str) -> bool {
|
||||
self.cascade_cache.read().await.contains_key(cascade_id)
|
||||
}
|
||||
|
||||
|
||||
|
||||
// ── Usage recording ──────────────────────────────────────────────────
|
||||
|
||||
/// Record a completed API exchange with usage data.
|
||||
pub async fn record_usage(&self, cascade_id: Option<&str>, usage: ApiUsage) {
|
||||
debug!(
|
||||
@@ -314,13 +419,7 @@ impl MitmStore {
|
||||
// Call 2: thinking summary text (thinking_output_tokens == 0, response_text has the summary)
|
||||
//
|
||||
// When Call 2 arrives, we merge its response_text as thinking_text into Call 1's usage.
|
||||
let key = if let Some(cid) = cascade_id {
|
||||
cid.to_string()
|
||||
} else if let Some(active) = self.active_cascade_id.read().await.as_ref() {
|
||||
active.clone()
|
||||
} else {
|
||||
"_latest".to_string()
|
||||
};
|
||||
let key = cascade_id.map_or_else(|| "_latest".to_string(), |s| s.to_string());
|
||||
let mut latest = self.latest_usage.write().await;
|
||||
|
||||
if let Some(existing) = latest.get_mut(&key) {
|
||||
@@ -346,7 +445,6 @@ impl MitmStore {
|
||||
// Evict old entries to prevent unbounded memory growth
|
||||
const MAX_ENTRIES: usize = 500;
|
||||
if latest.len() > MAX_ENTRIES {
|
||||
// Find the oldest entry by captured_at and remove it
|
||||
let oldest_key = latest
|
||||
.iter()
|
||||
.min_by_key(|(_, v)| v.captured_at)
|
||||
@@ -357,18 +455,13 @@ impl MitmStore {
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the latest usage for a cascade, consuming it (one-shot read).
|
||||
///
|
||||
/// Peek at usage data for a cascade without consuming it.
|
||||
/// Used to check if thinking text has been merged before taking.
|
||||
pub async fn peek_usage(&self, cascade_id: &str) -> Option<ApiUsage> {
|
||||
let latest = self.latest_usage.read().await;
|
||||
latest.get(cascade_id).cloned()
|
||||
}
|
||||
|
||||
/// Only returns exact cascade_id matches — no cross-cascade fallback.
|
||||
/// The `_latest` key is only consumed when the caller explicitly requests it
|
||||
/// (i.e., when the MITM couldn't identify the cascade).
|
||||
pub async fn take_usage(&self, cascade_id: &str) -> Option<ApiUsage> {
|
||||
let mut latest = self.latest_usage.write().await;
|
||||
latest.remove(cascade_id)
|
||||
@@ -379,19 +472,11 @@ impl MitmStore {
|
||||
self.stats.read().await.clone()
|
||||
}
|
||||
|
||||
// ── Function call capture ────────────────────────────────────────────
|
||||
|
||||
/// Record a captured function call from Google's response.
|
||||
///
|
||||
/// Falls back to `active_cascade_id` (set by the API handler) when no
|
||||
/// cascade hint is available from the request body, matching
|
||||
/// `record_usage`'s fallback behavior for consistent correlation.
|
||||
pub async fn record_function_call(&self, cascade_id: Option<&str>, fc: CapturedFunctionCall) {
|
||||
let key = if let Some(cid) = cascade_id {
|
||||
cid.to_string()
|
||||
} else if let Some(active) = self.active_cascade_id.read().await.as_ref() {
|
||||
active.clone()
|
||||
} else {
|
||||
"_latest".to_string()
|
||||
};
|
||||
let key = cascade_id.map_or_else(|| "_latest".to_string(), |s| s.to_string());
|
||||
info!(
|
||||
cascade = %key,
|
||||
tool = %fc.name,
|
||||
@@ -404,9 +489,7 @@ impl MitmStore {
|
||||
|
||||
/// Take pending function calls for a specific cascade.
|
||||
///
|
||||
/// Priority: exact cascade_id → active_cascade_id → `_latest` → any key.
|
||||
/// This prevents cross-cascade contamination when multiple requests are
|
||||
/// in-flight simultaneously.
|
||||
/// Priority: exact cascade_id → `_latest` → any key.
|
||||
pub async fn take_function_calls(&self, cascade_id: &str) -> Option<Vec<CapturedFunctionCall>> {
|
||||
let mut pending = self.pending_function_calls.write().await;
|
||||
|
||||
@@ -415,21 +498,12 @@ impl MitmStore {
|
||||
return Some(result);
|
||||
}
|
||||
|
||||
// 2. Active cascade (set by API handler)
|
||||
if let Some(active) = self.active_cascade_id.read().await.as_ref() {
|
||||
if active != cascade_id {
|
||||
if let Some(result) = pending.remove(active.as_str()) {
|
||||
return Some(result);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Fallback to _latest
|
||||
// 2. Fallback to _latest
|
||||
if let Some(result) = pending.remove("_latest") {
|
||||
return Some(result);
|
||||
}
|
||||
|
||||
// 4. Last resort: any key
|
||||
// 3. Last resort: any key
|
||||
if let Some(key) = pending.keys().next().cloned() {
|
||||
return pending.remove(&key);
|
||||
}
|
||||
@@ -438,7 +512,6 @@ impl MitmStore {
|
||||
}
|
||||
|
||||
/// Take any pending function calls (ignoring cascade ID).
|
||||
/// Legacy method — prefer `take_function_calls(cascade_id)` for proper correlation.
|
||||
pub async fn take_any_function_calls(&self) -> Option<Vec<CapturedFunctionCall>> {
|
||||
let mut pending = self.pending_function_calls.write().await;
|
||||
let result = pending.remove("_latest");
|
||||
@@ -451,114 +524,24 @@ impl MitmStore {
|
||||
None
|
||||
}
|
||||
|
||||
// ── Channel-based event pipeline ─────────────────────────────────────
|
||||
|
||||
/// Install a channel sender for the current tool-path request.
|
||||
/// The MITM proxy will send events through this channel.
|
||||
pub async fn set_channel(&self, tx: mpsc::Sender<MitmEvent>) {
|
||||
*self.active_channel.write().await = Some(tx);
|
||||
// NOTE: Do NOT set request_in_flight here. The MITM proxy's
|
||||
// try_mark_request_in_flight() is the sole setter — setting it
|
||||
// here causes compare_exchange(false,true) to always fail,
|
||||
// blocking every real LS request.
|
||||
}
|
||||
|
||||
/// Take the active channel sender (used by MITM proxy to grab it).
|
||||
/// Returns None if no channel is active.
|
||||
pub async fn take_channel(&self) -> Option<mpsc::Sender<MitmEvent>> {
|
||||
self.active_channel.write().await.take()
|
||||
}
|
||||
|
||||
|
||||
/// Drop the active channel and clear in-flight state.
|
||||
/// Called when the API handler is done with the current request.
|
||||
pub async fn drop_channel(&self) {
|
||||
*self.active_channel.write().await = None;
|
||||
self.request_in_flight.store(false, Ordering::SeqCst);
|
||||
}
|
||||
|
||||
// ── Tool context methods ─────────────────────────────────────────────
|
||||
|
||||
/// Set active tool definitions (already in Gemini format).
|
||||
pub async fn set_tools(&self, tools: Vec<serde_json::Value>) {
|
||||
*self.active_tools.write().await = Some(tools);
|
||||
}
|
||||
|
||||
/// Get active tool definitions.
|
||||
pub async fn get_tools(&self) -> Option<Vec<serde_json::Value>> {
|
||||
self.active_tools.read().await.clone()
|
||||
}
|
||||
|
||||
/// Clear active tool definitions.
|
||||
pub async fn clear_tools(&self) {
|
||||
*self.active_tools.write().await = None;
|
||||
*self.active_tool_config.write().await = None;
|
||||
// Also clear accumulated tool rounds to prevent stale data
|
||||
self.tool_rounds.write().await.clear();
|
||||
}
|
||||
|
||||
/// Set active tool config (Gemini toolConfig format).
|
||||
pub async fn set_tool_config(&self, config: serde_json::Value) {
|
||||
*self.active_tool_config.write().await = Some(config);
|
||||
}
|
||||
|
||||
/// Get active tool config.
|
||||
pub async fn get_tool_config(&self) -> Option<serde_json::Value> {
|
||||
self.active_tool_config.read().await.clone()
|
||||
}
|
||||
|
||||
/// Add a pending tool result for MITM injection.
|
||||
pub async fn add_tool_result(&self, result: PendingToolResult) {
|
||||
info!(name = %result.name, "Storing pending tool result");
|
||||
self.pending_tool_results.write().await.push(result);
|
||||
}
|
||||
|
||||
/// Take (consume) all pending tool results.
|
||||
pub async fn take_tool_results(&self) -> Vec<PendingToolResult> {
|
||||
std::mem::take(&mut *self.pending_tool_results.write().await)
|
||||
}
|
||||
|
||||
/// Register a call_id → function name mapping.
|
||||
pub async fn register_call_id(&self, call_id: String, name: String) {
|
||||
self.call_id_to_name.write().await.insert(call_id, name);
|
||||
}
|
||||
|
||||
/// Look up function name by call_id.
|
||||
pub async fn lookup_call_id(&self, call_id: &str) -> Option<String> {
|
||||
self.call_id_to_name.read().await.get(call_id).cloned()
|
||||
}
|
||||
|
||||
/// Save the last captured function calls (for history rewriting).
|
||||
pub async fn set_last_function_calls(&self, calls: Vec<CapturedFunctionCall>) {
|
||||
*self.last_function_calls.write().await = calls;
|
||||
}
|
||||
|
||||
/// Get the last captured function calls.
|
||||
pub async fn get_last_function_calls(&self) -> Vec<CapturedFunctionCall> {
|
||||
self.last_function_calls.read().await.clone()
|
||||
}
|
||||
|
||||
/// Store multi-round tool call history for correct per-turn history rewriting.
|
||||
pub async fn set_tool_rounds(&self, rounds: Vec<ToolRound>) {
|
||||
*self.tool_rounds.write().await = rounds;
|
||||
}
|
||||
|
||||
/// Take (consume) multi-round tool call history.
|
||||
pub async fn take_tool_rounds(&self) -> Vec<ToolRound> {
|
||||
std::mem::take(&mut *self.tool_rounds.write().await)
|
||||
}
|
||||
|
||||
/// Get (non-destructive clone) multi-round tool call history.
|
||||
/// Used by proxy.rs to read rounds without consuming them, so they
|
||||
/// persist across multiple LS requests in the same cascade.
|
||||
pub async fn get_tool_rounds(&self) -> Vec<ToolRound> {
|
||||
self.tool_rounds.read().await.clone()
|
||||
/// Peek at the thought_signatures of recently captured function calls.
|
||||
/// Returns a map of function_name → thought_signature (non-destructive).
|
||||
pub async fn peek_thought_signatures(&self) -> std::collections::HashMap<String, String> {
|
||||
let pending = self.pending_function_calls.read().await;
|
||||
let mut sigs = std::collections::HashMap::new();
|
||||
for calls in pending.values() {
|
||||
for fc in calls {
|
||||
if let Some(ref sig) = fc.thought_signature {
|
||||
sigs.insert(fc.name.clone(), sig.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
sigs
|
||||
}
|
||||
|
||||
// ── Legacy direct response capture (search.rs fallback) ──────────────
|
||||
|
||||
/// Set (replace) the captured response text.
|
||||
/// Used by MITM proxy for non-channel path (search endpoint fallback).
|
||||
pub async fn set_response_text(&self, text: &str) {
|
||||
*self.captured_response_text.write().await = Some(text.to_string());
|
||||
}
|
||||
@@ -568,71 +551,11 @@ impl MitmStore {
|
||||
self.captured_response_text.write().await.take()
|
||||
}
|
||||
|
||||
/// Clear stale state between requests.
|
||||
/// Drops any active channel and clears in-flight flags.
|
||||
/// Clear stale legacy response state.
|
||||
pub async fn clear_response_async(&self) {
|
||||
self.request_in_flight.store(false, Ordering::SeqCst);
|
||||
*self.active_channel.write().await = None;
|
||||
*self.captured_response_text.write().await = None;
|
||||
}
|
||||
|
||||
/// Atomically try to mark request as in-flight.
|
||||
/// Returns true if this caller won the race (was first to set it).
|
||||
/// Returns false if already in-flight (someone else set it first).
|
||||
pub fn try_mark_request_in_flight(&self) -> bool {
|
||||
self.request_in_flight
|
||||
.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
|
||||
.is_ok()
|
||||
}
|
||||
|
||||
/// Check if a request is currently in-flight.
|
||||
#[allow(dead_code)]
|
||||
pub fn is_request_in_flight(&self) -> bool {
|
||||
self.request_in_flight.load(Ordering::SeqCst)
|
||||
}
|
||||
|
||||
/// Clear the in-flight flag so the LS can make follow-up requests.
|
||||
pub fn clear_request_in_flight(&self) {
|
||||
self.request_in_flight.store(false, Ordering::SeqCst);
|
||||
}
|
||||
|
||||
// ── Cascade correlation ──────────────────────────────────────────────
|
||||
|
||||
/// Set the active cascade ID (called by API handlers before sending a message).
|
||||
/// The MITM proxy will use this to correlate intercepted traffic.
|
||||
pub async fn set_active_cascade(&self, cascade_id: &str) {
|
||||
*self.active_cascade_id.write().await = Some(cascade_id.to_string());
|
||||
}
|
||||
|
||||
/// Get the active cascade ID.
|
||||
#[allow(dead_code)]
|
||||
pub async fn get_active_cascade(&self) -> Option<String> {
|
||||
self.active_cascade_id.read().await.clone()
|
||||
}
|
||||
|
||||
/// Clear the active cascade ID (called after response is complete).
|
||||
#[allow(dead_code)]
|
||||
pub async fn clear_active_cascade(&self) {
|
||||
*self.active_cascade_id.write().await = None;
|
||||
}
|
||||
|
||||
// ── Generation parameters ────────────────────────────────────────────
|
||||
|
||||
/// Store client-specified generation parameters for MITM injection.
|
||||
pub async fn set_generation_params(&self, params: GenerationParams) {
|
||||
*self.generation_params.write().await = Some(params);
|
||||
}
|
||||
|
||||
/// Read current generation parameters (non-consuming).
|
||||
pub async fn get_generation_params(&self) -> Option<GenerationParams> {
|
||||
self.generation_params.read().await.clone()
|
||||
}
|
||||
|
||||
/// Clear generation parameters.
|
||||
pub async fn clear_generation_params(&self) {
|
||||
*self.generation_params.write().await = None;
|
||||
}
|
||||
|
||||
// ── Grounding metadata capture ──────────────────────────────────────
|
||||
|
||||
/// Store captured grounding metadata from API response.
|
||||
@@ -652,46 +575,35 @@ impl MitmStore {
|
||||
self.captured_grounding.read().await.clone()
|
||||
}
|
||||
|
||||
// ── Pending image for MITM injection ─────────────────────────────────
|
||||
// ── Compat shims for streaming tool-call loops ──────────────────────
|
||||
|
||||
/// Store a pending image for MITM injection.
|
||||
pub async fn set_pending_image(&self, image: PendingImage) {
|
||||
*self.pending_image.write().await = Some(image);
|
||||
/// Update the event channel on an existing request context.
|
||||
/// Used by streaming loop handlers when re-registering for a new tool round.
|
||||
pub async fn set_channel(&self, cascade_id: &str, tx: mpsc::Sender<MitmEvent>) {
|
||||
self.update_request(cascade_id, |ctx| {
|
||||
ctx.event_channel = Some(tx);
|
||||
}).await;
|
||||
}
|
||||
|
||||
/// Take (consume) pending image for injection.
|
||||
pub async fn take_pending_image(&self) -> Option<PendingImage> {
|
||||
self.pending_image.write().await.take()
|
||||
}
|
||||
|
||||
// ── Upstream error capture ───────────────────────────────────────────
|
||||
|
||||
/// Store an upstream error from Google's API.
|
||||
pub async fn set_upstream_error(&self, error: UpstreamError) {
|
||||
*self.upstream_error.write().await = Some(error);
|
||||
}
|
||||
|
||||
/// Take (consume) captured upstream error.
|
||||
pub async fn take_upstream_error(&self) -> Option<UpstreamError> {
|
||||
self.upstream_error.write().await.take()
|
||||
}
|
||||
|
||||
/// Clear any stored upstream error.
|
||||
/// No-op. Upstream errors are now delivered through the event channel.
|
||||
/// Kept for API handler compatibility.
|
||||
pub async fn clear_upstream_error(&self) {
|
||||
*self.upstream_error.write().await = None;
|
||||
// Intentionally empty — errors flow through MitmEvent::UpstreamError
|
||||
}
|
||||
|
||||
// ── Pending user text for MITM injection ─────────────────────────────
|
||||
|
||||
/// Store the real user text for MITM injection.
|
||||
/// Called by API handlers before sending a dummy prompt to the LS.
|
||||
pub async fn set_pending_user_text(&self, text: String) {
|
||||
*self.pending_user_text.write().await = Some(text);
|
||||
/// Returns None. Upstream errors are now captured and delivered via the
|
||||
/// per-request event channel rather than stored globally.
|
||||
pub async fn take_upstream_error(&self) -> Option<UpstreamError> {
|
||||
None
|
||||
}
|
||||
|
||||
/// Take (consume) the pending user text.
|
||||
/// Called by the MITM proxy when building ToolContext.
|
||||
pub async fn take_pending_user_text(&self) -> Option<String> {
|
||||
self.pending_user_text.write().await.take()
|
||||
/// Store a call_id → function_name mapping in the request context.
|
||||
/// Used by streaming tool-call loops when the model returns function calls.
|
||||
pub async fn register_call_id(&self, cascade_id: &str, call_id: String, name: String) {
|
||||
self.update_request(cascade_id, |ctx| {
|
||||
ctx.call_id_to_name.insert(call_id, name);
|
||||
}).await;
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user