refactor: endpoint parity and proxy improvements
Mixed changes from recent sessions: endpoint feature parity improvements, proxy bug fixes, and store cleanup.
This commit is contained in:
@@ -1,12 +1,14 @@
|
||||
//! Shared store for intercepted API usage data.
|
||||
//!
|
||||
//! The MITM proxy writes usage data here; the API handlers read from it.
|
||||
//! When custom tools are active, the MITM proxy sends real-time events
|
||||
//! through a channel instead of writing to shared state.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
use tokio::sync::{mpsc, RwLock};
|
||||
use tracing::{debug, info};
|
||||
|
||||
/// Token usage from an intercepted API response.
|
||||
@@ -126,6 +128,29 @@ pub struct GenerationParams {
|
||||
pub google_search: bool,
|
||||
}
|
||||
|
||||
// ─── Channel-based event pipeline ────────────────────────────────────────────
|
||||
|
||||
/// Events sent from the MITM proxy to API handlers through a per-request channel.
|
||||
/// Replaces the old polling-based approach (shared atomics + RwLocks) with
|
||||
/// instant, race-free delivery.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum MitmEvent {
|
||||
/// Incremental thinking/reasoning text from the model.
|
||||
ThinkingDelta(String),
|
||||
/// Incremental response text from the model.
|
||||
TextDelta(String),
|
||||
/// Model requested function call(s).
|
||||
FunctionCall(Vec<CapturedFunctionCall>),
|
||||
/// Response streaming is complete (finishReason received).
|
||||
ResponseComplete,
|
||||
/// Google API returned an error.
|
||||
UpstreamError(UpstreamError),
|
||||
/// Grounding metadata (search results) from the response.
|
||||
Grounding(serde_json::Value),
|
||||
/// Token usage data from the response.
|
||||
Usage(ApiUsage),
|
||||
}
|
||||
|
||||
/// Thread-safe store for intercepted data.
|
||||
///
|
||||
/// Keyed by a unique request ID that we can correlate with cascade operations.
|
||||
@@ -138,20 +163,17 @@ pub struct MitmStore {
|
||||
stats: Arc<RwLock<MitmStats>>,
|
||||
/// Pending function calls captured from Google responses.
|
||||
/// Key: cascade hint or "_latest". Value: list of function calls.
|
||||
/// Used by the non-tool LS path (normal sync responses).
|
||||
pending_function_calls: Arc<RwLock<HashMap<String, Vec<CapturedFunctionCall>>>>,
|
||||
/// Simple flag: set when a functionCall is captured, cleared when consumed.
|
||||
/// Used to block follow-up requests regardless of cascade identification.
|
||||
has_active_function_call: Arc<AtomicBool>,
|
||||
/// Persistent flag: set when a function call is captured, cleared ONLY when
|
||||
/// a tool result is submitted. Prevents the LS from making follow-up API
|
||||
/// calls during the entire tool execution cycle.
|
||||
awaiting_tool_result: Arc<AtomicBool>,
|
||||
/// Set when the MITM forwards the first LLM request with custom tools.
|
||||
/// Blocks ALL subsequent LS requests until the API handler clears it.
|
||||
request_in_flight: Arc<AtomicBool>,
|
||||
/// Generation counter — incremented each time a new completions turn starts.
|
||||
/// Used to discard stale data from leaked LS connections.
|
||||
request_generation: Arc<AtomicU64>,
|
||||
|
||||
// ── Channel-based event pipeline (replaces old polling) ──────────────
|
||||
/// Active channel sender for the current tool-path request.
|
||||
/// When present, the MITM proxy sends events through this instead of
|
||||
/// writing to shared state. The channel's existence = request in-flight.
|
||||
active_channel: Arc<RwLock<Option<mpsc::Sender<MitmEvent>>>>,
|
||||
|
||||
// ── Tool call support ────────────────────────────────────────────────
|
||||
/// Active tool definitions (Gemini format) for MITM injection.
|
||||
@@ -173,14 +195,9 @@ pub struct MitmStore {
|
||||
/// Used by the MITM proxy to correlate intercepted traffic to cascades.
|
||||
active_cascade_id: Arc<RwLock<Option<String>>>,
|
||||
|
||||
// ── Direct response capture (bypasses LS) ────────────────────────────
|
||||
/// Captured response text from MITM when custom tools are active.
|
||||
/// The completions/responses handler reads this instead of polling LS steps.
|
||||
// ── Legacy direct response capture (used by search.rs) ───────────────
|
||||
/// Captured response text from MITM. Used as fallback by search endpoint.
|
||||
captured_response_text: Arc<RwLock<Option<String>>>,
|
||||
/// Captured thinking/reasoning text from MITM (for real-time streaming).
|
||||
captured_thinking_text: Arc<RwLock<Option<String>>>,
|
||||
/// Whether the captured response is complete (finishReason received).
|
||||
response_complete: Arc<AtomicBool>,
|
||||
|
||||
// ── Generation parameters for MITM injection ─────────────────────────
|
||||
/// Client-specified sampling parameters to inject into Google API requests.
|
||||
@@ -194,9 +211,14 @@ pub struct MitmStore {
|
||||
/// Image to inject into the next Google API request via MITM.
|
||||
pending_image: Arc<RwLock<Option<PendingImage>>>,
|
||||
|
||||
// ── Upstream error capture ───────────────────────────────────────────
|
||||
// ── Upstream error capture (legacy, used when no channel) ────────────
|
||||
/// Error from Google's API, captured by MITM for forwarding to client.
|
||||
upstream_error: Arc<RwLock<Option<UpstreamError>>>,
|
||||
|
||||
// ── Standard LS input: real user text for MITM injection ─────────────
|
||||
/// The real user text to inject into the Google API request.
|
||||
/// API handlers store this before sending a dummy prompt to the LS.
|
||||
pending_user_text: Arc<RwLock<Option<String>>>,
|
||||
}
|
||||
|
||||
/// Aggregate statistics across all intercepted traffic.
|
||||
@@ -229,10 +251,8 @@ impl MitmStore {
|
||||
latest_usage: Arc::new(RwLock::new(HashMap::new())),
|
||||
stats: Arc::new(RwLock::new(MitmStats::default())),
|
||||
pending_function_calls: Arc::new(RwLock::new(HashMap::new())),
|
||||
has_active_function_call: Arc::new(AtomicBool::new(false)),
|
||||
awaiting_tool_result: Arc::new(AtomicBool::new(false)),
|
||||
request_in_flight: Arc::new(AtomicBool::new(false)),
|
||||
request_generation: Arc::new(AtomicU64::new(0)),
|
||||
active_channel: Arc::new(RwLock::new(None)),
|
||||
active_tools: Arc::new(RwLock::new(None)),
|
||||
active_tool_config: Arc::new(RwLock::new(None)),
|
||||
pending_tool_results: Arc::new(RwLock::new(Vec::new())),
|
||||
@@ -241,12 +261,11 @@ impl MitmStore {
|
||||
tool_rounds: Arc::new(RwLock::new(Vec::new())),
|
||||
active_cascade_id: Arc::new(RwLock::new(None)),
|
||||
captured_response_text: Arc::new(RwLock::new(None)),
|
||||
captured_thinking_text: Arc::new(RwLock::new(None)),
|
||||
response_complete: Arc::new(AtomicBool::new(false)),
|
||||
generation_params: Arc::new(RwLock::new(None)),
|
||||
captured_grounding: Arc::new(RwLock::new(None)),
|
||||
pending_image: Arc::new(RwLock::new(None)),
|
||||
upstream_error: Arc::new(RwLock::new(None)),
|
||||
pending_user_text: Arc::new(RwLock::new(None)),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -381,30 +400,6 @@ impl MitmStore {
|
||||
);
|
||||
let mut pending = self.pending_function_calls.write().await;
|
||||
pending.entry(key).or_default().push(fc);
|
||||
self.has_active_function_call.store(true, Ordering::SeqCst);
|
||||
self.awaiting_tool_result.store(true, Ordering::SeqCst);
|
||||
}
|
||||
|
||||
/// Check if there's an active (unclaimed) function call.
|
||||
pub fn has_active_function_call(&self) -> bool {
|
||||
self.has_active_function_call.load(Ordering::SeqCst)
|
||||
}
|
||||
|
||||
/// Force-clear the active function call flag (used to reset stale state).
|
||||
pub fn clear_active_function_call(&self) {
|
||||
self.has_active_function_call.store(false, Ordering::SeqCst);
|
||||
}
|
||||
|
||||
/// Check if we're awaiting a tool result (blocks LS follow-up requests).
|
||||
/// This persists across function call consumption — only cleared when
|
||||
/// actual tool results are submitted.
|
||||
pub fn is_awaiting_tool_result(&self) -> bool {
|
||||
self.awaiting_tool_result.load(Ordering::SeqCst)
|
||||
}
|
||||
|
||||
/// Clear the awaiting-tool-result flag (called when tool results arrive).
|
||||
pub fn clear_awaiting_tool_result(&self) {
|
||||
self.awaiting_tool_result.store(false, Ordering::SeqCst);
|
||||
}
|
||||
|
||||
/// Take pending function calls for a specific cascade.
|
||||
@@ -417,7 +412,6 @@ impl MitmStore {
|
||||
|
||||
// 1. Exact cascade match
|
||||
if let Some(result) = pending.remove(cascade_id) {
|
||||
self.has_active_function_call.store(false, Ordering::SeqCst);
|
||||
return Some(result);
|
||||
}
|
||||
|
||||
@@ -425,7 +419,6 @@ impl MitmStore {
|
||||
if let Some(active) = self.active_cascade_id.read().await.as_ref() {
|
||||
if active != cascade_id {
|
||||
if let Some(result) = pending.remove(active.as_str()) {
|
||||
self.has_active_function_call.store(false, Ordering::SeqCst);
|
||||
return Some(result);
|
||||
}
|
||||
}
|
||||
@@ -433,17 +426,12 @@ impl MitmStore {
|
||||
|
||||
// 3. Fallback to _latest
|
||||
if let Some(result) = pending.remove("_latest") {
|
||||
self.has_active_function_call.store(false, Ordering::SeqCst);
|
||||
return Some(result);
|
||||
}
|
||||
|
||||
// 4. Last resort: any key
|
||||
if let Some(key) = pending.keys().next().cloned() {
|
||||
let result = pending.remove(&key);
|
||||
if result.is_some() {
|
||||
self.has_active_function_call.store(false, Ordering::SeqCst);
|
||||
}
|
||||
return result;
|
||||
return pending.remove(&key);
|
||||
}
|
||||
|
||||
None
|
||||
@@ -455,19 +443,40 @@ impl MitmStore {
|
||||
let mut pending = self.pending_function_calls.write().await;
|
||||
let result = pending.remove("_latest");
|
||||
if result.is_some() {
|
||||
self.has_active_function_call.store(false, Ordering::SeqCst);
|
||||
return result;
|
||||
}
|
||||
if let Some(key) = pending.keys().next().cloned() {
|
||||
let result = pending.remove(&key);
|
||||
if result.is_some() {
|
||||
self.has_active_function_call.store(false, Ordering::SeqCst);
|
||||
}
|
||||
return result;
|
||||
return pending.remove(&key);
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
// ── Channel-based event pipeline ─────────────────────────────────────
|
||||
|
||||
/// Install a channel sender for the current tool-path request.
|
||||
/// The MITM proxy will send events through this channel.
|
||||
pub async fn set_channel(&self, tx: mpsc::Sender<MitmEvent>) {
|
||||
*self.active_channel.write().await = Some(tx);
|
||||
// NOTE: Do NOT set request_in_flight here. The MITM proxy's
|
||||
// try_mark_request_in_flight() is the sole setter — setting it
|
||||
// here causes compare_exchange(false,true) to always fail,
|
||||
// blocking every real LS request.
|
||||
}
|
||||
|
||||
/// Take the active channel sender (used by MITM proxy to grab it).
|
||||
/// Returns None if no channel is active.
|
||||
pub async fn take_channel(&self) -> Option<mpsc::Sender<MitmEvent>> {
|
||||
self.active_channel.write().await.take()
|
||||
}
|
||||
|
||||
|
||||
/// Drop the active channel and clear in-flight state.
|
||||
/// Called when the API handler is done with the current request.
|
||||
pub async fn drop_channel(&self) {
|
||||
*self.active_channel.write().await = None;
|
||||
self.request_in_flight.store(false, Ordering::SeqCst);
|
||||
}
|
||||
|
||||
// ── Tool context methods ─────────────────────────────────────────────
|
||||
|
||||
/// Set active tool definitions (already in Gemini format).
|
||||
@@ -546,10 +555,10 @@ impl MitmStore {
|
||||
self.tool_rounds.read().await.clone()
|
||||
}
|
||||
|
||||
|
||||
// ── Direct response capture (bypass LS) ──────────────────────────────
|
||||
// ── Legacy direct response capture (search.rs fallback) ──────────────
|
||||
|
||||
/// Set (replace) the captured response text.
|
||||
/// Used by MITM proxy for non-channel path (search endpoint fallback).
|
||||
pub async fn set_response_text(&self, text: &str) {
|
||||
*self.captured_response_text.write().await = Some(text.to_string());
|
||||
}
|
||||
@@ -559,28 +568,12 @@ impl MitmStore {
|
||||
self.captured_response_text.write().await.take()
|
||||
}
|
||||
|
||||
/// Peek at the captured response text without consuming it.
|
||||
pub async fn peek_response_text(&self) -> Option<String> {
|
||||
self.captured_response_text.read().await.clone()
|
||||
}
|
||||
|
||||
/// Mark the response as complete.
|
||||
pub fn mark_response_complete(&self) {
|
||||
self.response_complete.store(true, Ordering::SeqCst);
|
||||
}
|
||||
|
||||
/// Check if the response is complete.
|
||||
pub fn is_response_complete(&self) -> bool {
|
||||
self.response_complete.load(Ordering::SeqCst)
|
||||
}
|
||||
|
||||
/// Async version of clear_response. Bumps generation counter.
|
||||
/// Clear stale state between requests.
|
||||
/// Drops any active channel and clears in-flight flags.
|
||||
pub async fn clear_response_async(&self) {
|
||||
self.response_complete.store(false, Ordering::SeqCst);
|
||||
self.request_in_flight.store(false, Ordering::SeqCst);
|
||||
self.request_generation.fetch_add(1, Ordering::SeqCst);
|
||||
*self.active_channel.write().await = None;
|
||||
*self.captured_response_text.write().await = None;
|
||||
*self.captured_thinking_text.write().await = None;
|
||||
}
|
||||
|
||||
/// Atomically try to mark request as in-flight.
|
||||
@@ -593,6 +586,7 @@ impl MitmStore {
|
||||
}
|
||||
|
||||
/// Check if a request is currently in-flight.
|
||||
#[allow(dead_code)]
|
||||
pub fn is_request_in_flight(&self) -> bool {
|
||||
self.request_in_flight.load(Ordering::SeqCst)
|
||||
}
|
||||
@@ -602,38 +596,6 @@ impl MitmStore {
|
||||
self.request_in_flight.store(false, Ordering::SeqCst);
|
||||
}
|
||||
|
||||
/// Reset response_complete so we can wait for the next response.
|
||||
pub fn clear_response_complete(&self) {
|
||||
self.response_complete.store(false, Ordering::SeqCst);
|
||||
}
|
||||
|
||||
/// Get current generation number.
|
||||
pub fn current_generation(&self) -> u64 {
|
||||
self.request_generation.load(Ordering::SeqCst)
|
||||
}
|
||||
|
||||
/// Bump generation counter (invalidates all pending data from old generation).
|
||||
pub fn bump_generation(&self) -> u64 {
|
||||
self.request_generation.fetch_add(1, Ordering::SeqCst) + 1
|
||||
}
|
||||
|
||||
// ── Thinking text capture ────────────────────────────────────────────
|
||||
|
||||
/// Set (replace) the captured thinking text.
|
||||
pub async fn set_thinking_text(&self, text: &str) {
|
||||
*self.captured_thinking_text.write().await = Some(text.to_string());
|
||||
}
|
||||
|
||||
/// Peek at the captured thinking text without consuming it.
|
||||
pub async fn peek_thinking_text(&self) -> Option<String> {
|
||||
self.captured_thinking_text.read().await.clone()
|
||||
}
|
||||
|
||||
/// Take the captured thinking text (consumes it).
|
||||
pub async fn take_thinking_text(&self) -> Option<String> {
|
||||
self.captured_thinking_text.write().await.take()
|
||||
}
|
||||
|
||||
// ── Cascade correlation ──────────────────────────────────────────────
|
||||
|
||||
/// Set the active cascade ID (called by API handlers before sending a message).
|
||||
@@ -718,4 +680,18 @@ impl MitmStore {
|
||||
pub async fn clear_upstream_error(&self) {
|
||||
*self.upstream_error.write().await = None;
|
||||
}
|
||||
|
||||
// ── Pending user text for MITM injection ─────────────────────────────
|
||||
|
||||
/// Store the real user text for MITM injection.
|
||||
/// Called by API handlers before sending a dummy prompt to the LS.
|
||||
pub async fn set_pending_user_text(&self, text: String) {
|
||||
*self.pending_user_text.write().await = Some(text);
|
||||
}
|
||||
|
||||
/// Take (consume) the pending user text.
|
||||
/// Called by the MITM proxy when building ToolContext.
|
||||
pub async fn take_pending_user_text(&self) -> Option<String> {
|
||||
self.pending_user_text.write().await.take()
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user