refactor: endpoint parity and proxy improvements

Mixed changes from recent sessions: endpoint feature parity
improvements, proxy bug fixes, and store cleanup.
This commit is contained in:
Nikketryhard
2026-02-16 21:47:00 -06:00
parent 86675fd960
commit 637fbc0e54
5 changed files with 763 additions and 692 deletions

View File

@@ -1,12 +1,14 @@
//! Shared store for intercepted API usage data.
//!
//! The MITM proxy writes usage data here; the API handlers read from it.
//! When custom tools are active, the MITM proxy sends real-time events
//! through a channel instead of writing to shared state.
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use tokio::sync::RwLock;
use tokio::sync::{mpsc, RwLock};
use tracing::{debug, info};
/// Token usage from an intercepted API response.
@@ -126,6 +128,29 @@ pub struct GenerationParams {
pub google_search: bool,
}
// ─── Channel-based event pipeline ────────────────────────────────────────────
/// Events sent from the MITM proxy to API handlers through a per-request channel.
/// Replaces the old polling-based approach (shared atomics + RwLocks) with
/// instant, race-free delivery.
#[derive(Debug, Clone)]
pub enum MitmEvent {
/// Incremental thinking/reasoning text from the model.
ThinkingDelta(String),
/// Incremental response text from the model.
TextDelta(String),
/// Model requested function call(s).
FunctionCall(Vec<CapturedFunctionCall>),
/// Response streaming is complete (finishReason received).
ResponseComplete,
/// Google API returned an error.
UpstreamError(UpstreamError),
/// Grounding metadata (search results) from the response.
Grounding(serde_json::Value),
/// Token usage data from the response.
Usage(ApiUsage),
}
/// Thread-safe store for intercepted data.
///
/// Keyed by a unique request ID that we can correlate with cascade operations.
@@ -138,20 +163,17 @@ pub struct MitmStore {
stats: Arc<RwLock<MitmStats>>,
/// Pending function calls captured from Google responses.
/// Key: cascade hint or "_latest". Value: list of function calls.
/// Used by the non-tool LS path (normal sync responses).
pending_function_calls: Arc<RwLock<HashMap<String, Vec<CapturedFunctionCall>>>>,
/// Simple flag: set when a functionCall is captured, cleared when consumed.
/// Used to block follow-up requests regardless of cascade identification.
has_active_function_call: Arc<AtomicBool>,
/// Persistent flag: set when a function call is captured, cleared ONLY when
/// a tool result is submitted. Prevents the LS from making follow-up API
/// calls during the entire tool execution cycle.
awaiting_tool_result: Arc<AtomicBool>,
/// Set when the MITM forwards the first LLM request with custom tools.
/// Blocks ALL subsequent LS requests until the API handler clears it.
request_in_flight: Arc<AtomicBool>,
/// Generation counter — incremented each time a new completions turn starts.
/// Used to discard stale data from leaked LS connections.
request_generation: Arc<AtomicU64>,
// ── Channel-based event pipeline (replaces old polling) ──────────────
/// Active channel sender for the current tool-path request.
/// When present, the MITM proxy sends events through this instead of
/// writing to shared state. The channel's existence = request in-flight.
active_channel: Arc<RwLock<Option<mpsc::Sender<MitmEvent>>>>,
// ── Tool call support ────────────────────────────────────────────────
/// Active tool definitions (Gemini format) for MITM injection.
@@ -173,14 +195,9 @@ pub struct MitmStore {
/// Used by the MITM proxy to correlate intercepted traffic to cascades.
active_cascade_id: Arc<RwLock<Option<String>>>,
// ── Direct response capture (bypasses LS) ────────────────────────────
/// Captured response text from MITM when custom tools are active.
/// The completions/responses handler reads this instead of polling LS steps.
// ── Legacy direct response capture (used by search.rs) ───────────────
/// Captured response text from MITM. Used as fallback by search endpoint.
captured_response_text: Arc<RwLock<Option<String>>>,
/// Captured thinking/reasoning text from MITM (for real-time streaming).
captured_thinking_text: Arc<RwLock<Option<String>>>,
/// Whether the captured response is complete (finishReason received).
response_complete: Arc<AtomicBool>,
// ── Generation parameters for MITM injection ─────────────────────────
/// Client-specified sampling parameters to inject into Google API requests.
@@ -194,9 +211,14 @@ pub struct MitmStore {
/// Image to inject into the next Google API request via MITM.
pending_image: Arc<RwLock<Option<PendingImage>>>,
// ── Upstream error capture ───────────────────────────────────────────
// ── Upstream error capture (legacy, used when no channel) ────────────
/// Error from Google's API, captured by MITM for forwarding to client.
upstream_error: Arc<RwLock<Option<UpstreamError>>>,
// ── Standard LS input: real user text for MITM injection ─────────────
/// The real user text to inject into the Google API request.
/// API handlers store this before sending a dummy prompt to the LS.
pending_user_text: Arc<RwLock<Option<String>>>,
}
/// Aggregate statistics across all intercepted traffic.
@@ -229,10 +251,8 @@ impl MitmStore {
latest_usage: Arc::new(RwLock::new(HashMap::new())),
stats: Arc::new(RwLock::new(MitmStats::default())),
pending_function_calls: Arc::new(RwLock::new(HashMap::new())),
has_active_function_call: Arc::new(AtomicBool::new(false)),
awaiting_tool_result: Arc::new(AtomicBool::new(false)),
request_in_flight: Arc::new(AtomicBool::new(false)),
request_generation: Arc::new(AtomicU64::new(0)),
active_channel: Arc::new(RwLock::new(None)),
active_tools: Arc::new(RwLock::new(None)),
active_tool_config: Arc::new(RwLock::new(None)),
pending_tool_results: Arc::new(RwLock::new(Vec::new())),
@@ -241,12 +261,11 @@ impl MitmStore {
tool_rounds: Arc::new(RwLock::new(Vec::new())),
active_cascade_id: Arc::new(RwLock::new(None)),
captured_response_text: Arc::new(RwLock::new(None)),
captured_thinking_text: Arc::new(RwLock::new(None)),
response_complete: Arc::new(AtomicBool::new(false)),
generation_params: Arc::new(RwLock::new(None)),
captured_grounding: Arc::new(RwLock::new(None)),
pending_image: Arc::new(RwLock::new(None)),
upstream_error: Arc::new(RwLock::new(None)),
pending_user_text: Arc::new(RwLock::new(None)),
}
}
@@ -381,30 +400,6 @@ impl MitmStore {
);
let mut pending = self.pending_function_calls.write().await;
pending.entry(key).or_default().push(fc);
self.has_active_function_call.store(true, Ordering::SeqCst);
self.awaiting_tool_result.store(true, Ordering::SeqCst);
}
/// Check if there's an active (unclaimed) function call.
pub fn has_active_function_call(&self) -> bool {
self.has_active_function_call.load(Ordering::SeqCst)
}
/// Force-clear the active function call flag (used to reset stale state).
pub fn clear_active_function_call(&self) {
self.has_active_function_call.store(false, Ordering::SeqCst);
}
/// Check if we're awaiting a tool result (blocks LS follow-up requests).
/// This persists across function call consumption — only cleared when
/// actual tool results are submitted.
pub fn is_awaiting_tool_result(&self) -> bool {
self.awaiting_tool_result.load(Ordering::SeqCst)
}
/// Clear the awaiting-tool-result flag (called when tool results arrive).
pub fn clear_awaiting_tool_result(&self) {
self.awaiting_tool_result.store(false, Ordering::SeqCst);
}
/// Take pending function calls for a specific cascade.
@@ -417,7 +412,6 @@ impl MitmStore {
// 1. Exact cascade match
if let Some(result) = pending.remove(cascade_id) {
self.has_active_function_call.store(false, Ordering::SeqCst);
return Some(result);
}
@@ -425,7 +419,6 @@ impl MitmStore {
if let Some(active) = self.active_cascade_id.read().await.as_ref() {
if active != cascade_id {
if let Some(result) = pending.remove(active.as_str()) {
self.has_active_function_call.store(false, Ordering::SeqCst);
return Some(result);
}
}
@@ -433,17 +426,12 @@ impl MitmStore {
// 3. Fallback to _latest
if let Some(result) = pending.remove("_latest") {
self.has_active_function_call.store(false, Ordering::SeqCst);
return Some(result);
}
// 4. Last resort: any key
if let Some(key) = pending.keys().next().cloned() {
let result = pending.remove(&key);
if result.is_some() {
self.has_active_function_call.store(false, Ordering::SeqCst);
}
return result;
return pending.remove(&key);
}
None
@@ -455,19 +443,40 @@ impl MitmStore {
let mut pending = self.pending_function_calls.write().await;
let result = pending.remove("_latest");
if result.is_some() {
self.has_active_function_call.store(false, Ordering::SeqCst);
return result;
}
if let Some(key) = pending.keys().next().cloned() {
let result = pending.remove(&key);
if result.is_some() {
self.has_active_function_call.store(false, Ordering::SeqCst);
}
return result;
return pending.remove(&key);
}
None
}
// ── Channel-based event pipeline ─────────────────────────────────────
/// Install a channel sender for the current tool-path request.
/// The MITM proxy will send events through this channel.
pub async fn set_channel(&self, tx: mpsc::Sender<MitmEvent>) {
*self.active_channel.write().await = Some(tx);
// NOTE: Do NOT set request_in_flight here. The MITM proxy's
// try_mark_request_in_flight() is the sole setter — setting it
// here causes compare_exchange(false,true) to always fail,
// blocking every real LS request.
}
/// Take the active channel sender (used by MITM proxy to grab it).
/// Returns None if no channel is active.
pub async fn take_channel(&self) -> Option<mpsc::Sender<MitmEvent>> {
self.active_channel.write().await.take()
}
/// Drop the active channel and clear in-flight state.
/// Called when the API handler is done with the current request.
pub async fn drop_channel(&self) {
*self.active_channel.write().await = None;
self.request_in_flight.store(false, Ordering::SeqCst);
}
// ── Tool context methods ─────────────────────────────────────────────
/// Set active tool definitions (already in Gemini format).
@@ -546,10 +555,10 @@ impl MitmStore {
self.tool_rounds.read().await.clone()
}
// ── Direct response capture (bypass LS) ──────────────────────────────
// ── Legacy direct response capture (search.rs fallback) ──────────────
/// Set (replace) the captured response text.
/// Used by MITM proxy for non-channel path (search endpoint fallback).
pub async fn set_response_text(&self, text: &str) {
*self.captured_response_text.write().await = Some(text.to_string());
}
@@ -559,28 +568,12 @@ impl MitmStore {
self.captured_response_text.write().await.take()
}
/// Peek at the captured response text without consuming it.
pub async fn peek_response_text(&self) -> Option<String> {
self.captured_response_text.read().await.clone()
}
/// Mark the response as complete.
pub fn mark_response_complete(&self) {
self.response_complete.store(true, Ordering::SeqCst);
}
/// Check if the response is complete.
pub fn is_response_complete(&self) -> bool {
self.response_complete.load(Ordering::SeqCst)
}
/// Async version of clear_response. Bumps generation counter.
/// Clear stale state between requests.
/// Drops any active channel and clears in-flight flags.
pub async fn clear_response_async(&self) {
self.response_complete.store(false, Ordering::SeqCst);
self.request_in_flight.store(false, Ordering::SeqCst);
self.request_generation.fetch_add(1, Ordering::SeqCst);
*self.active_channel.write().await = None;
*self.captured_response_text.write().await = None;
*self.captured_thinking_text.write().await = None;
}
/// Atomically try to mark request as in-flight.
@@ -593,6 +586,7 @@ impl MitmStore {
}
/// Check if a request is currently in-flight.
#[allow(dead_code)]
pub fn is_request_in_flight(&self) -> bool {
self.request_in_flight.load(Ordering::SeqCst)
}
@@ -602,38 +596,6 @@ impl MitmStore {
self.request_in_flight.store(false, Ordering::SeqCst);
}
/// Reset response_complete so we can wait for the next response.
pub fn clear_response_complete(&self) {
self.response_complete.store(false, Ordering::SeqCst);
}
/// Get current generation number.
pub fn current_generation(&self) -> u64 {
self.request_generation.load(Ordering::SeqCst)
}
/// Bump generation counter (invalidates all pending data from old generation).
pub fn bump_generation(&self) -> u64 {
self.request_generation.fetch_add(1, Ordering::SeqCst) + 1
}
// ── Thinking text capture ────────────────────────────────────────────
/// Set (replace) the captured thinking text.
pub async fn set_thinking_text(&self, text: &str) {
*self.captured_thinking_text.write().await = Some(text.to_string());
}
/// Peek at the captured thinking text without consuming it.
pub async fn peek_thinking_text(&self) -> Option<String> {
self.captured_thinking_text.read().await.clone()
}
/// Take the captured thinking text (consumes it).
pub async fn take_thinking_text(&self) -> Option<String> {
self.captured_thinking_text.write().await.take()
}
// ── Cascade correlation ──────────────────────────────────────────────
/// Set the active cascade ID (called by API handlers before sending a message).
@@ -718,4 +680,18 @@ impl MitmStore {
pub async fn clear_upstream_error(&self) {
*self.upstream_error.write().await = None;
}
// ── Pending user text for MITM injection ─────────────────────────────
/// Store the real user text for MITM injection.
/// Called by API handlers before sending a dummy prompt to the LS.
pub async fn set_pending_user_text(&self, text: String) {
*self.pending_user_text.write().await = Some(text);
}
/// Take (consume) the pending user text.
/// Called by the MITM proxy when building ToolContext.
pub async fn take_pending_user_text(&self) -> Option<String> {
self.pending_user_text.write().await.take()
}
}