fix: gemini route, usage capture, search timeout, and trace finalization

- Add missing /v1/gemini POST route and handler
- Capture MitmEvent::Usage in gemini sync/streaming handlers
- Add retry counter (max 3) to search handler to prevent hang
- Add trace finalization at all gemini_sync channel exit points
- Fix UpstreamError trace outcome label
- Add timeout trace with error recording
- Dispatch Usage before ResponseComplete in SSE flush
This commit is contained in:
Nikketryhard
2026-02-18 01:31:18 -06:00
parent 48674f65da
commit 28d3296c87
11 changed files with 1480 additions and 221 deletions

View File

@@ -435,21 +435,33 @@ pub(crate) async fn handle_completions(
.map(|r| r.calls.clone()) .map(|r| r.calls.clone())
.unwrap_or_default(); .unwrap_or_default();
// Build event channel for streaming // Build event channel — always created for MITM response path
let has_custom_tools = tools.is_some(); let (tx, rx) = tokio::sync::mpsc::channel(64);
let (mitm_rx, event_tx) = if has_custom_tools && body.stream { let (mitm_rx, event_tx) = (Some(rx), tx);
let (tx, rx) = tokio::sync::mpsc::channel(64);
(Some(rx), Some(tx))
} else {
(None, None)
};
// Build pending tool results from latest round // Build pending tool results from latest round
let pending_tool_results = tool_rounds.last() let pending_tool_results = tool_rounds.last()
.map(|r| r.results.clone()) .map(|r| r.results.clone())
.unwrap_or_default(); .unwrap_or_default();
// Register all per-request state atomically // Start debug trace
let trace = state.trace.start(&cascade_id, "POST /v1/chat/completions", model_name, body.stream);
if let Some(ref t) = trace {
t.set_client_request(crate::trace::ClientRequestSummary {
message_count: body.messages.len(),
tool_count: body.tools.as_ref().map_or(0, |t| t.len()),
tool_round_count: tool_rounds.len(),
user_text_len: user_text.len(),
user_text_preview: user_text.chars().take(200).collect(),
system_prompt: body.messages.iter().any(|m| m.role == "system"),
has_image: image.is_some(),
}).await;
// Start turn 0
t.start_turn().await;
}
let mitm_gate = std::sync::Arc::new(tokio::sync::Notify::new());
let mitm_gate_clone = mitm_gate.clone();
state.mitm_store.register_request(crate::mitm::store::RequestContext { state.mitm_store.register_request(crate::mitm::store::RequestContext {
cascade_id: cascade_id.clone(), cascade_id: cascade_id.clone(),
pending_user_text: user_text.clone(), pending_user_text: user_text.clone(),
@@ -463,6 +475,9 @@ pub(crate) async fn handle_completions(
last_function_calls, last_function_calls,
call_id_to_name, call_id_to_name,
created_at: std::time::Instant::now(), created_at: std::time::Instant::now(),
gate: mitm_gate_clone,
trace_handle: trace.clone(),
trace_turn: 0,
}).await; }).await;
// Send REAL user text to LS // Send REAL user text to LS
@@ -480,6 +495,7 @@ pub(crate) async fn handle_completions(
} }
Ok((status, _)) => { Ok((status, _)) => {
state.mitm_store.remove_request(&cascade_id).await; state.mitm_store.remove_request(&cascade_id).await;
if let Some(ref t) = trace { t.record_error(format!("Backend returned {status}")).await; t.finish("backend_error").await; }
return err_response( return err_response(
StatusCode::BAD_GATEWAY, StatusCode::BAD_GATEWAY,
format!("Backend returned {status}"), format!("Backend returned {status}"),
@@ -488,6 +504,7 @@ pub(crate) async fn handle_completions(
} }
Err(e) => { Err(e) => {
state.mitm_store.remove_request(&cascade_id).await; state.mitm_store.remove_request(&cascade_id).await;
if let Some(ref t) = trace { t.record_error(format!("Send failed: {e}")).await; t.finish("send_error").await; }
return err_response( return err_response(
StatusCode::BAD_GATEWAY, StatusCode::BAD_GATEWAY,
format!("Send failed: {e}"), format!("Send failed: {e}"),
@@ -496,6 +513,34 @@ pub(crate) async fn handle_completions(
} }
} }
// Wait for MITM gate: 5s → 502 if MITM enabled
let gate_start = std::time::Instant::now();
let gate_matched = tokio::time::timeout(
std::time::Duration::from_secs(5),
mitm_gate.notified(),
).await;
let gate_wait_ms = gate_start.elapsed().as_millis() as u64;
if gate_matched.is_err() {
if state.mitm_enabled {
state.mitm_store.remove_request(&cascade_id).await;
if let Some(ref t) = trace {
t.record_error("MITM gate timeout (5s)".to_string()).await;
t.finish("mitm_timeout").await;
}
return err_response(
StatusCode::BAD_GATEWAY,
"MITM proxy did not match request within 5s".to_string(),
"mitm_timeout",
);
}
warn!(cascade = %cascade_id, "MITM gate timeout (--no-mitm mode)");
} else {
debug!(cascade = %cascade_id, gate_wait_ms, "MITM gate signaled — request matched");
if let Some(ref t) = trace {
t.record_mitm_match(0, gate_wait_ms).await;
}
}
let completion_id = format!( let completion_id = format!(
"chatcmpl-{}", "chatcmpl-{}",
uuid::Uuid::new_v4().to_string().replace('-', "") uuid::Uuid::new_v4().to_string().replace('-', "")
@@ -515,6 +560,7 @@ pub(crate) async fn handle_completions(
body.timeout, body.timeout,
include_usage, include_usage,
mitm_rx, mitm_rx,
trace,
) )
.await .await
} else if n <= 1 { } else if n <= 1 {
@@ -524,6 +570,7 @@ pub(crate) async fn handle_completions(
model_name.to_string(), model_name.to_string(),
cascade_id, cascade_id,
body.timeout, body.timeout,
trace,
) )
.await .await
} else { } else {
@@ -653,6 +700,7 @@ async fn chat_completions_stream(
timeout: u64, timeout: u64,
include_usage: bool, include_usage: bool,
mitm_rx: Option<tokio::sync::mpsc::Receiver<crate::mitm::store::MitmEvent>>, mitm_rx: Option<tokio::sync::mpsc::Receiver<crate::mitm::store::MitmEvent>>,
trace: Option<crate::trace::TraceHandle>,
) -> axum::response::Response { ) -> axum::response::Response {
let stream = async_stream::stream! { let stream = async_stream::stream! {
let start = std::time::Instant::now(); let start = std::time::Instant::now();
@@ -774,6 +822,21 @@ async fn chat_completions_stream(
} }
yield Ok(Event::default().data("[DONE]")); yield Ok(Event::default().data("[DONE]"));
state.mitm_store.remove_request(&cascade_id).await; state.mitm_store.remove_request(&cascade_id).await;
if let Some(ref t) = trace {
let (ipt, opt, crt2, tht) = if let Some(ref u) = last_usage {
(u.input_tokens, u.output_tokens, u.cache_read_input_tokens, u.thinking_output_tokens)
} else { (0, 0, 0, 0) };
t.record_response(0, crate::trace::ResponseSummary {
text_len: 0, thinking_len: 0, text_preview: String::new(),
finish_reason: Some("tool_calls".to_string()),
function_calls: calls.iter().map(|fc| crate::trace::FunctionCallSummary {
name: fc.name.clone(), args_preview: serde_json::to_string(&fc.args).unwrap_or_default().chars().take(200).collect(),
}).collect(),
grounding: false,
}).await;
t.set_usage(crate::trace::TrackedUsage { input_tokens: ipt, output_tokens: opt, thinking_tokens: tht, cache_read: crt2 }).await;
t.finish("tool_call").await;
}
return; return;
} }
MitmEvent::ResponseComplete => { MitmEvent::ResponseComplete => {
@@ -802,6 +865,19 @@ async fn chat_completions_stream(
} }
yield Ok(Event::default().data("[DONE]")); yield Ok(Event::default().data("[DONE]"));
state.mitm_store.remove_request(&cascade_id).await; state.mitm_store.remove_request(&cascade_id).await;
if let Some(ref t) = trace {
let (ipt, opt, crt2, tht) = if let Some(ref u) = mitm {
(u.input_tokens, u.output_tokens, u.cache_read_input_tokens, u.thinking_output_tokens)
} else { (0, 0, 0, 0) };
t.record_response(0, crate::trace::ResponseSummary {
text_len: acc_text.len(), thinking_len: acc_thinking.len(),
text_preview: acc_text.chars().take(200).collect(),
finish_reason: Some("stop".to_string()),
function_calls: Vec::new(), grounding: false,
}).await;
t.set_usage(crate::trace::TrackedUsage { input_tokens: ipt, output_tokens: opt, thinking_tokens: tht, cache_read: crt2 }).await;
t.finish("completed").await;
}
return; return;
} else if !acc_thinking.is_empty() && !did_unblock_ls { } else if !acc_thinking.is_empty() && !did_unblock_ls {
// Thinking-only response — LS needs follow-up API calls. // Thinking-only response — LS needs follow-up API calls.
@@ -844,6 +920,19 @@ async fn chat_completions_stream(
} }
yield Ok(Event::default().data("[DONE]")); yield Ok(Event::default().data("[DONE]"));
state.mitm_store.remove_request(&cascade_id).await; state.mitm_store.remove_request(&cascade_id).await;
if let Some(ref t) = trace {
let (ipt, opt, crt2, tht) = if let Some(ref u) = mitm {
(u.input_tokens, u.output_tokens, u.cache_read_input_tokens, u.thinking_output_tokens)
} else { (0, 0, 0, 0) };
t.record_response(0, crate::trace::ResponseSummary {
text_len: 0, thinking_len: acc_thinking.len(),
text_preview: String::new(),
finish_reason: Some("stop".to_string()),
function_calls: Vec::new(), grounding: false,
}).await;
t.set_usage(crate::trace::TrackedUsage { input_tokens: ipt, output_tokens: opt, thinking_tokens: tht, cache_read: crt2 }).await;
t.finish("thinking_timeout").await;
}
return; return;
} }
// Don't break — wait for more channel events // Don't break — wait for more channel events
@@ -860,6 +949,14 @@ async fn chat_completions_stream(
))); )));
yield Ok(Event::default().data("[DONE]")); yield Ok(Event::default().data("[DONE]"));
state.mitm_store.remove_request(&cascade_id).await; state.mitm_store.remove_request(&cascade_id).await;
if let Some(ref t) = trace {
t.record_response(0, crate::trace::ResponseSummary {
text_len: 0, thinking_len: 0, text_preview: String::new(),
finish_reason: Some("stop".to_string()),
function_calls: Vec::new(), grounding: false,
}).await;
t.finish("empty_response").await;
}
return; return;
} }
continue 'channel_loop; continue 'channel_loop;
@@ -900,6 +997,15 @@ async fn chat_completions_stream(
))); )));
} }
yield Ok(Event::default().data("[DONE]")); yield Ok(Event::default().data("[DONE]"));
if let Some(ref t) = trace {
t.record_response(0, crate::trace::ResponseSummary {
text_len: last_text.len(), thinking_len: last_thinking_len,
text_preview: last_text.chars().take(200).collect(),
finish_reason: Some("stop".to_string()),
function_calls: Vec::new(), grounding: false,
}).await;
t.finish("channel_closed").await;
}
return; return;
} else { } else {
// ── Fallback: LS steps (no MITM capture active) ── // ── Fallback: LS steps (no MITM capture active) ──
@@ -1046,6 +1152,7 @@ async fn chat_completions_sync(
model_name: String, model_name: String,
cascade_id: String, cascade_id: String,
timeout: u64, timeout: u64,
trace: Option<crate::trace::TraceHandle>,
) -> axum::response::Response { ) -> axum::response::Response {
let result = poll_for_response(&state, &cascade_id, timeout).await; let result = poll_for_response(&state, &cascade_id, timeout).await;
if let Some(ref err) = result.upstream_error { if let Some(ref err) = result.upstream_error {
@@ -1084,6 +1191,27 @@ async fn chat_completions_sync(
message["reasoning_content"] = serde_json::json!(thinking); message["reasoning_content"] = serde_json::json!(thinking);
} }
// Record trace data
if let Some(ref t) = trace {
t.record_response(0, crate::trace::ResponseSummary {
text_len: result.text.len(),
thinking_len: result.thinking.as_ref().map_or(0, |s| s.len()),
text_preview: result.text.chars().take(200).collect(),
finish_reason: Some(finish_reason.to_string()),
function_calls: Vec::new(),
grounding: false,
}).await;
if prompt_tokens > 0 || completion_tokens > 0 {
t.set_usage(crate::trace::TrackedUsage {
input_tokens: prompt_tokens,
output_tokens: completion_tokens,
thinking_tokens: thinking_tokens,
cache_read: cached_tokens,
}).await;
}
t.finish("completed").await;
}
Json(serde_json::json!({ Json(serde_json::json!({
"id": completion_id, "id": completion_id,
"object": "chat.completion", "object": "chat.completion",

View File

@@ -16,7 +16,7 @@ use axum::{
}; };
use rand::Rng; use rand::Rng;
use std::sync::Arc; use std::sync::Arc;
use tracing::{debug, info}; use tracing::{debug, info, warn};
use super::models::{lookup_model, DEFAULT_MODEL, MODELS}; use super::models::{lookup_model, DEFAULT_MODEL, MODELS};
use super::polling::{ use super::polling::{
@@ -30,8 +30,15 @@ use crate::mitm::store::PendingToolResult;
#[derive(serde::Deserialize)] #[derive(serde::Deserialize)]
pub(crate) struct GeminiRequest { pub(crate) struct GeminiRequest {
pub model: Option<String>, pub model: Option<String>,
/// User input text. /// User input text (our custom format).
pub input: serde_json::Value, #[serde(default)]
pub input: Option<serde_json::Value>,
/// Official Gemini API format: [{"role": "user", "parts": [{"text": "..."}]}]
#[serde(default)]
pub contents: Option<Vec<serde_json::Value>>,
/// Shorthand: single text message (alias for simple requests).
#[serde(default)]
pub message: Option<String>,
/// Gemini-native tools: [{"functionDeclarations": [...]}] /// Gemini-native tools: [{"functionDeclarations": [...]}]
#[serde(default)] #[serde(default)]
pub tools: Option<Vec<serde_json::Value>>, pub tools: Option<Vec<serde_json::Value>>,
@@ -111,6 +118,14 @@ async fn build_usage_metadata(
} }
} }
/// POST /v1/gemini — simple custom endpoint
pub(crate) async fn handle_gemini(
State(state): State<Arc<AppState>>,
Json(body): Json<GeminiRequest>,
) -> axum::response::Response {
handle_gemini_inner(state, body).await
}
/// POST /v1beta/*path — handles both :generateContent and :streamGenerateContent /// POST /v1beta/*path — handles both :generateContent and :streamGenerateContent
/// ///
/// Parses paths like: /// Parses paths like:
@@ -185,58 +200,105 @@ async fn handle_gemini_inner(
); );
} }
// Extract user text and optional image // Extract user text and optional image.
// Priority: contents (official Gemini API) > input (our format) > message (shorthand)
let mut image: Option<crate::proto::ImageData> = None; let mut image: Option<crate::proto::ImageData> = None;
let user_text = match &body.input { let user_text = if let Some(ref contents) = body.contents {
serde_json::Value::String(s) => s.clone(), // Official Gemini API format: [{"role": "user", "parts": [{"text": "..."}]}]
serde_json::Value::Array(arr) => { // Extract text from the last user message.
// Support array input: strings, {text: "..."}, or {inlineData: {mimeType, data}} let mut text_parts: Vec<String> = Vec::new();
let mut parts: Vec<String> = Vec::new(); for content in contents.iter().rev() {
for item in arr { let role = content.get("role").and_then(|r| r.as_str()).unwrap_or("user");
match item { if role != "user" { continue; }
serde_json::Value::String(s) => parts.push(s.clone()), if let Some(parts) = content.get("parts").and_then(|p| p.as_array()) {
serde_json::Value::Object(obj) => { for part in parts {
if let Some(text) = obj.get("text").and_then(|v| v.as_str()) { if let Some(text) = part.get("text").and_then(|t| t.as_str()) {
parts.push(text.to_string()); text_parts.push(text.to_string());
} }
// Gemini-native inlineData format // Handle inlineData image
if image.is_none() { if image.is_none() {
if let Some(inline) = obj.get("inlineData") { if let Some(inline) = part.get("inlineData") {
if let (Some(mime), Some(b64)) = if let (Some(mime), Some(b64)) =
(inline["mimeType"].as_str(), inline["data"].as_str()) (inline["mimeType"].as_str(), inline["data"].as_str())
{ {
if let Some(img) = super::util::parse_data_uri(&format!( if let Some(img) = super::util::parse_data_uri(&format!(
"data:{mime};base64,{b64}" "data:{mime};base64,{b64}"
)) { )) {
image = Some(img); image = Some(img);
}
} }
} }
// Also support OpenAI-style image_url in Gemini input
if let Some(img) = super::util::extract_image_from_content(item) {
image = Some(img);
}
} }
} }
_ => {}
} }
} }
if parts.is_empty() { if !text_parts.is_empty() { break; }
return err_response(
StatusCode::BAD_REQUEST,
"Gemini input array contains no text parts".to_string(),
"invalid_request_error",
);
}
parts.join("\n")
} }
_ => { if text_parts.is_empty() {
return err_response( return err_response(
StatusCode::BAD_REQUEST, StatusCode::BAD_REQUEST,
"Gemini endpoint requires input as a string or array of text parts".to_string(), "No text found in contents array".to_string(),
"invalid_request_error", "invalid_request_error",
); );
} }
text_parts.join("\n")
} else if let Some(ref input) = body.input {
// Our custom format: input as string or array
match input {
serde_json::Value::String(s) => s.clone(),
serde_json::Value::Array(arr) => {
let mut parts: Vec<String> = Vec::new();
for item in arr {
match item {
serde_json::Value::String(s) => parts.push(s.clone()),
serde_json::Value::Object(obj) => {
if let Some(text) = obj.get("text").and_then(|v| v.as_str()) {
parts.push(text.to_string());
}
if image.is_none() {
if let Some(inline) = obj.get("inlineData") {
if let (Some(mime), Some(b64)) =
(inline["mimeType"].as_str(), inline["data"].as_str())
{
if let Some(img) = super::util::parse_data_uri(&format!(
"data:{mime};base64,{b64}"
)) {
image = Some(img);
}
}
}
if let Some(img) = super::util::extract_image_from_content(item) {
image = Some(img);
}
}
}
_ => {}
}
}
if parts.is_empty() {
return err_response(
StatusCode::BAD_REQUEST,
"Gemini input array contains no text parts".to_string(),
"invalid_request_error",
);
}
parts.join("\n")
}
_ => {
return err_response(
StatusCode::BAD_REQUEST,
"Gemini input must be a string or array of text parts".to_string(),
"invalid_request_error",
);
}
}
} else if let Some(ref msg) = body.message {
msg.clone()
} else {
return err_response(
StatusCode::BAD_REQUEST,
"Request must include 'contents' (Gemini API), 'input', or 'message'".to_string(),
"invalid_request_error",
);
}; };
// ── Build per-request state locally ────────────────────────────────── // ── Build per-request state locally ──────────────────────────────────
@@ -320,14 +382,9 @@ async fn handle_gemini_inner(
} }
}); });
// Build event channel for streaming // Build event channel — always created for MITM response path
let has_custom_tools = tools.is_some(); let (tx, rx) = tokio::sync::mpsc::channel(64);
let (mitm_rx, event_tx) = if has_custom_tools { let (mitm_rx, event_tx) = (Some(rx), tx);
let (tx, rx) = tokio::sync::mpsc::channel(64);
(Some(rx), Some(tx))
} else {
(None, None)
};
// Build tool rounds now that cascade_id is known // Build tool rounds now that cascade_id is known
let mut tool_rounds: Vec<crate::mitm::store::ToolRound> = Vec::new(); let mut tool_rounds: Vec<crate::mitm::store::ToolRound> = Vec::new();
@@ -340,7 +397,23 @@ async fn handle_gemini_inner(
}); });
} }
// Register all per-request state atomically // Start debug trace
let trace = state.trace.start(&cascade_id, "POST gemini", &model_name, body.stream);
if let Some(ref t) = trace {
t.set_client_request(crate::trace::ClientRequestSummary {
message_count: 1,
tool_count: body.tools.as_ref().map_or(0, |t| t.len()),
tool_round_count: tool_rounds.len(),
user_text_len: user_text.len(),
user_text_preview: user_text.chars().take(200).collect(),
system_prompt: false,
has_image: image.is_some(),
}).await;
t.start_turn().await;
}
let mitm_gate = std::sync::Arc::new(tokio::sync::Notify::new());
let mitm_gate_clone = mitm_gate.clone();
state.mitm_store.register_request(crate::mitm::store::RequestContext { state.mitm_store.register_request(crate::mitm::store::RequestContext {
cascade_id: cascade_id.clone(), cascade_id: cascade_id.clone(),
pending_user_text: user_text.clone(), pending_user_text: user_text.clone(),
@@ -354,6 +427,9 @@ async fn handle_gemini_inner(
last_function_calls: Vec::new(), last_function_calls: Vec::new(),
call_id_to_name: std::collections::HashMap::new(), call_id_to_name: std::collections::HashMap::new(),
created_at: std::time::Instant::now(), created_at: std::time::Instant::now(),
gate: mitm_gate_clone,
trace_handle: trace.clone(),
trace_turn: 0,
}).await; }).await;
// Send REAL user text to LS (no more dummy ".") // Send REAL user text to LS (no more dummy ".")
@@ -387,13 +463,36 @@ async fn handle_gemini_inner(
} }
} }
// Wait for MITM gate: 5s -> 502 if MITM enabled
let gate_start = std::time::Instant::now();
let gate_matched = tokio::time::timeout(
std::time::Duration::from_secs(5),
mitm_gate.notified(),
).await;
let gate_wait_ms = gate_start.elapsed().as_millis() as u64;
if gate_matched.is_err() {
if state.mitm_enabled {
state.mitm_store.remove_request(&cascade_id).await;
if let Some(ref t) = trace { t.record_error("MITM gate timeout (5s)".to_string()).await; t.finish("mitm_timeout").await; }
return err_response(
StatusCode::BAD_GATEWAY,
"MITM proxy did not match request within 5s".to_string(),
"mitm_timeout",
);
}
warn!(cascade = %cascade_id, "MITM gate timeout (--no-mitm mode)");
} else {
debug!(cascade = %cascade_id, gate_wait_ms, "MITM gate signaled -- request matched");
if let Some(ref t) = trace { t.record_mitm_match(0, gate_wait_ms).await; }
}
// Dispatch to sync or stream // Dispatch to sync or stream
let model_name = model_name.to_string(); let model_name = model_name.to_string();
let timeout = body.timeout; let timeout = body.timeout;
if body.stream { if body.stream {
gemini_stream(state, model_name, cascade_id, timeout, mitm_rx).await gemini_stream(state, model_name, cascade_id, timeout, mitm_rx, trace).await
} else { } else {
gemini_sync(state, model_name, cascade_id, timeout, mitm_rx).await gemini_sync(state, model_name, cascade_id, timeout, mitm_rx, trace).await
} }
} }
@@ -405,6 +504,7 @@ async fn gemini_sync(
cascade_id: String, cascade_id: String,
timeout: u64, timeout: u64,
mitm_rx: Option<tokio::sync::mpsc::Receiver<crate::mitm::store::MitmEvent>>, mitm_rx: Option<tokio::sync::mpsc::Receiver<crate::mitm::store::MitmEvent>>,
trace: Option<crate::trace::TraceHandle>,
) -> axum::response::Response { ) -> axum::response::Response {
// Clear stale response and upstream errors (only if no pre-installed channel) // Clear stale response and upstream errors (only if no pre-installed channel)
if mitm_rx.is_none() { if mitm_rx.is_none() {
@@ -418,6 +518,7 @@ async fn gemini_sync(
let mut acc_text = String::new(); let mut acc_text = String::new();
let mut acc_thinking: Option<String> = None; let mut acc_thinking: Option<String> = None;
let mut last_usage: Option<crate::mitm::store::ApiUsage> = None;
while let Some(event) = tokio::time::timeout( while let Some(event) = tokio::time::timeout(
std::time::Duration::from_secs(timeout.saturating_sub(start.elapsed().as_secs())), std::time::Duration::from_secs(timeout.saturating_sub(start.elapsed().as_secs())),
@@ -427,7 +528,8 @@ async fn gemini_sync(
match event { match event {
MitmEvent::ThinkingDelta(t) => { acc_thinking = Some(t); } MitmEvent::ThinkingDelta(t) => { acc_thinking = Some(t); }
MitmEvent::TextDelta(t) => { acc_text = t; } MitmEvent::TextDelta(t) => { acc_text = t; }
MitmEvent::Usage(_) | MitmEvent::Grounding(_) => {} MitmEvent::Usage(u) => { last_usage = Some(u); }
MitmEvent::Grounding(_) => {}
MitmEvent::FunctionCall(calls) => { MitmEvent::FunctionCall(calls) => {
let parts: Vec<serde_json::Value> = calls let parts: Vec<serde_json::Value> = calls
.iter() .iter()
@@ -440,6 +542,21 @@ async fn gemini_sync(
}) })
}) })
.collect(); .collect();
if let Some(ref t) = trace {
let fc_summaries: Vec<crate::trace::FunctionCallSummary> = calls.iter().map(|fc| {
crate::trace::FunctionCallSummary {
name: fc.name.clone(),
args_preview: serde_json::to_string(&fc.args).unwrap_or_default().chars().take(200).collect(),
}
}).collect();
t.record_response(0, crate::trace::ResponseSummary {
text_len: 0, thinking_len: acc_thinking.as_ref().map_or(0, |s| s.len()),
text_preview: String::new(),
finish_reason: Some("STOP".to_string()),
function_calls: fc_summaries, grounding: false,
}).await;
t.finish("tool_call").await;
}
state.mitm_store.remove_request(&cascade_id).await; state.mitm_store.remove_request(&cascade_id).await;
return Json(serde_json::json!({ return Json(serde_json::json!({
"candidates": [{ "candidates": [{
@@ -477,6 +594,18 @@ async fn gemini_sync(
parts.push(serde_json::json!({"text": t, "thought": true})); parts.push(serde_json::json!({"text": t, "thought": true}));
} }
parts.push(serde_json::json!({"text": acc_text})); parts.push(serde_json::json!({"text": acc_text}));
if let Some(ref t) = trace {
t.record_response(0, crate::trace::ResponseSummary {
text_len: acc_text.len(), thinking_len: acc_thinking.as_ref().map_or(0, |s| s.len()),
text_preview: acc_text.chars().take(200).collect(),
finish_reason: Some("STOP".to_string()),
function_calls: Vec::new(), grounding: false,
}).await;
if let Some(ref u) = last_usage {
t.set_usage(crate::trace::TrackedUsage { input_tokens: u.input_tokens, output_tokens: u.output_tokens, thinking_tokens: u.thinking_output_tokens, cache_read: u.cache_read_input_tokens }).await;
}
t.finish("completed").await;
}
state.mitm_store.remove_request(&cascade_id).await; state.mitm_store.remove_request(&cascade_id).await;
return Json(serde_json::json!({ return Json(serde_json::json!({
"candidates": [{ "candidates": [{
@@ -487,11 +616,33 @@ async fn gemini_sync(
"finishReason": "STOP", "finishReason": "STOP",
}], }],
"modelVersion": model_name, "modelVersion": model_name,
"usageMetadata": build_usage_metadata(&state.mitm_store, &cascade_id).await, "usageMetadata": if let Some(ref u) = last_usage {
serde_json::json!({
"promptTokenCount": u.input_tokens,
"candidatesTokenCount": u.output_tokens,
"totalTokenCount": u.input_tokens + u.output_tokens,
"thoughtsTokenCount": u.thinking_output_tokens,
"cachedContentTokenCount": u.cache_read_input_tokens,
})
} else {
build_usage_metadata(&state.mitm_store, &cascade_id).await
},
})) }))
.into_response(); .into_response();
} }
MitmEvent::UpstreamError(err) => { MitmEvent::UpstreamError(err) => {
if let Some(ref t) = trace {
t.record_response(0, crate::trace::ResponseSummary {
text_len: acc_text.len(), thinking_len: acc_thinking.as_ref().map_or(0, |s| s.len()),
text_preview: acc_text.chars().take(200).collect(),
finish_reason: Some("STOP".to_string()),
function_calls: Vec::new(), grounding: false,
}).await;
if let Some(ref u) = last_usage {
t.set_usage(crate::trace::TrackedUsage { input_tokens: u.input_tokens, output_tokens: u.output_tokens, thinking_tokens: u.thinking_output_tokens, cache_read: u.cache_read_input_tokens }).await;
}
t.finish("upstream_error").await;
}
state.mitm_store.remove_request(&cascade_id).await; state.mitm_store.remove_request(&cascade_id).await;
return upstream_err_response(&err); return upstream_err_response(&err);
} }
@@ -499,6 +650,10 @@ async fn gemini_sync(
} }
// Timeout // Timeout
if let Some(ref t) = trace {
t.record_error(format!("Timeout: no response after {timeout}s")).await;
t.finish("timeout").await;
}
state.mitm_store.remove_request(&cascade_id).await; state.mitm_store.remove_request(&cascade_id).await;
return ( return (
axum::http::StatusCode::GATEWAY_TIMEOUT, axum::http::StatusCode::GATEWAY_TIMEOUT,
@@ -541,6 +696,25 @@ async fn gemini_sync(
}) })
.collect(); .collect();
// Record trace
if let Some(ref t) = trace {
let fc_summaries: Vec<crate::trace::FunctionCallSummary> = calls.iter().map(|fc| {
crate::trace::FunctionCallSummary {
name: fc.name.clone(),
args_preview: serde_json::to_string(&fc.args).unwrap_or_default().chars().take(200).collect(),
}
}).collect();
t.record_response(0, crate::trace::ResponseSummary {
text_len: 0,
thinking_len: 0,
text_preview: String::new(),
finish_reason: Some("STOP".to_string()),
function_calls: fc_summaries,
grounding: false,
}).await;
t.finish("tool_call").await;
}
return Json(serde_json::json!({ return Json(serde_json::json!({
"candidates": [{ "candidates": [{
"content": { "content": {
@@ -562,6 +736,19 @@ async fn gemini_sync(
} }
parts.push(serde_json::json!({"text": poll_result.text})); parts.push(serde_json::json!({"text": poll_result.text}));
// Record trace
if let Some(ref t) = trace {
t.record_response(0, crate::trace::ResponseSummary {
text_len: poll_result.text.len(),
thinking_len: poll_result.thinking.as_ref().map_or(0, |s| s.len()),
text_preview: poll_result.text.chars().take(200).collect(),
finish_reason: Some("STOP".to_string()),
function_calls: Vec::new(),
grounding: false,
}).await;
t.finish("completed").await;
}
Json(serde_json::json!({ Json(serde_json::json!({
"candidates": [{ "candidates": [{
"content": { "content": {
@@ -584,11 +771,13 @@ async fn gemini_stream(
cascade_id: String, cascade_id: String,
timeout: u64, timeout: u64,
mitm_rx: Option<tokio::sync::mpsc::Receiver<crate::mitm::store::MitmEvent>>, mitm_rx: Option<tokio::sync::mpsc::Receiver<crate::mitm::store::MitmEvent>>,
trace: Option<crate::trace::TraceHandle>,
) -> axum::response::Response { ) -> axum::response::Response {
let stream = async_stream::stream! { let stream = async_stream::stream! {
let start = std::time::Instant::now(); let start = std::time::Instant::now();
let mut last_text = String::new(); let mut last_text = String::new();
let mut last_thinking = String::new(); let mut last_thinking = String::new();
let mut last_usage: Option<crate::mitm::store::ApiUsage> = None;
// Clear stale response (only if no pre-installed channel) // Clear stale response (only if no pre-installed channel)
if mitm_rx.is_none() { if mitm_rx.is_none() {
@@ -665,13 +854,34 @@ async fn gemini_stream(
"modelVersion": model_name, "modelVersion": model_name,
})).unwrap_or_default())); })).unwrap_or_default()));
yield Ok(Event::default().data("[DONE]")); yield Ok(Event::default().data("[DONE]"));
if let Some(ref t) = trace {
let fc_summaries: Vec<crate::trace::FunctionCallSummary> = calls.iter().map(|fc| crate::trace::FunctionCallSummary {
name: fc.name.clone(), args_preview: serde_json::to_string(&fc.args).unwrap_or_default().chars().take(200).collect(),
}).collect();
t.record_response(0, crate::trace::ResponseSummary {
text_len: 0, thinking_len: last_thinking.len(), text_preview: String::new(),
finish_reason: Some("STOP".to_string()),
function_calls: fc_summaries, grounding: false,
}).await;
t.finish("tool_call").await;
}
state.mitm_store.remove_request(&cascade_id).await; state.mitm_store.remove_request(&cascade_id).await;
return; return;
} }
MitmEvent::ResponseComplete => { MitmEvent::ResponseComplete => {
if !last_text.is_empty() { if !last_text.is_empty() {
// Final chunk with finishReason + usageMetadata // Final chunk with finishReason + usageMetadata
let usage_meta = build_usage_metadata(&state.mitm_store, &cascade_id).await; let usage_meta = if let Some(ref u) = last_usage {
serde_json::json!({
"promptTokenCount": u.input_tokens,
"candidatesTokenCount": u.output_tokens,
"totalTokenCount": u.input_tokens + u.output_tokens,
"thoughtsTokenCount": u.thinking_output_tokens,
"cachedContentTokenCount": u.cache_read_input_tokens,
})
} else {
build_usage_metadata(&state.mitm_store, &cascade_id).await
};
yield Ok(Event::default().data(serde_json::to_string(&serde_json::json!({ yield Ok(Event::default().data(serde_json::to_string(&serde_json::json!({
"candidates": [{ "candidates": [{
"content": { "content": {
@@ -684,6 +894,15 @@ async fn gemini_stream(
"modelVersion": model_name, "modelVersion": model_name,
})).unwrap_or_default())); })).unwrap_or_default()));
yield Ok(Event::default().data("[DONE]")); yield Ok(Event::default().data("[DONE]"));
if let Some(ref t) = trace {
t.record_response(0, crate::trace::ResponseSummary {
text_len: last_text.len(), thinking_len: last_thinking.len(),
text_preview: last_text.chars().take(200).collect(),
finish_reason: Some("STOP".to_string()),
function_calls: Vec::new(), grounding: false,
}).await;
t.finish("completed").await;
}
state.mitm_store.remove_request(&cascade_id).await; state.mitm_store.remove_request(&cascade_id).await;
return; return;
} else if !last_thinking.is_empty() && !did_unblock_ls { } else if !last_thinking.is_empty() && !did_unblock_ls {
@@ -714,10 +933,15 @@ async fn gemini_stream(
} }
})).unwrap())); })).unwrap()));
yield Ok(Event::default().data("[DONE]")); yield Ok(Event::default().data("[DONE]"));
if let Some(ref t) = trace {
t.record_error(format!("Upstream: {}", error_msg)).await;
t.finish("upstream_error").await;
}
state.mitm_store.remove_request(&cascade_id).await; state.mitm_store.remove_request(&cascade_id).await;
return; return;
} }
MitmEvent::Usage(_) | MitmEvent::Grounding(_) => {} MitmEvent::Usage(u) => { last_usage = Some(u); }
MitmEvent::Grounding(_) => {}
} }
} }
@@ -730,6 +954,10 @@ async fn gemini_stream(
"code": 504, "code": 504,
} }
})).unwrap())); })).unwrap()));
if let Some(ref t) = trace {
t.record_error(format!("Timeout: {timeout}s")).await;
t.finish("timeout").await;
}
yield Ok(Event::default().data("[DONE]")); yield Ok(Event::default().data("[DONE]"));
return; return;
} }

View File

@@ -33,6 +33,10 @@ pub struct AppState {
pub sessions: SessionManager, pub sessions: SessionManager,
pub mitm_store: crate::mitm::store::MitmStore, pub mitm_store: crate::mitm::store::MitmStore,
pub quota_store: crate::quota::QuotaStore, pub quota_store: crate::quota::QuotaStore,
/// Whether the MITM proxy is active (false when --no-mitm).
pub mitm_enabled: bool,
/// Per-call debug trace collector.
pub trace: crate::trace::TraceCollector,
} }
// ─── Router ────────────────────────────────────────────────────────────────── // ─── Router ──────────────────────────────────────────────────────────────────
@@ -44,6 +48,7 @@ pub fn router(state: Arc<AppState>) -> Router {
"/v1/chat/completions", "/v1/chat/completions",
post(completions::handle_completions), post(completions::handle_completions),
) )
.route("/v1/gemini", post(gemini::handle_gemini))
.route( .route(
"/v1beta/{*path}", "/v1beta/{*path}",
post(gemini::handle_gemini_v1beta), post(gemini::handle_gemini_v1beta),

View File

@@ -11,7 +11,7 @@ use axum::{
use rand::Rng; use rand::Rng;
use std::sync::atomic::{AtomicU32, Ordering}; use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc; use std::sync::Arc;
use tracing::{debug, info}; use tracing::{debug, info, warn};
use super::models::{lookup_model, DEFAULT_MODEL, MODELS}; use super::models::{lookup_model, DEFAULT_MODEL, MODELS};
use super::polling::{ use super::polling::{
@@ -364,14 +364,9 @@ pub(crate) async fn handle_responses(
} }
}); });
// Build event channel // Build event channel — always created for MITM response path
let has_custom_tools = tools.is_some(); let (tx, rx) = tokio::sync::mpsc::channel(64);
let (mitm_rx, event_tx) = if has_custom_tools { let (mitm_rx, event_tx) = (Some(rx), tx);
let (tx, rx) = tokio::sync::mpsc::channel(64);
(Some(rx), Some(tx))
} else {
(None, None)
};
// Build tool rounds now that cascade_id is known // Build tool rounds now that cascade_id is known
let mut tool_rounds: Vec<crate::mitm::store::ToolRound> = Vec::new(); let mut tool_rounds: Vec<crate::mitm::store::ToolRound> = Vec::new();
@@ -385,7 +380,23 @@ pub(crate) async fn handle_responses(
}); });
} }
// Register all per-request state atomically // Start debug trace
let trace = state.trace.start(&cascade_id, "POST /v1/responses", &model.name, body.stream);
if let Some(ref t) = trace {
t.set_client_request(crate::trace::ClientRequestSummary {
message_count: if is_tool_result_turn { 0 } else { 1 },
tool_count: body.tools.as_ref().map_or(0, |t| t.len()),
tool_round_count: tool_rounds.len(),
user_text_len: user_text.len(),
user_text_preview: user_text.chars().take(200).collect(),
system_prompt: body.instructions.is_some(),
has_image: image.is_some(),
}).await;
t.start_turn().await;
}
let mitm_gate = std::sync::Arc::new(tokio::sync::Notify::new());
let mitm_gate_clone = mitm_gate.clone();
state.mitm_store.register_request(crate::mitm::store::RequestContext { state.mitm_store.register_request(crate::mitm::store::RequestContext {
cascade_id: cascade_id.clone(), cascade_id: cascade_id.clone(),
pending_user_text: user_text.clone(), pending_user_text: user_text.clone(),
@@ -399,6 +410,9 @@ pub(crate) async fn handle_responses(
last_function_calls: Vec::new(), last_function_calls: Vec::new(),
call_id_to_name: std::collections::HashMap::new(), call_id_to_name: std::collections::HashMap::new(),
created_at: std::time::Instant::now(), created_at: std::time::Instant::now(),
gate: mitm_gate_clone,
trace_handle: trace.clone(),
trace_turn: 0,
}).await; }).await;
// Send REAL user text to LS // Send REAL user text to LS
@@ -432,6 +446,29 @@ pub(crate) async fn handle_responses(
} }
} }
// Wait for MITM gate: 5s → 502 if MITM enabled
let gate_start = std::time::Instant::now();
let gate_matched = tokio::time::timeout(
std::time::Duration::from_secs(5),
mitm_gate.notified(),
).await;
let gate_wait_ms = gate_start.elapsed().as_millis() as u64;
if gate_matched.is_err() {
if state.mitm_enabled {
state.mitm_store.remove_request(&cascade_id).await;
if let Some(ref t) = trace { t.record_error("MITM gate timeout (5s)".to_string()).await; t.finish("mitm_timeout").await; }
return err_response(
StatusCode::BAD_GATEWAY,
"MITM proxy did not match request within 5s".to_string(),
"mitm_timeout",
);
}
warn!(cascade = %cascade_id, "MITM gate timeout (--no-mitm mode)");
} else {
debug!(cascade = %cascade_id, gate_wait_ms, "MITM gate signaled — request matched");
if let Some(ref t) = trace { t.record_mitm_match(0, gate_wait_ms).await; }
}
// Capture request params for response building // Capture request params for response building
let req_params = RequestParams { let req_params = RequestParams {
user_text: user_text.clone(), user_text: user_text.clone(),
@@ -462,6 +499,7 @@ pub(crate) async fn handle_responses(
body.timeout, body.timeout,
req_params, req_params,
mitm_rx, mitm_rx,
trace,
) )
.await .await
} else { } else {
@@ -473,6 +511,7 @@ pub(crate) async fn handle_responses(
body.timeout, body.timeout,
req_params, req_params,
mitm_rx, mitm_rx,
trace,
) )
.await .await
} }
@@ -595,6 +634,7 @@ async fn handle_responses_sync(
timeout: u64, timeout: u64,
params: RequestParams, params: RequestParams,
mitm_rx: Option<tokio::sync::mpsc::Receiver<crate::mitm::store::MitmEvent>>, mitm_rx: Option<tokio::sync::mpsc::Receiver<crate::mitm::store::MitmEvent>>,
trace: Option<crate::trace::TraceHandle>,
) -> axum::response::Response { ) -> axum::response::Response {
let created_at = now_unix(); let created_at = now_unix();
@@ -642,6 +682,30 @@ async fn handle_responses_sync(
&state.mitm_store, &cascade_id, &None, &params.user_text, "", &state.mitm_store, &cascade_id, &None, &params.user_text, "",
).await; ).await;
state.mitm_store.remove_request(&cascade_id).await; state.mitm_store.remove_request(&cascade_id).await;
// Record trace before usage is moved
if let Some(ref t) = trace {
let fc_summaries: Vec<crate::trace::FunctionCallSummary> = calls.iter().map(|fc| {
crate::trace::FunctionCallSummary {
name: fc.name.clone(),
args_preview: serde_json::to_string(&fc.args).unwrap_or_default().chars().take(200).collect(),
}
}).collect();
t.record_response(0, crate::trace::ResponseSummary {
text_len: 0,
thinking_len: 0,
text_preview: String::new(),
finish_reason: Some("tool_calls".to_string()),
function_calls: fc_summaries,
grounding: false,
}).await;
t.set_usage(crate::trace::TrackedUsage {
input_tokens: usage.input_tokens,
output_tokens: usage.output_tokens,
thinking_tokens: usage.output_tokens_details.reasoning_tokens,
cache_read: usage.input_tokens_details.cached_tokens,
}).await;
t.finish("tool_call").await;
}
let resp = build_response_object( let resp = build_response_object(
ResponseData { ResponseData {
id: response_id, id: response_id,
@@ -688,6 +752,24 @@ async fn handle_responses_sync(
let msg_id = format!("msg_{}", uuid::Uuid::new_v4().to_string().replace('-', "")); let msg_id = format!("msg_{}", uuid::Uuid::new_v4().to_string().replace('-', ""));
output_items.push(build_message_output(&msg_id, &acc_text)); output_items.push(build_message_output(&msg_id, &acc_text));
// Record trace before usage is moved
if let Some(ref t) = trace {
t.record_response(0, crate::trace::ResponseSummary {
text_len: acc_text.len(),
thinking_len: acc_thinking.as_ref().map_or(0, |s| s.len()),
text_preview: acc_text.chars().take(200).collect(),
finish_reason: Some("stop".to_string()),
function_calls: Vec::new(),
grounding: false,
}).await;
t.set_usage(crate::trace::TrackedUsage {
input_tokens: usage.input_tokens,
output_tokens: usage.output_tokens,
thinking_tokens: usage.output_tokens_details.reasoning_tokens,
cache_read: usage.input_tokens_details.cached_tokens,
}).await;
t.finish("completed").await;
}
let resp = build_response_object( let resp = build_response_object(
ResponseData { ResponseData {
id: response_id, id: response_id,
@@ -705,6 +787,7 @@ async fn handle_responses_sync(
} }
MitmEvent::UpstreamError(err) => { MitmEvent::UpstreamError(err) => {
state.mitm_store.remove_request(&cascade_id).await; state.mitm_store.remove_request(&cascade_id).await;
if let Some(ref t) = trace { t.record_error(format!("Upstream: {}", err.message.as_deref().unwrap_or("unknown"))).await; t.finish("upstream_error").await; }
return upstream_err_response(&err); return upstream_err_response(&err);
} }
} }
@@ -712,6 +795,7 @@ async fn handle_responses_sync(
// Timeout // Timeout
state.mitm_store.remove_request(&cascade_id).await; state.mitm_store.remove_request(&cascade_id).await;
if let Some(ref t) = trace { t.record_error(format!("Timeout: {}s", timeout)).await; t.finish("timeout").await; }
return err_response( return err_response(
StatusCode::GATEWAY_TIMEOUT, StatusCode::GATEWAY_TIMEOUT,
format!("Timeout: no response from Google API after {timeout}s"), format!("Timeout: no response from Google API after {timeout}s"),
@@ -772,6 +856,31 @@ async fn handle_responses_sync(
) )
.await; .await;
// Record trace before usage is moved
if let Some(ref t) = trace {
let fc_summaries: Vec<crate::trace::FunctionCallSummary> = calls.iter().map(|fc| {
crate::trace::FunctionCallSummary {
name: fc.name.clone(),
args_preview: serde_json::to_string(&fc.args).unwrap_or_default().chars().take(200).collect(),
}
}).collect();
t.record_response(0, crate::trace::ResponseSummary {
text_len: poll_result.text.len(),
thinking_len: poll_result.thinking.as_ref().map_or(0, |s| s.len()),
text_preview: String::new(),
finish_reason: Some("tool_calls".to_string()),
function_calls: fc_summaries,
grounding: false,
}).await;
t.set_usage(crate::trace::TrackedUsage {
input_tokens: usage.input_tokens,
output_tokens: usage.output_tokens,
thinking_tokens: usage.output_tokens_details.reasoning_tokens,
cache_read: usage.input_tokens_details.cached_tokens,
}).await;
t.finish("tool_call").await;
}
let resp = build_response_object( let resp = build_response_object(
ResponseData { ResponseData {
id: response_id, id: response_id,
@@ -809,6 +918,25 @@ async fn handle_responses_sync(
} }
output_items.push(build_message_output(&msg_id, &poll_result.text)); output_items.push(build_message_output(&msg_id, &poll_result.text));
// Record trace before usage is moved
if let Some(ref t) = trace {
t.record_response(0, crate::trace::ResponseSummary {
text_len: poll_result.text.len(),
thinking_len: thinking_text.as_ref().map_or(0, |s| s.len()),
text_preview: poll_result.text.chars().take(200).collect(),
finish_reason: Some("stop".to_string()),
function_calls: Vec::new(),
grounding: false,
}).await;
t.set_usage(crate::trace::TrackedUsage {
input_tokens: usage.input_tokens,
output_tokens: usage.output_tokens,
thinking_tokens: usage.output_tokens_details.reasoning_tokens,
cache_read: usage.input_tokens_details.cached_tokens,
}).await;
t.finish("completed").await;
}
let resp = build_response_object( let resp = build_response_object(
ResponseData { ResponseData {
id: response_id, id: response_id,
@@ -836,6 +964,7 @@ async fn handle_responses_stream(
timeout: u64, timeout: u64,
params: RequestParams, params: RequestParams,
mitm_rx: Option<tokio::sync::mpsc::Receiver<crate::mitm::store::MitmEvent>>, mitm_rx: Option<tokio::sync::mpsc::Receiver<crate::mitm::store::MitmEvent>>,
trace: Option<crate::trace::TraceHandle>,
) -> axum::response::Response { ) -> axum::response::Response {
let stream = async_stream::stream! { let stream = async_stream::stream! {
let msg_id = format!("msg_{}", uuid::Uuid::new_v4().to_string().replace('-', "")); let msg_id = format!("msg_{}", uuid::Uuid::new_v4().to_string().replace('-', ""));
@@ -1111,6 +1240,14 @@ async fn handle_responses_stream(
&params.user_text, "", &params.user_text, "",
).await; ).await;
// Save trace usage before move
let trace_usage = crate::trace::TrackedUsage {
input_tokens: usage.input_tokens,
output_tokens: usage.output_tokens,
thinking_tokens: usage.output_tokens_details.reasoning_tokens,
cache_read: usage.input_tokens_details.cached_tokens,
};
let final_resp = build_response_object( let final_resp = build_response_object(
ResponseData { ResponseData {
id: response_id.clone(), id: response_id.clone(),
@@ -1132,6 +1269,19 @@ async fn handle_responses_stream(
"response": response_to_json(&final_resp), "response": response_to_json(&final_resp),
}), }),
)); ));
if let Some(ref t) = trace {
let fc_summaries: Vec<crate::trace::FunctionCallSummary> = calls.iter().map(|fc| crate::trace::FunctionCallSummary {
name: fc.name.clone(), args_preview: serde_json::to_string(&fc.args).unwrap_or_default().chars().take(200).collect(),
}).collect();
t.record_response(0, crate::trace::ResponseSummary {
text_len: 0, thinking_len: last_thinking.len(),
text_preview: String::new(),
finish_reason: Some("tool_calls".to_string()),
function_calls: fc_summaries, grounding: false,
}).await;
t.set_usage(trace_usage).await;
t.finish("tool_call").await;
}
state.mitm_store.remove_request(&cascade_id).await; state.mitm_store.remove_request(&cascade_id).await;
return; return;
} }
@@ -1150,6 +1300,16 @@ async fn handle_responses_stream(
) { ) {
yield Ok(evt); yield Ok(evt);
} }
if let Some(ref t) = trace {
t.record_response(0, crate::trace::ResponseSummary {
text_len: last_text.len(),
thinking_len: thinking_text.as_ref().map_or(0, |s| s.len()),
text_preview: last_text.chars().take(200).collect(),
finish_reason: Some("stop".to_string()),
function_calls: Vec::new(), grounding: false,
}).await;
t.finish("completed").await;
}
state.mitm_store.remove_request(&cascade_id).await; state.mitm_store.remove_request(&cascade_id).await;
return; return;
} else if !last_thinking.is_empty() { } else if !last_thinking.is_empty() {
@@ -1186,6 +1346,10 @@ async fn handle_responses_stream(
}, },
}), }),
)); ));
if let Some(ref t) = trace {
t.record_error(format!("Upstream: {}", error_msg)).await;
t.finish("upstream_error").await;
}
state.mitm_store.remove_request(&cascade_id).await; state.mitm_store.remove_request(&cascade_id).await;
return; return;
} }
@@ -1213,6 +1377,10 @@ async fn handle_responses_stream(
}, },
}), }),
)); ));
if let Some(ref t) = trace {
t.record_error(format!("Timeout: {timeout}s")).await;
t.finish("timeout").await;
}
return; return;
} }

View File

@@ -138,12 +138,29 @@ async fn do_search(state: Arc<AppState>, body: SearchRequest) -> axum::response:
} }
}; };
// Register per-request state — no tools, just generation params for search grounding // Start debug trace
let trace = state.trace.start(&cascade_id, "POST /v1/search", model.name, false);
if let Some(ref t) = trace {
t.set_client_request(crate::trace::ClientRequestSummary {
message_count: 1,
tool_count: 0,
tool_round_count: 0,
user_text_len: body.query.len(),
user_text_preview: body.query.chars().take(200).collect(),
system_prompt: false,
has_image: false,
}).await;
t.start_turn().await;
}
let mitm_gate = std::sync::Arc::new(tokio::sync::Notify::new());
let mitm_gate_clone = mitm_gate.clone();
let (mitm_tx, mut mitm_rx) = tokio::sync::mpsc::channel(64);
state.mitm_store.register_request(crate::mitm::store::RequestContext { state.mitm_store.register_request(crate::mitm::store::RequestContext {
cascade_id: cascade_id.clone(), cascade_id: cascade_id.clone(),
pending_user_text: search_prompt.clone(), pending_user_text: search_prompt.clone(),
event_channel: None, event_channel: mitm_tx,
generation_params: Some(gp), generation_params: Some(gp.clone()),
pending_image: None, pending_image: None,
tools: None, tools: None,
tool_config: None, tool_config: None,
@@ -152,6 +169,9 @@ async fn do_search(state: Arc<AppState>, body: SearchRequest) -> axum::response:
last_function_calls: Vec::new(), last_function_calls: Vec::new(),
call_id_to_name: std::collections::HashMap::new(), call_id_to_name: std::collections::HashMap::new(),
created_at: std::time::Instant::now(), created_at: std::time::Instant::now(),
gate: mitm_gate_clone,
trace_handle: trace.clone(),
trace_turn: 0,
}).await; }).await;
// Send dot to LS — real search prompt injected by MITM proxy // Send dot to LS — real search prompt injected by MITM proxy
@@ -168,32 +188,176 @@ async fn do_search(state: Arc<AppState>, body: SearchRequest) -> axum::response:
); );
} }
// Poll for response // ── Strict timeout cascade ───────────────────────────────────────────────
// 5s gate → MITM didn't match → 502
let gate_matched = tokio::time::timeout(
std::time::Duration::from_secs(5),
mitm_gate.notified(),
).await;
if gate_matched.is_err() {
if state.mitm_enabled {
state.mitm_store.remove_request(&cascade_id).await;
return err_response(
StatusCode::BAD_GATEWAY,
"MITM proxy did not match request within 5s".to_string(),
"mitm_timeout",
);
}
// --no-mitm fallback: use polling
tracing::warn!(cascade = %cascade_id, "MITM gate timeout (--no-mitm mode, falling back to polling)");
}
// ── Channel-based response path (primary) ────────────────────────────────
if state.mitm_enabled {
let timeout = body.timeout;
let mut response_text = String::new();
let mut last_usage: Option<crate::mitm::store::ApiUsage> = None;
let mut retries = 0u32;
const MAX_RETRIES: u32 = 3;
while let Some(event) = tokio::time::timeout(
std::time::Duration::from_secs(timeout),
mitm_rx.recv(),
).await.ok().flatten() {
use crate::mitm::store::MitmEvent;
match event {
MitmEvent::TextDelta(t) => { response_text.push_str(&t); }
MitmEvent::ThinkingDelta(_) => {} // search doesn't use thinking
MitmEvent::Usage(u) => { last_usage = Some(u); }
MitmEvent::Grounding(_) => {} // stored by proxy directly
MitmEvent::FunctionCall(_) => {} // not expected for search
MitmEvent::ResponseComplete => {
// Check if we got actual content — if not, this was a
// thinking-only intermediate response. The LS will make
// a follow-up request; re-register context and keep waiting.
let grounding_peek = state.mitm_store.peek_grounding().await;
if response_text.is_empty() && grounding_peek.is_none() {
retries += 1;
if retries >= MAX_RETRIES {
tracing::warn!(cascade = %cascade_id, retries, "Search: max retries reached with no content — giving up");
break;
}
let (new_tx, new_rx) = tokio::sync::mpsc::channel(64);
let new_gate = std::sync::Arc::new(tokio::sync::Notify::new());
state.mitm_store.register_request(crate::mitm::store::RequestContext {
cascade_id: cascade_id.clone(),
pending_user_text: search_prompt.clone(),
event_channel: new_tx,
generation_params: Some(gp.clone()),
pending_image: None,
tools: None,
tool_config: None,
pending_tool_results: Vec::new(),
tool_rounds: Vec::new(),
last_function_calls: Vec::new(),
call_id_to_name: std::collections::HashMap::new(),
created_at: std::time::Instant::now(),
gate: new_gate,
trace_handle: trace.clone(),
trace_turn: 0,
}).await;
mitm_rx = new_rx;
tracing::debug!(
cascade = %cascade_id, retries,
"Search: empty response — re-registered context for follow-up"
);
continue;
}
break;
}
MitmEvent::UpstreamError(err) => {
if let Some(ref t) = trace {
t.record_error(format!("Upstream: {}", super::util::upstream_error_message(&err))).await;
t.finish("upstream_error").await;
}
state.mitm_store.remove_request(&cascade_id).await;
return upstream_err_response(&err);
}
}
}
// Extract grounding metadata (stored by dispatch_stream_events)
let grounding = state.mitm_store.take_grounding().await;
state.mitm_store.remove_request(&cascade_id).await;
if response_text.is_empty() && grounding.is_none() {
if let Some(ref t) = trace {
t.record_error(format!("Timeout: no search response after {timeout}s (retries: {retries})")).await;
t.finish("timeout").await;
}
return err_response(
StatusCode::GATEWAY_TIMEOUT,
format!("Timeout: no search response after {timeout}s"),
"upstream_error",
);
}
return {
// Finalize trace for channel-based path
if let Some(ref t) = trace {
t.record_response(0, crate::trace::ResponseSummary {
text_len: response_text.len(), thinking_len: 0,
text_preview: response_text.chars().take(200).collect(),
finish_reason: Some("stop".to_string()),
function_calls: Vec::new(), grounding: grounding.is_some(),
}).await;
if let Some((it, ot)) = last_usage.as_ref().map(|u| (u.input_tokens, u.output_tokens)) {
t.set_usage(crate::trace::TrackedUsage {
input_tokens: it, output_tokens: ot,
thinking_tokens: 0, cache_read: 0,
}).await;
}
t.finish("completed").await;
}
build_search_response(&body.query, model.name, response_text, grounding, last_usage.map(|u| (u.input_tokens, u.output_tokens)))
};
}
// ── Fallback: polling path (--no-mitm only) ──────────────────────────────
let poll_result = poll_for_response(&state, &cascade_id, body.timeout).await; let poll_result = poll_for_response(&state, &cascade_id, body.timeout).await;
if let Some(ref err) = poll_result.upstream_error { if let Some(ref err) = poll_result.upstream_error {
return upstream_err_response(err); return upstream_err_response(err);
} }
// Extract grounding metadata
let grounding = state.mitm_store.take_grounding().await; let grounding = state.mitm_store.take_grounding().await;
// The poll result text contains the model's summary (grounded response)
let response_text = if !poll_result.text.is_empty() { let response_text = if !poll_result.text.is_empty() {
poll_result.text.clone() poll_result.text.clone()
} else { } else {
// Fall back to MITM captured text
state.mitm_store.take_response_text().await.unwrap_or_default() state.mitm_store.take_response_text().await.unwrap_or_default()
}; };
// Clean up
state.mitm_store.remove_request(&cascade_id).await; state.mitm_store.remove_request(&cascade_id).await;
state.mitm_store.clear_response_async().await; state.mitm_store.clear_response_async().await;
// Build the search response // Finalize trace for polling path
if let Some(ref t) = trace {
t.record_response(0, crate::trace::ResponseSummary {
text_len: response_text.len(), thinking_len: 0,
text_preview: response_text.chars().take(200).collect(),
finish_reason: Some("stop".to_string()),
function_calls: Vec::new(), grounding: grounding.is_some(),
}).await;
t.finish("completed").await;
}
build_search_response(&body.query, model.name, response_text, grounding, poll_result.usage.map(|u| (u.input_tokens, u.output_tokens)))
}
fn build_search_response(
query: &str,
model_name: &str,
response_text: String,
grounding: Option<serde_json::Value>,
usage: Option<(u64, u64)>,
) -> axum::response::Response {
use axum::Json;
let mut response = serde_json::json!({ let mut response = serde_json::json!({
"object": "search_result", "object": "search_result",
"query": body.query, "query": query,
"model": model.name, "model": model_name,
"summary": response_text, "summary": response_text,
}); });
@@ -267,11 +431,11 @@ async fn do_search(state: Arc<AppState>, body: SearchRequest) -> axum::response:
} }
// Include usage if available // Include usage if available
if let Some(ref u) = poll_result.usage { if let Some((input, output)) = usage {
response["usage"] = serde_json::json!({ response["usage"] = serde_json::json!({
"input_tokens": u.input_tokens, "input_tokens": input,
"output_tokens": u.output_tokens, "output_tokens": output,
"total_tokens": u.input_tokens + u.output_tokens, "total_tokens": input + output,
}); });
} }

View File

@@ -12,6 +12,7 @@ mod proto;
mod quota; mod quota;
mod session; mod session;
mod standalone; mod standalone;
mod trace;
mod warmup; mod warmup;
use api::AppState; use api::AppState;
@@ -62,6 +63,10 @@ struct Cli {
/// Classic mode — requires a running Antigravity app. Alias for --no-headless. /// Classic mode — requires a running Antigravity app. Alias for --no-headless.
#[arg(long, conflicts_with = "headless")] #[arg(long, conflicts_with = "headless")]
classic: bool, classic: bool,
/// Disable per-call debug traces (on by default, writes JSON to ~/.config/antigravity-proxy/traces/)
#[arg(long)]
no_trace: bool,
} }
#[tokio::main] #[tokio::main]
@@ -272,11 +277,21 @@ async fn main() {
quota_store.clone().start_polling(Arc::clone(&backend)); quota_store.clone().start_polling(Arc::clone(&backend));
info!("Quota monitor started (polling every 60s)"); info!("Quota monitor started (polling every 60s)");
// ── Step 4c: Debug trace collector ────────────────────────────────────────
let trace_enabled = !cli.no_trace;
let trace_collector = trace::TraceCollector::new(trace_enabled);
if trace_enabled {
trace_collector.cleanup_old_traces(7);
info!("Debug tracing enabled → ~/.config/antigravity-proxy/traces/");
}
let state = Arc::new(AppState { let state = Arc::new(AppState {
backend, backend,
sessions: SessionManager::new(), sessions: SessionManager::new(),
mitm_store, mitm_store,
quota_store, quota_store,
mitm_enabled: mitm_handle.is_some(),
trace: trace_collector,
}); });
// Periodic backend refresh — keeps LS connection details fresh // Periodic backend refresh — keeps LS connection details fresh

View File

@@ -89,6 +89,8 @@ pub struct StreamingAccumulator {
pub grounding_metadata: Option<serde_json::Value>, pub grounding_metadata: Option<serde_json::Value>,
/// Buffer for reassembling lines split across TCP reads. /// Buffer for reassembling lines split across TCP reads.
pub pending_data: String, pub pending_data: String,
/// Thinking signature (base64 opaque blob) from non-function-call response parts.
pub thinking_signature: Option<String>,
} }
impl StreamingAccumulator { impl StreamingAccumulator {
@@ -150,8 +152,12 @@ impl StreamingAccumulator {
.as_secs(), .as_secs(),
}); });
} }
// Capture non-thinking response text (skip thoughtSignature parts) // Capture non-thinking response text
else if part.get("thoughtSignature").is_none() { else {
// Capture thoughtSignature from response parts (not function call parts)
if let Some(sig) = part.get("thoughtSignature").and_then(|v| v.as_str()) {
self.thinking_signature = Some(sig.to_string());
}
if let Some(text) = part["text"].as_str() { if let Some(text) = part["text"].as_str() {
if !text.is_empty() { if !text.is_empty() {
self.response_text.push_str(text); self.response_text.push_str(text);
@@ -277,6 +283,7 @@ impl StreamingAccumulator {
.duration_since(std::time::UNIX_EPOCH) .duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default() .unwrap_or_default()
.as_secs(), .as_secs(),
thinking_signature: self.thinking_signature,
} }
} }
} }
@@ -302,6 +309,7 @@ fn extract_usage_from_message(msg: &Value) -> Option<ApiUsage> {
.duration_since(std::time::UNIX_EPOCH) .duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default() .unwrap_or_default()
.as_secs(), .as_secs(),
thinking_signature: None,
}) })
} }

View File

@@ -95,6 +95,7 @@ impl GrpcUsage {
.duration_since(std::time::UNIX_EPOCH) .duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default() .unwrap_or_default()
.as_secs(), .as_secs(),
thinking_signature: None,
} }
} }
} }

View File

@@ -436,143 +436,111 @@ async fn handle_http_over_tls(
// checkpoints) from stealing the RequestContext. // checkpoints) from stealing the RequestContext.
// ── Request modification ───────────────────────────────────── // ── Request modification ─────────────────────────────────────
// Dechunk body → check if agent request → modify → rechunk // Dechunk body → check for our <cid:UUID> nonce → modify → rechunk
//
// Detection is deterministic: if the raw body bytes contain our
// <cid:CASCADE_ID> nonce tag AND we have a pending RequestContext
// for that cascade, it's our request. No JSON parsing needed.
if modify_requests && body_len > 0 { if modify_requests && body_len > 0 {
let body_slice = &request_buf[headers_end..]; let body_slice = &request_buf[headers_end..];
let raw_body = super::modify::dechunk(body_slice); let raw_body = super::modify::dechunk(body_slice);
// Only modify "agent" requests, not "checkpoint" (LS internal) // Fast nonce detection: search raw bytes for <cid:UUID> tag.
let body_str = String::from_utf8_lossy(&raw_body); // This is the sole signal — no requestType check, no
let is_agent = body_str.contains("\"requestType\":\"agent\"") // USER_REQUEST wrapper scanning, no JSON parsing for detection.
|| body_str.contains("\"requestType\": \"agent\""); let nonce_cascade = extract_cascade_hint(&raw_body);
let effective_cascade = nonce_cascade.or(cascade_hint.clone());
if is_agent { // Only take RequestContext if we found our nonce tag
// Re-extract cascade_hint from the dechunked (JSON-parseable) body. let has_nonce = effective_cascade.is_some() && {
// The chunked transfer encoding body at `request_buf[headers_end..]` let body_str = String::from_utf8_lossy(&raw_body);
// can't be JSON-parsed, but `raw_body` (dechunked) can. // The nonce is `<cid:UUID>` — check raw bytes
let precise_cascade = extract_cascade_hint(&raw_body); if let Some(ref cid) = effective_cascade {
body_str.contains(&format!("<cid:{}>", cid))
} else {
false
}
};
let mut request_ctx: Option<super::store::RequestContext> = if has_nonce {
debug!( debug!(
cascade = ?precise_cascade, cascade = ?effective_cascade,
"MITM: cascade from dechunked requestId" "MITM: nonce matched — taking RequestContext"
); );
let ctx = if let Some(ref cid) = effective_cascade {
// Check if ANY user message contains our dummy dot prompt store.take_request(cid).await
// within a <USER_REQUEST> wrapper. } else {
// Only then should we consume the pending RequestContext. None
// This prevents LS internal requests (title gen, etc.) from };
// consuming the context meant for the user's actual request. if ctx.is_some() {
// NOTE: We check ALL user messages because the LS appends context ctx
// messages AFTER the dot prompt (conversation summaries, etc.). } else if let Some(ref cid) = effective_cascade {
// We look for <USER_REQUEST> + dot specifically to avoid matching // Subsequent turn of an already-processed cascade
// old <cid:> markers in history (which are in model messages). if store.has_cascade_cache(cid).await {
let contains_our_dot = serde_json::from_slice::<serde_json::Value>(&raw_body) debug!(cascade = %cid, "MITM: subsequent turn — using cached context");
.ok()
.and_then(|json| {
let contents = json.pointer("/request/contents")?.as_array()?;
for msg in contents.iter() {
let is_user = msg.get("role")
.and_then(|r| r.as_str())
.map_or(true, |r| r == "user");
if !is_user { continue; }
if let Some(text) = msg.pointer("/parts/0/text").and_then(|v| v.as_str()) {
// Check for dot in <USER_REQUEST> wrapper
if text.contains("<USER_REQUEST>") {
if let (Some(s), Some(e)) = (text.find("<USER_REQUEST>"), text.find("</USER_REQUEST>")) {
let inner = &text[s + 14..e]; // 14 = len("<USER_REQUEST>")
let it = inner.trim();
if it == "." || it.starts_with(".<cid:") {
return Some(true);
}
}
}
// Check for bare dot (no wrapper)
let t = text.trim();
if t == "." || t == ".<cid:" || (t.starts_with(".<cid:") && t.ends_with(">")) {
return Some(true);
}
}
}
Some(false)
})
.unwrap_or(false);
// Only take the RequestContext if this request has our dot
let effective_cascade = precise_cascade.or(cascade_hint.clone());
let mut request_ctx: Option<super::store::RequestContext> = if contains_our_dot {
let ctx = if let Some(ref cid) = effective_cascade {
store.take_request(cid).await
} else {
None None
};
if ctx.is_some() {
ctx
} else if let Some(ref cid) = effective_cascade {
// Check if this is a subsequent turn (turn 1+) of an
// already-processed cascade. If so, DON'T fall through
// to take_latest_request — that would steal an unrelated
// cascade's context.
if store.has_cascade_cache(cid).await {
debug!(cascade = %cid, "MITM: subsequent turn — using cached context");
None
} else {
// Unknown cascade, try latest fallback
store.take_latest_request().await
}
} else { } else {
// Unknown cascade with our nonce, try latest fallback
store.take_latest_request().await store.take_latest_request().await
} }
} else { } else {
None store.take_latest_request().await
};
// Extract event channel from matched context
if let Some(ref mut ctx) = request_ctx {
event_tx = ctx.event_channel.take();
} }
} else {
// No nonce → LS internal request (title gen, checkpoint, etc.)
// Don't touch it.
None
};
// Build ToolContext from RequestContext (turn 0) or cached // Extract event channel from matched context
// context (turn 1+). On turn 0, we also cache the context if let Some(ref ctx) = request_ctx {
// for subsequent turns. event_tx = Some(ctx.event_channel.clone());
let tool_ctx = if let Some(ctx) = request_ctx.take() { }
// Turn 0: cache context for subsequent turns
if let Some(ref cid) = effective_cascade { // Build ToolContext from RequestContext (turn 0) or cached
store.cache_cascade(cid, super::store::CascadeCache { // context (turn 1+). On turn 0, we also cache the context
user_text: ctx.pending_user_text.clone(), // for subsequent turns.
tools: ctx.tools.clone(), let tool_ctx = if let Some(ctx) = request_ctx.take() {
tool_config: ctx.tool_config.clone(), // Turn 0: cache context for subsequent turns
generation_params: ctx.generation_params.clone(), if let Some(ref cid) = effective_cascade {
}).await; store.cache_cascade(cid, super::store::CascadeCache {
} user_text: ctx.pending_user_text.clone(),
tools: ctx.tools.clone(),
tool_config: ctx.tool_config.clone(),
generation_params: ctx.generation_params.clone(),
}).await;
}
Some(super::modify::ToolContext {
pending_user_text: ctx.pending_user_text,
tools: ctx.tools,
tool_config: ctx.tool_config,
pending_results: ctx.pending_tool_results,
last_calls: ctx.last_function_calls,
generation_params: ctx.generation_params,
pending_image: ctx.pending_image,
tool_rounds: ctx.tool_rounds,
})
} else if let Some(ref cid) = effective_cascade {
// Turn 1+: rebuild lite ToolContext from cache
if let Some(cached) = store.get_cascade_cache(cid).await {
Some(super::modify::ToolContext { Some(super::modify::ToolContext {
pending_user_text: ctx.pending_user_text, pending_user_text: cached.user_text,
tools: ctx.tools, tools: cached.tools,
tool_config: ctx.tool_config, tool_config: cached.tool_config,
pending_results: ctx.pending_tool_results, pending_results: vec![],
last_calls: ctx.last_function_calls, last_calls: vec![],
generation_params: ctx.generation_params, generation_params: cached.generation_params,
pending_image: ctx.pending_image, pending_image: None,
tool_rounds: ctx.tool_rounds, tool_rounds: vec![],
}) })
} else if let Some(ref cid) = effective_cascade {
// Turn 1+: rebuild lite ToolContext from cache
if let Some(cached) = store.get_cascade_cache(cid).await {
Some(super::modify::ToolContext {
pending_user_text: cached.user_text,
tools: cached.tools,
tool_config: cached.tool_config,
pending_results: vec![],
last_calls: vec![],
generation_params: cached.generation_params,
pending_image: None,
tool_rounds: vec![],
})
} else {
None
}
} else { } else {
None None
}; }
} else {
None
};
if tool_ctx.is_some() || has_nonce {
if let Some(modified_body) = if let Some(modified_body) =
super::modify::modify_request(&raw_body, tool_ctx.as_ref()) super::modify::modify_request(&raw_body, tool_ctx.as_ref())
{ {
@@ -1014,6 +982,29 @@ async fn dispatch_stream_events(
let _ = tx.send(super::store::MitmEvent::Grounding(gm.clone())).await; let _ = tx.send(super::store::MitmEvent::Grounding(gm.clone())).await;
} }
if acc.is_complete { if acc.is_complete {
// Send usage BEFORE ResponseComplete so handlers have it when processing completion
if acc.output_tokens > 0 || acc.input_tokens > 0 {
let usage_snapshot = super::store::ApiUsage {
input_tokens: acc.input_tokens,
output_tokens: acc.output_tokens,
cache_creation_input_tokens: acc.cache_creation_input_tokens,
cache_read_input_tokens: acc.cache_read_input_tokens,
thinking_output_tokens: acc.thinking_tokens,
thinking_text: None,
response_text: None,
response_output_tokens: 0,
model: acc.model.clone(),
stop_reason: acc.stop_reason.clone(),
api_provider: acc.api_provider.clone().unwrap_or_else(|| "unknown".to_string()).into(),
grpc_method: None,
captured_at: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
thinking_signature: acc.thinking_signature.clone(),
};
let _ = tx.send(super::store::MitmEvent::Usage(usage_snapshot)).await;
}
info!( info!(
response_text_len = acc.response_text.len(), response_text_len = acc.response_text.len(),
thinking_text_len = acc.thinking_text.len(), thinking_text_len = acc.thinking_text.len(),

View File

@@ -45,6 +45,10 @@ pub struct ApiUsage {
pub grpc_method: Option<String>, pub grpc_method: Option<String>,
/// Timestamp when this usage was captured. /// Timestamp when this usage was captured.
pub captured_at: u64, pub captured_at: u64,
/// Thinking signature from Google's response (base64 opaque blob).
/// Required for multi-turn with thinking models.
#[serde(skip_serializing_if = "Option::is_none")]
pub thinking_signature: Option<String>,
} }
/// A captured function call from Google's API response. /// A captured function call from Google's API response.
@@ -188,8 +192,7 @@ pub struct RequestContext {
/// Real user text for MITM injection (LS receives "." instead). /// Real user text for MITM injection (LS receives "." instead).
pub pending_user_text: String, pub pending_user_text: String,
/// Event channel for real-time streaming from MITM → API handler. /// Event channel for real-time streaming from MITM → API handler.
/// Only present when custom tools are active. pub event_channel: mpsc::Sender<MitmEvent>,
pub event_channel: Option<mpsc::Sender<MitmEvent>>,
/// Client-specified generation parameters (temperature, top_p, etc.). /// Client-specified generation parameters (temperature, top_p, etc.).
pub generation_params: Option<GenerationParams>, pub generation_params: Option<GenerationParams>,
/// Image to inject into the Google API request. /// Image to inject into the Google API request.
@@ -208,6 +211,13 @@ pub struct RequestContext {
pub call_id_to_name: HashMap<String, String>, pub call_id_to_name: HashMap<String, String>,
/// When this context was created (for TTL cleanup). /// When this context was created (for TTL cleanup).
pub created_at: Instant, pub created_at: Instant,
/// Gate: signaled when MITM takes this context.
/// API handlers wait on this with a timeout to detect match failures.
pub gate: Arc<tokio::sync::Notify>,
/// Debug trace handle (if tracing is enabled).
pub trace_handle: Option<crate::trace::TraceHandle>,
/// Current turn index in the trace (for multi-turn tracking).
pub trace_turn: usize,
} }
// ─── MitmStore ─────────────────────────────────────────────────────────────── // ─── MitmStore ───────────────────────────────────────────────────────────────
@@ -295,8 +305,9 @@ impl MitmStore {
/// Called by the MITM proxy when intercepting the LS's outbound request. /// Called by the MITM proxy when intercepting the LS's outbound request.
pub async fn take_request(&self, cascade_id: &str) -> Option<RequestContext> { pub async fn take_request(&self, cascade_id: &str) -> Option<RequestContext> {
let ctx = self.pending_requests.write().await.remove(cascade_id); let ctx = self.pending_requests.write().await.remove(cascade_id);
if ctx.is_some() { if let Some(ref c) = ctx {
debug!(cascade = %cascade_id, "Took request context"); c.gate.notify_one();
debug!(cascade = %cascade_id, "Took request context (gate signaled)");
} }
ctx ctx
} }
@@ -315,8 +326,9 @@ impl MitmStore {
.map(|(k, _)| k.clone()); .map(|(k, _)| k.clone());
if let Some(key) = latest_key { if let Some(key) = latest_key {
let ctx = pending.remove(&key); let ctx = pending.remove(&key);
if ctx.is_some() { if let Some(ref c) = ctx {
debug!(cascade = %key, "Took latest request context (fallback)"); c.gate.notify_one();
debug!(cascade = %key, "Took latest request context (fallback, gate signaled)");
} }
ctx ctx
} else { } else {
@@ -577,12 +589,42 @@ impl MitmStore {
// ── Compat shims for streaming tool-call loops ────────────────────── // ── Compat shims for streaming tool-call loops ──────────────────────
/// Update the event channel on an existing request context. /// Update the event channel on an existing request context,
/// Used by streaming loop handlers when re-registering for a new tool round. /// or re-register a minimal context if it was already consumed by `take_request`.
///
/// This is critical for thinking-only intermediate responses: the MITM proxy
/// consumes the context via `take_request`, but the handler needs to re-install
/// a channel for the LS's follow-up request.
pub async fn set_channel(&self, cascade_id: &str, tx: mpsc::Sender<MitmEvent>) { pub async fn set_channel(&self, cascade_id: &str, tx: mpsc::Sender<MitmEvent>) {
self.update_request(cascade_id, |ctx| { let updated = self.update_request(cascade_id, |ctx| {
ctx.event_channel = Some(tx); ctx.event_channel = tx.clone();
}).await; }).await;
if !updated {
// Context was already consumed — re-register a minimal one
// so the MITM proxy can match the follow-up request.
let gate = std::sync::Arc::new(tokio::sync::Notify::new());
self.register_request(RequestContext {
cascade_id: cascade_id.to_string(),
pending_user_text: String::new(),
event_channel: tx,
generation_params: None,
pending_image: None,
tools: None,
tool_config: None,
pending_tool_results: Vec::new(),
tool_rounds: Vec::new(),
last_function_calls: Vec::new(),
call_id_to_name: std::collections::HashMap::new(),
created_at: std::time::Instant::now(),
gate,
trace_handle: None,
trace_turn: 0,
}).await;
tracing::debug!(
cascade = cascade_id,
"set_channel: re-registered minimal context (original was consumed)"
);
}
} }
/// No-op. Upstream errors are now delivered through the event channel. /// No-op. Upstream errors are now delivered through the event channel.

509
src/trace.rs Normal file
View File

@@ -0,0 +1,509 @@
//! Per-call debug trace system.
//!
//! Every API call gets a structured JSON trace file written to
//! `~/.config/antigravity-proxy/traces/{YYYY-MM-DD}/{HH-MM-SS}_{cascade_short}.json`.
//!
//! Designed for LLM consumption: compact, structured, no raw bodies.
use serde::Serialize;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::Mutex;
/// Shared trace state for `AppState`.
#[derive(Clone)]
pub struct TraceCollector {
enabled: bool,
traces_dir: PathBuf,
}
impl TraceCollector {
pub fn new(enabled: bool) -> Self {
let home = std::env::var("HOME").unwrap_or_else(|_| "/tmp".to_string());
let traces_dir = PathBuf::from(home)
.join(".config")
.join("antigravity-proxy")
.join("traces");
Self {
enabled,
traces_dir,
}
}
/// Whether tracing is enabled.
pub fn enabled(&self) -> bool {
self.enabled
}
/// Start a new trace for an API call. Returns `None` if tracing is disabled.
pub fn start(&self, cascade_id: &str, endpoint: &str, model: &str, stream: bool) -> Option<TraceHandle> {
if !self.enabled {
return None;
}
let now = chrono::Utc::now();
Some(TraceHandle {
inner: Arc::new(Mutex::new(TraceData {
cascade_id: cascade_id.to_string(),
endpoint: endpoint.to_string(),
model: model.to_string(),
stream,
started_at: now.format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string(),
finished_at: None,
duration_ms: 0,
outcome: "in_progress".to_string(),
client_request: None,
turns: Vec::new(),
usage: None,
errors: Vec::new(),
})),
traces_dir: self.traces_dir.clone(),
started_at_chrono: now,
started_instant: Instant::now(),
})
}
/// Delete trace directories older than `max_age_days`.
pub fn cleanup_old_traces(&self, max_age_days: u32) {
if !self.enabled {
return;
}
let Ok(entries) = std::fs::read_dir(&self.traces_dir) else {
return;
};
let cutoff = chrono::Utc::now() - chrono::Duration::days(max_age_days as i64);
let cutoff_str = cutoff.format("%Y-%m-%d").to_string();
for entry in entries.flatten() {
let name = entry.file_name().to_string_lossy().to_string();
// Directory names are YYYY-MM-DD — lexicographic comparison works
if name < cutoff_str {
if let Err(e) = std::fs::remove_dir_all(entry.path()) {
tracing::warn!(dir = %name, error = %e, "trace: failed to remove old trace dir");
} else {
tracing::info!(dir = %name, "trace: cleaned up old trace dir");
}
}
}
}
}
/// Handle to a single in-flight trace. Cloneable, thread-safe.
#[derive(Clone)]
pub struct TraceHandle {
inner: Arc<Mutex<TraceData>>,
traces_dir: PathBuf,
started_at_chrono: chrono::DateTime<chrono::Utc>,
started_instant: Instant,
}
impl std::fmt::Debug for TraceHandle {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TraceHandle")
.field("traces_dir", &self.traces_dir)
.finish()
}
}
impl TraceHandle {
/// Record the client request summary.
pub async fn set_client_request(&self, req: ClientRequestSummary) {
self.inner.lock().await.client_request = Some(req);
}
/// Start a new turn (thinking/tool-call round). Returns the turn index.
pub async fn start_turn(&self) -> usize {
let mut data = self.inner.lock().await;
let idx = data.turns.len();
data.turns.push(TraceTurn {
turn: idx,
mitm_matched: false,
gate_wait_ms: None,
modify_summary: None,
request_bytes: None,
upstream_wait_ms: None,
response: None,
events_sent: Vec::new(),
handler_action: None,
});
idx
}
/// Record MITM match for a turn.
pub async fn record_mitm_match(&self, turn: usize, gate_wait_ms: u64) {
let mut data = self.inner.lock().await;
if let Some(t) = data.turns.get_mut(turn) {
t.mitm_matched = true;
t.gate_wait_ms = Some(gate_wait_ms);
}
}
/// Record MITM modify summary for a turn.
pub async fn record_modify(&self, turn: usize, summary: String, original: u64, modified: u64) {
let mut data = self.inner.lock().await;
if let Some(t) = data.turns.get_mut(turn) {
t.modify_summary = Some(summary);
t.request_bytes = Some(RequestBytes { original, modified });
}
}
/// Record upstream wait time.
pub async fn record_upstream_wait(&self, turn: usize, wait_ms: u64) {
let mut data = self.inner.lock().await;
if let Some(t) = data.turns.get_mut(turn) {
t.upstream_wait_ms = Some(wait_ms);
}
}
/// Record the response summary for a turn.
pub async fn record_response(&self, turn: usize, resp: ResponseSummary) {
let mut data = self.inner.lock().await;
if let Some(t) = data.turns.get_mut(turn) {
t.response = Some(resp);
}
}
/// Record an event sent via channel.
pub async fn record_event(&self, turn: usize, event_name: &str) {
let mut data = self.inner.lock().await;
if let Some(t) = data.turns.get_mut(turn) {
t.events_sent.push(event_name.to_string());
}
}
/// Record the handler action for a turn.
pub async fn record_action(&self, turn: usize, action: &str) {
let mut data = self.inner.lock().await;
if let Some(t) = data.turns.get_mut(turn) {
t.handler_action = Some(action.to_string());
}
}
/// Record an error.
pub async fn record_error(&self, error: String) {
self.inner.lock().await.errors.push(error);
}
/// Record final usage.
pub async fn set_usage(&self, usage: TrackedUsage) {
self.inner.lock().await.usage = Some(usage);
}
/// Finalize the trace and write to disk as a per-call folder.
///
/// Layout: `traces/{YYYY-MM-DD}/{HH-MM-SS}_{cascade_short}/`
/// - `summary.md` — always written, rich LLM-readable overview
/// - `request.json` — client request details (always)
/// - `turns.json` — per-turn MITM/gate/modify data (always)
/// - `response.json` — response summary + usage (if present)
/// - `events.json` — channel events (if non-empty)
/// - `errors.json` — errors (if any)
pub async fn finish(&self, outcome: &str) {
let mut data = self.inner.lock().await;
let now = chrono::Utc::now();
data.finished_at = Some(now.format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string());
data.duration_ms = self.started_instant.elapsed().as_millis() as u64;
data.outcome = outcome.to_string();
// Build folder path: traces/{YYYY-MM-DD}/{HH-MM-SS}_{cascade_short}/
let date_str = self.started_at_chrono.format("%Y-%m-%d").to_string();
let time_str = self.started_at_chrono.format("%H-%M-%S%.3f").to_string();
let cascade_short = &data.cascade_id[..8.min(data.cascade_id.len())];
let dir = self.traces_dir.join(&date_str).join(format!("{}_{}", time_str, cascade_short));
// Build all file contents while holding lock
let summary = generate_summary(&data);
let request_json = serde_json::to_string_pretty(&data.client_request).unwrap_or_default();
let turns_json = serde_json::to_string_pretty(&data.turns).unwrap_or_default();
let response_json = if data.usage.is_some() || data.turns.iter().any(|t| t.response.is_some()) {
let resp = ResponseFile {
usage: data.usage.clone(),
};
Some(serde_json::to_string_pretty(&resp).unwrap_or_default())
} else {
None
};
let events_json = {
let all_events: Vec<_> = data.turns.iter()
.enumerate()
.filter(|(_, t)| !t.events_sent.is_empty())
.map(|(i, t)| serde_json::json!({ "turn": i, "events": t.events_sent }))
.collect();
if all_events.is_empty() { None }
else { Some(serde_json::to_string_pretty(&all_events).unwrap_or_default()) }
};
let errors_json = if data.errors.is_empty() { None }
else { Some(serde_json::to_string_pretty(&data.errors).unwrap_or_default()) };
// Build meta.txt for grep
let meta_txt = format!(
"cascade={} endpoint={} model={} outcome={} duration={}ms stream={}",
cascade_short, data.endpoint, data.model, data.outcome, data.duration_ms, data.stream
);
drop(data); // release lock before I/O
tokio::spawn(async move {
if let Err(e) = tokio::fs::create_dir_all(&dir).await {
tracing::warn!(error = %e, "trace: failed to create dir");
return;
}
// Always write summary + request + turns + meta
let _ = tokio::fs::write(dir.join("summary.md"), summary).await;
let _ = tokio::fs::write(dir.join("request.json"), request_json).await;
let _ = tokio::fs::write(dir.join("turns.json"), turns_json).await;
let _ = tokio::fs::write(dir.join("meta.txt"), meta_txt).await;
// Conditionally write response, events, errors
if let Some(j) = response_json {
let _ = tokio::fs::write(dir.join("response.json"), j).await;
}
if let Some(j) = events_json {
let _ = tokio::fs::write(dir.join("events.json"), j).await;
}
if let Some(j) = errors_json {
let _ = tokio::fs::write(dir.join("errors.json"), j).await;
}
tracing::debug!(path = %dir.display(), "trace: folder written");
});
}
}
// ── Summary generation ─────────────────────────────────────────────────
#[derive(Serialize)]
struct ResponseFile {
#[serde(skip_serializing_if = "Option::is_none")]
usage: Option<TrackedUsage>,
}
/// Build a rich markdown summary from trace data.
fn generate_summary(data: &TraceData) -> String {
let mut s = String::with_capacity(2048);
let cascade_short = &data.cascade_id[..8.min(data.cascade_id.len())];
// Header
s.push_str(&format!("# Trace: {}{}\n\n", cascade_short, data.endpoint));
// Overview table
s.push_str("| Field | Value |\n|-------|-------|\n");
s.push_str(&format!("| Cascade ID | `{}` |\n", data.cascade_id));
s.push_str(&format!("| Model | {} |\n", data.model));
s.push_str(&format!("| Stream | {} |\n", data.stream));
s.push_str(&format!("| Started | {} |\n", data.started_at));
if let Some(ref fin) = data.finished_at {
s.push_str(&format!("| Finished | {} |\n", fin));
}
s.push_str(&format!("| Duration | {}ms |\n", data.duration_ms));
s.push_str(&format!("| Outcome | **{}** |\n", data.outcome));
s.push('\n');
// Client request
s.push_str("## Client Request\n\n");
if let Some(ref req) = data.client_request {
s.push_str(&format!("- **Messages:** {} (user text: {} chars)\n", req.message_count, req.user_text_len));
if !req.user_text_preview.is_empty() {
s.push_str(&format!("- **Preview:** `{}`\n", req.user_text_preview));
}
s.push_str(&format!("- **Tools:** {} | **Tool rounds:** {}\n", req.tool_count, req.tool_round_count));
if req.system_prompt { s.push_str("- **System prompt:** yes\n"); }
s.push_str(&format!("- **Image:** {}\n", if req.has_image { "yes" } else { "no" }));
} else {
s.push_str("(not recorded)\n");
}
s.push_str("\n→ Full details in [request.json](./request.json)\n\n");
// Turns
s.push_str(&format!("## Turns ({} total)\n\n", data.turns.len()));
for turn in &data.turns {
s.push_str(&format!("### Turn {}\n\n", turn.turn));
// MITM match
if turn.mitm_matched {
s.push_str(&format!("- **MITM matched:** ✓ (gate wait: {}ms)\n",
turn.gate_wait_ms.unwrap_or(0)));
} else {
s.push_str("- **MITM matched:** ✗\n");
}
// Modify
if let Some(ref mod_sum) = turn.modify_summary {
s.push_str(&format!("- **Modify:** {}", mod_sum));
if let Some(ref bytes) = turn.request_bytes {
s.push_str(&format!(" ({}B → {}B)", bytes.original, bytes.modified));
}
s.push('\n');
}
// Upstream wait
if let Some(wait) = turn.upstream_wait_ms {
s.push_str(&format!("- **Upstream wait:** {}ms\n", wait));
}
// Response
if let Some(ref resp) = turn.response {
s.push_str(&format!("- **Response:** {} chars text, {} chars thinking",
resp.text_len, resp.thinking_len));
if let Some(ref fr) = resp.finish_reason {
s.push_str(&format!(", finish_reason={}", fr));
}
if !resp.function_calls.is_empty() {
let names: Vec<&str> = resp.function_calls.iter().map(|f| f.name.as_str()).collect();
s.push_str(&format!(", tool_calls=[{}]", names.join(", ")));
}
if resp.grounding {
s.push_str(", grounding=yes");
}
s.push('\n');
if !resp.text_preview.is_empty() {
s.push_str(&format!("- **Output preview:** `{}`\n", resp.text_preview));
}
}
// Events
if !turn.events_sent.is_empty() {
s.push_str(&format!("- **Events:** {} sent ({})\n",
turn.events_sent.len(),
turn.events_sent.join(", ")));
}
// Handler action
if let Some(ref action) = turn.handler_action {
s.push_str(&format!("- **Action:** {}\n", action));
}
s.push('\n');
}
if data.turns.iter().any(|t| t.response.is_some()) {
s.push_str("→ Full turn details in [turns.json](./turns.json)\n\n");
}
// Usage
if let Some(ref u) = data.usage {
s.push_str("## Usage\n\n");
s.push_str(&format!("| Metric | Tokens |\n|--------|--------|\n"));
s.push_str(&format!("| Input | {} |\n", u.input_tokens));
s.push_str(&format!("| Output | {} |\n", u.output_tokens));
if u.thinking_tokens > 0 {
s.push_str(&format!("| Thinking | {} |\n", u.thinking_tokens));
}
if u.cache_read > 0 {
s.push_str(&format!("| Cache read | {} |\n", u.cache_read));
}
s.push_str("\n→ Full details in [response.json](./response.json)\n\n");
}
// Errors
if !data.errors.is_empty() {
s.push_str("## Errors\n\n");
for err in &data.errors {
s.push_str(&format!("- ❌ {}\n", err));
}
s.push_str("\n→ Full details in [errors.json](./errors.json)\n\n");
}
// Files index
s.push_str("## Files\n\n");
s.push_str("| File | Contains |\n|------|----------|\n");
s.push_str("| [request.json](./request.json) | Client request summary |\n");
s.push_str("| [turns.json](./turns.json) | Per-turn MITM/gate/modify/response data |\n");
if data.usage.is_some() || data.turns.iter().any(|t| t.response.is_some()) {
s.push_str("| [response.json](./response.json) | Response summaries + token usage |\n");
}
if data.turns.iter().any(|t| !t.events_sent.is_empty()) {
s.push_str("| [events.json](./events.json) | Channel events per turn |\n");
}
if !data.errors.is_empty() {
s.push_str("| [errors.json](./errors.json) | Error messages |\n");
}
s
}
// ── Serializable data structures ───────────────────────────────────────
#[derive(Serialize)]
struct TraceData {
cascade_id: String,
endpoint: String,
model: String,
stream: bool,
started_at: String,
#[serde(skip_serializing_if = "Option::is_none")]
finished_at: Option<String>,
duration_ms: u64,
outcome: String,
#[serde(skip_serializing_if = "Option::is_none")]
client_request: Option<ClientRequestSummary>,
turns: Vec<TraceTurn>,
#[serde(skip_serializing_if = "Option::is_none")]
usage: Option<TrackedUsage>,
#[serde(skip_serializing_if = "Vec::is_empty")]
errors: Vec<String>,
}
#[derive(Serialize, Clone)]
pub struct ClientRequestSummary {
pub message_count: usize,
pub tool_count: usize,
pub tool_round_count: usize,
pub user_text_len: usize,
pub user_text_preview: String,
pub system_prompt: bool,
pub has_image: bool,
}
#[derive(Serialize)]
struct TraceTurn {
turn: usize,
mitm_matched: bool,
#[serde(skip_serializing_if = "Option::is_none")]
gate_wait_ms: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
modify_summary: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
request_bytes: Option<RequestBytes>,
#[serde(skip_serializing_if = "Option::is_none")]
upstream_wait_ms: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
response: Option<ResponseSummary>,
#[serde(skip_serializing_if = "Vec::is_empty")]
events_sent: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
handler_action: Option<String>,
}
#[derive(Serialize)]
struct RequestBytes {
original: u64,
modified: u64,
}
#[derive(Serialize, Clone)]
pub struct ResponseSummary {
pub text_len: usize,
pub thinking_len: usize,
#[serde(skip_serializing_if = "String::is_empty")]
pub text_preview: String,
pub finish_reason: Option<String>,
#[serde(skip_serializing_if = "Vec::is_empty")]
pub function_calls: Vec<FunctionCallSummary>,
pub grounding: bool,
}
#[derive(Serialize, Clone)]
pub struct FunctionCallSummary {
pub name: String,
pub args_preview: String,
}
#[derive(Serialize, Clone)]
pub struct TrackedUsage {
pub input_tokens: u64,
pub output_tokens: u64,
pub thinking_tokens: u64,
pub cache_read: u64,
}