From 28d3296c87a08a4a55ee8018213a8b184cb71989 Mon Sep 17 00:00:00 2001 From: Nikketryhard Date: Wed, 18 Feb 2026 01:31:18 -0600 Subject: [PATCH] fix: gemini route, usage capture, search timeout, and trace finalization - Add missing /v1/gemini POST route and handler - Capture MitmEvent::Usage in gemini sync/streaming handlers - Add retry counter (max 3) to search handler to prevent hang - Add trace finalization at all gemini_sync channel exit points - Fix UpstreamError trace outcome label - Add timeout trace with error recording - Dispatch Usage before ResponseComplete in SSE flush --- src/api/completions.rs | 146 +++++++++++- src/api/gemini.rs | 342 ++++++++++++++++++++++----- src/api/mod.rs | 5 + src/api/responses.rs | 188 ++++++++++++++- src/api/search.rs | 194 ++++++++++++++-- src/main.rs | 15 ++ src/mitm/intercept.rs | 12 +- src/mitm/proto.rs | 1 + src/mitm/proxy.rs | 227 +++++++++--------- src/mitm/store.rs | 62 ++++- src/trace.rs | 509 +++++++++++++++++++++++++++++++++++++++++ 11 files changed, 1480 insertions(+), 221 deletions(-) create mode 100644 src/trace.rs diff --git a/src/api/completions.rs b/src/api/completions.rs index 8240370..38f9feb 100644 --- a/src/api/completions.rs +++ b/src/api/completions.rs @@ -435,21 +435,33 @@ pub(crate) async fn handle_completions( .map(|r| r.calls.clone()) .unwrap_or_default(); - // Build event channel for streaming - let has_custom_tools = tools.is_some(); - let (mitm_rx, event_tx) = if has_custom_tools && body.stream { - let (tx, rx) = tokio::sync::mpsc::channel(64); - (Some(rx), Some(tx)) - } else { - (None, None) - }; + // Build event channel — always created for MITM response path + let (tx, rx) = tokio::sync::mpsc::channel(64); + let (mitm_rx, event_tx) = (Some(rx), tx); // Build pending tool results from latest round let pending_tool_results = tool_rounds.last() .map(|r| r.results.clone()) .unwrap_or_default(); - // Register all per-request state atomically + // Start debug trace + let trace = state.trace.start(&cascade_id, "POST /v1/chat/completions", model_name, body.stream); + if let Some(ref t) = trace { + t.set_client_request(crate::trace::ClientRequestSummary { + message_count: body.messages.len(), + tool_count: body.tools.as_ref().map_or(0, |t| t.len()), + tool_round_count: tool_rounds.len(), + user_text_len: user_text.len(), + user_text_preview: user_text.chars().take(200).collect(), + system_prompt: body.messages.iter().any(|m| m.role == "system"), + has_image: image.is_some(), + }).await; + // Start turn 0 + t.start_turn().await; + } + + let mitm_gate = std::sync::Arc::new(tokio::sync::Notify::new()); + let mitm_gate_clone = mitm_gate.clone(); state.mitm_store.register_request(crate::mitm::store::RequestContext { cascade_id: cascade_id.clone(), pending_user_text: user_text.clone(), @@ -463,6 +475,9 @@ pub(crate) async fn handle_completions( last_function_calls, call_id_to_name, created_at: std::time::Instant::now(), + gate: mitm_gate_clone, + trace_handle: trace.clone(), + trace_turn: 0, }).await; // Send REAL user text to LS @@ -480,6 +495,7 @@ pub(crate) async fn handle_completions( } Ok((status, _)) => { state.mitm_store.remove_request(&cascade_id).await; + if let Some(ref t) = trace { t.record_error(format!("Backend returned {status}")).await; t.finish("backend_error").await; } return err_response( StatusCode::BAD_GATEWAY, format!("Backend returned {status}"), @@ -488,6 +504,7 @@ pub(crate) async fn handle_completions( } Err(e) => { state.mitm_store.remove_request(&cascade_id).await; + if let Some(ref t) = trace { t.record_error(format!("Send failed: {e}")).await; t.finish("send_error").await; } return err_response( StatusCode::BAD_GATEWAY, format!("Send failed: {e}"), @@ -496,6 +513,34 @@ pub(crate) async fn handle_completions( } } + // Wait for MITM gate: 5s → 502 if MITM enabled + let gate_start = std::time::Instant::now(); + let gate_matched = tokio::time::timeout( + std::time::Duration::from_secs(5), + mitm_gate.notified(), + ).await; + let gate_wait_ms = gate_start.elapsed().as_millis() as u64; + if gate_matched.is_err() { + if state.mitm_enabled { + state.mitm_store.remove_request(&cascade_id).await; + if let Some(ref t) = trace { + t.record_error("MITM gate timeout (5s)".to_string()).await; + t.finish("mitm_timeout").await; + } + return err_response( + StatusCode::BAD_GATEWAY, + "MITM proxy did not match request within 5s".to_string(), + "mitm_timeout", + ); + } + warn!(cascade = %cascade_id, "MITM gate timeout (--no-mitm mode)"); + } else { + debug!(cascade = %cascade_id, gate_wait_ms, "MITM gate signaled — request matched"); + if let Some(ref t) = trace { + t.record_mitm_match(0, gate_wait_ms).await; + } + } + let completion_id = format!( "chatcmpl-{}", uuid::Uuid::new_v4().to_string().replace('-', "") @@ -515,6 +560,7 @@ pub(crate) async fn handle_completions( body.timeout, include_usage, mitm_rx, + trace, ) .await } else if n <= 1 { @@ -524,6 +570,7 @@ pub(crate) async fn handle_completions( model_name.to_string(), cascade_id, body.timeout, + trace, ) .await } else { @@ -653,6 +700,7 @@ async fn chat_completions_stream( timeout: u64, include_usage: bool, mitm_rx: Option>, + trace: Option, ) -> axum::response::Response { let stream = async_stream::stream! { let start = std::time::Instant::now(); @@ -774,6 +822,21 @@ async fn chat_completions_stream( } yield Ok(Event::default().data("[DONE]")); state.mitm_store.remove_request(&cascade_id).await; + if let Some(ref t) = trace { + let (ipt, opt, crt2, tht) = if let Some(ref u) = last_usage { + (u.input_tokens, u.output_tokens, u.cache_read_input_tokens, u.thinking_output_tokens) + } else { (0, 0, 0, 0) }; + t.record_response(0, crate::trace::ResponseSummary { + text_len: 0, thinking_len: 0, text_preview: String::new(), + finish_reason: Some("tool_calls".to_string()), + function_calls: calls.iter().map(|fc| crate::trace::FunctionCallSummary { + name: fc.name.clone(), args_preview: serde_json::to_string(&fc.args).unwrap_or_default().chars().take(200).collect(), + }).collect(), + grounding: false, + }).await; + t.set_usage(crate::trace::TrackedUsage { input_tokens: ipt, output_tokens: opt, thinking_tokens: tht, cache_read: crt2 }).await; + t.finish("tool_call").await; + } return; } MitmEvent::ResponseComplete => { @@ -802,6 +865,19 @@ async fn chat_completions_stream( } yield Ok(Event::default().data("[DONE]")); state.mitm_store.remove_request(&cascade_id).await; + if let Some(ref t) = trace { + let (ipt, opt, crt2, tht) = if let Some(ref u) = mitm { + (u.input_tokens, u.output_tokens, u.cache_read_input_tokens, u.thinking_output_tokens) + } else { (0, 0, 0, 0) }; + t.record_response(0, crate::trace::ResponseSummary { + text_len: acc_text.len(), thinking_len: acc_thinking.len(), + text_preview: acc_text.chars().take(200).collect(), + finish_reason: Some("stop".to_string()), + function_calls: Vec::new(), grounding: false, + }).await; + t.set_usage(crate::trace::TrackedUsage { input_tokens: ipt, output_tokens: opt, thinking_tokens: tht, cache_read: crt2 }).await; + t.finish("completed").await; + } return; } else if !acc_thinking.is_empty() && !did_unblock_ls { // Thinking-only response — LS needs follow-up API calls. @@ -844,6 +920,19 @@ async fn chat_completions_stream( } yield Ok(Event::default().data("[DONE]")); state.mitm_store.remove_request(&cascade_id).await; + if let Some(ref t) = trace { + let (ipt, opt, crt2, tht) = if let Some(ref u) = mitm { + (u.input_tokens, u.output_tokens, u.cache_read_input_tokens, u.thinking_output_tokens) + } else { (0, 0, 0, 0) }; + t.record_response(0, crate::trace::ResponseSummary { + text_len: 0, thinking_len: acc_thinking.len(), + text_preview: String::new(), + finish_reason: Some("stop".to_string()), + function_calls: Vec::new(), grounding: false, + }).await; + t.set_usage(crate::trace::TrackedUsage { input_tokens: ipt, output_tokens: opt, thinking_tokens: tht, cache_read: crt2 }).await; + t.finish("thinking_timeout").await; + } return; } // Don't break — wait for more channel events @@ -860,6 +949,14 @@ async fn chat_completions_stream( ))); yield Ok(Event::default().data("[DONE]")); state.mitm_store.remove_request(&cascade_id).await; + if let Some(ref t) = trace { + t.record_response(0, crate::trace::ResponseSummary { + text_len: 0, thinking_len: 0, text_preview: String::new(), + finish_reason: Some("stop".to_string()), + function_calls: Vec::new(), grounding: false, + }).await; + t.finish("empty_response").await; + } return; } continue 'channel_loop; @@ -900,6 +997,15 @@ async fn chat_completions_stream( ))); } yield Ok(Event::default().data("[DONE]")); + if let Some(ref t) = trace { + t.record_response(0, crate::trace::ResponseSummary { + text_len: last_text.len(), thinking_len: last_thinking_len, + text_preview: last_text.chars().take(200).collect(), + finish_reason: Some("stop".to_string()), + function_calls: Vec::new(), grounding: false, + }).await; + t.finish("channel_closed").await; + } return; } else { // ── Fallback: LS steps (no MITM capture active) ── @@ -1046,6 +1152,7 @@ async fn chat_completions_sync( model_name: String, cascade_id: String, timeout: u64, + trace: Option, ) -> axum::response::Response { let result = poll_for_response(&state, &cascade_id, timeout).await; if let Some(ref err) = result.upstream_error { @@ -1084,6 +1191,27 @@ async fn chat_completions_sync( message["reasoning_content"] = serde_json::json!(thinking); } + // Record trace data + if let Some(ref t) = trace { + t.record_response(0, crate::trace::ResponseSummary { + text_len: result.text.len(), + thinking_len: result.thinking.as_ref().map_or(0, |s| s.len()), + text_preview: result.text.chars().take(200).collect(), + finish_reason: Some(finish_reason.to_string()), + function_calls: Vec::new(), + grounding: false, + }).await; + if prompt_tokens > 0 || completion_tokens > 0 { + t.set_usage(crate::trace::TrackedUsage { + input_tokens: prompt_tokens, + output_tokens: completion_tokens, + thinking_tokens: thinking_tokens, + cache_read: cached_tokens, + }).await; + } + t.finish("completed").await; + } + Json(serde_json::json!({ "id": completion_id, "object": "chat.completion", diff --git a/src/api/gemini.rs b/src/api/gemini.rs index dd0fe23..1b27884 100644 --- a/src/api/gemini.rs +++ b/src/api/gemini.rs @@ -16,7 +16,7 @@ use axum::{ }; use rand::Rng; use std::sync::Arc; -use tracing::{debug, info}; +use tracing::{debug, info, warn}; use super::models::{lookup_model, DEFAULT_MODEL, MODELS}; use super::polling::{ @@ -30,8 +30,15 @@ use crate::mitm::store::PendingToolResult; #[derive(serde::Deserialize)] pub(crate) struct GeminiRequest { pub model: Option, - /// User input text. - pub input: serde_json::Value, + /// User input text (our custom format). + #[serde(default)] + pub input: Option, + /// Official Gemini API format: [{"role": "user", "parts": [{"text": "..."}]}] + #[serde(default)] + pub contents: Option>, + /// Shorthand: single text message (alias for simple requests). + #[serde(default)] + pub message: Option, /// Gemini-native tools: [{"functionDeclarations": [...]}] #[serde(default)] pub tools: Option>, @@ -111,6 +118,14 @@ async fn build_usage_metadata( } } +/// POST /v1/gemini — simple custom endpoint +pub(crate) async fn handle_gemini( + State(state): State>, + Json(body): Json, +) -> axum::response::Response { + handle_gemini_inner(state, body).await +} + /// POST /v1beta/*path — handles both :generateContent and :streamGenerateContent /// /// Parses paths like: @@ -185,58 +200,105 @@ async fn handle_gemini_inner( ); } - // Extract user text and optional image + // Extract user text and optional image. + // Priority: contents (official Gemini API) > input (our format) > message (shorthand) let mut image: Option = None; - let user_text = match &body.input { - serde_json::Value::String(s) => s.clone(), - serde_json::Value::Array(arr) => { - // Support array input: strings, {text: "..."}, or {inlineData: {mimeType, data}} - let mut parts: Vec = Vec::new(); - for item in arr { - match item { - serde_json::Value::String(s) => parts.push(s.clone()), - serde_json::Value::Object(obj) => { - if let Some(text) = obj.get("text").and_then(|v| v.as_str()) { - parts.push(text.to_string()); - } - // Gemini-native inlineData format - if image.is_none() { - if let Some(inline) = obj.get("inlineData") { - if let (Some(mime), Some(b64)) = - (inline["mimeType"].as_str(), inline["data"].as_str()) - { - if let Some(img) = super::util::parse_data_uri(&format!( - "data:{mime};base64,{b64}" - )) { - image = Some(img); - } + let user_text = if let Some(ref contents) = body.contents { + // Official Gemini API format: [{"role": "user", "parts": [{"text": "..."}]}] + // Extract text from the last user message. + let mut text_parts: Vec = Vec::new(); + for content in contents.iter().rev() { + let role = content.get("role").and_then(|r| r.as_str()).unwrap_or("user"); + if role != "user" { continue; } + if let Some(parts) = content.get("parts").and_then(|p| p.as_array()) { + for part in parts { + if let Some(text) = part.get("text").and_then(|t| t.as_str()) { + text_parts.push(text.to_string()); + } + // Handle inlineData image + if image.is_none() { + if let Some(inline) = part.get("inlineData") { + if let (Some(mime), Some(b64)) = + (inline["mimeType"].as_str(), inline["data"].as_str()) + { + if let Some(img) = super::util::parse_data_uri(&format!( + "data:{mime};base64,{b64}" + )) { + image = Some(img); } } - // Also support OpenAI-style image_url in Gemini input - if let Some(img) = super::util::extract_image_from_content(item) { - image = Some(img); - } } } - _ => {} } } - if parts.is_empty() { - return err_response( - StatusCode::BAD_REQUEST, - "Gemini input array contains no text parts".to_string(), - "invalid_request_error", - ); - } - parts.join("\n") + if !text_parts.is_empty() { break; } } - _ => { + if text_parts.is_empty() { return err_response( StatusCode::BAD_REQUEST, - "Gemini endpoint requires input as a string or array of text parts".to_string(), + "No text found in contents array".to_string(), "invalid_request_error", ); } + text_parts.join("\n") + } else if let Some(ref input) = body.input { + // Our custom format: input as string or array + match input { + serde_json::Value::String(s) => s.clone(), + serde_json::Value::Array(arr) => { + let mut parts: Vec = Vec::new(); + for item in arr { + match item { + serde_json::Value::String(s) => parts.push(s.clone()), + serde_json::Value::Object(obj) => { + if let Some(text) = obj.get("text").and_then(|v| v.as_str()) { + parts.push(text.to_string()); + } + if image.is_none() { + if let Some(inline) = obj.get("inlineData") { + if let (Some(mime), Some(b64)) = + (inline["mimeType"].as_str(), inline["data"].as_str()) + { + if let Some(img) = super::util::parse_data_uri(&format!( + "data:{mime};base64,{b64}" + )) { + image = Some(img); + } + } + } + if let Some(img) = super::util::extract_image_from_content(item) { + image = Some(img); + } + } + } + _ => {} + } + } + if parts.is_empty() { + return err_response( + StatusCode::BAD_REQUEST, + "Gemini input array contains no text parts".to_string(), + "invalid_request_error", + ); + } + parts.join("\n") + } + _ => { + return err_response( + StatusCode::BAD_REQUEST, + "Gemini input must be a string or array of text parts".to_string(), + "invalid_request_error", + ); + } + } + } else if let Some(ref msg) = body.message { + msg.clone() + } else { + return err_response( + StatusCode::BAD_REQUEST, + "Request must include 'contents' (Gemini API), 'input', or 'message'".to_string(), + "invalid_request_error", + ); }; // ── Build per-request state locally ────────────────────────────────── @@ -320,14 +382,9 @@ async fn handle_gemini_inner( } }); - // Build event channel for streaming - let has_custom_tools = tools.is_some(); - let (mitm_rx, event_tx) = if has_custom_tools { - let (tx, rx) = tokio::sync::mpsc::channel(64); - (Some(rx), Some(tx)) - } else { - (None, None) - }; + // Build event channel — always created for MITM response path + let (tx, rx) = tokio::sync::mpsc::channel(64); + let (mitm_rx, event_tx) = (Some(rx), tx); // Build tool rounds now that cascade_id is known let mut tool_rounds: Vec = Vec::new(); @@ -340,7 +397,23 @@ async fn handle_gemini_inner( }); } - // Register all per-request state atomically + // Start debug trace + let trace = state.trace.start(&cascade_id, "POST gemini", &model_name, body.stream); + if let Some(ref t) = trace { + t.set_client_request(crate::trace::ClientRequestSummary { + message_count: 1, + tool_count: body.tools.as_ref().map_or(0, |t| t.len()), + tool_round_count: tool_rounds.len(), + user_text_len: user_text.len(), + user_text_preview: user_text.chars().take(200).collect(), + system_prompt: false, + has_image: image.is_some(), + }).await; + t.start_turn().await; + } + + let mitm_gate = std::sync::Arc::new(tokio::sync::Notify::new()); + let mitm_gate_clone = mitm_gate.clone(); state.mitm_store.register_request(crate::mitm::store::RequestContext { cascade_id: cascade_id.clone(), pending_user_text: user_text.clone(), @@ -354,6 +427,9 @@ async fn handle_gemini_inner( last_function_calls: Vec::new(), call_id_to_name: std::collections::HashMap::new(), created_at: std::time::Instant::now(), + gate: mitm_gate_clone, + trace_handle: trace.clone(), + trace_turn: 0, }).await; // Send REAL user text to LS (no more dummy ".") @@ -387,13 +463,36 @@ async fn handle_gemini_inner( } } + // Wait for MITM gate: 5s -> 502 if MITM enabled + let gate_start = std::time::Instant::now(); + let gate_matched = tokio::time::timeout( + std::time::Duration::from_secs(5), + mitm_gate.notified(), + ).await; + let gate_wait_ms = gate_start.elapsed().as_millis() as u64; + if gate_matched.is_err() { + if state.mitm_enabled { + state.mitm_store.remove_request(&cascade_id).await; + if let Some(ref t) = trace { t.record_error("MITM gate timeout (5s)".to_string()).await; t.finish("mitm_timeout").await; } + return err_response( + StatusCode::BAD_GATEWAY, + "MITM proxy did not match request within 5s".to_string(), + "mitm_timeout", + ); + } + warn!(cascade = %cascade_id, "MITM gate timeout (--no-mitm mode)"); + } else { + debug!(cascade = %cascade_id, gate_wait_ms, "MITM gate signaled -- request matched"); + if let Some(ref t) = trace { t.record_mitm_match(0, gate_wait_ms).await; } + } + // Dispatch to sync or stream let model_name = model_name.to_string(); let timeout = body.timeout; if body.stream { - gemini_stream(state, model_name, cascade_id, timeout, mitm_rx).await + gemini_stream(state, model_name, cascade_id, timeout, mitm_rx, trace).await } else { - gemini_sync(state, model_name, cascade_id, timeout, mitm_rx).await + gemini_sync(state, model_name, cascade_id, timeout, mitm_rx, trace).await } } @@ -405,6 +504,7 @@ async fn gemini_sync( cascade_id: String, timeout: u64, mitm_rx: Option>, + trace: Option, ) -> axum::response::Response { // Clear stale response and upstream errors (only if no pre-installed channel) if mitm_rx.is_none() { @@ -418,6 +518,7 @@ async fn gemini_sync( let mut acc_text = String::new(); let mut acc_thinking: Option = None; + let mut last_usage: Option = None; while let Some(event) = tokio::time::timeout( std::time::Duration::from_secs(timeout.saturating_sub(start.elapsed().as_secs())), @@ -427,7 +528,8 @@ async fn gemini_sync( match event { MitmEvent::ThinkingDelta(t) => { acc_thinking = Some(t); } MitmEvent::TextDelta(t) => { acc_text = t; } - MitmEvent::Usage(_) | MitmEvent::Grounding(_) => {} + MitmEvent::Usage(u) => { last_usage = Some(u); } + MitmEvent::Grounding(_) => {} MitmEvent::FunctionCall(calls) => { let parts: Vec = calls .iter() @@ -440,6 +542,21 @@ async fn gemini_sync( }) }) .collect(); + if let Some(ref t) = trace { + let fc_summaries: Vec = calls.iter().map(|fc| { + crate::trace::FunctionCallSummary { + name: fc.name.clone(), + args_preview: serde_json::to_string(&fc.args).unwrap_or_default().chars().take(200).collect(), + } + }).collect(); + t.record_response(0, crate::trace::ResponseSummary { + text_len: 0, thinking_len: acc_thinking.as_ref().map_or(0, |s| s.len()), + text_preview: String::new(), + finish_reason: Some("STOP".to_string()), + function_calls: fc_summaries, grounding: false, + }).await; + t.finish("tool_call").await; + } state.mitm_store.remove_request(&cascade_id).await; return Json(serde_json::json!({ "candidates": [{ @@ -477,6 +594,18 @@ async fn gemini_sync( parts.push(serde_json::json!({"text": t, "thought": true})); } parts.push(serde_json::json!({"text": acc_text})); + if let Some(ref t) = trace { + t.record_response(0, crate::trace::ResponseSummary { + text_len: acc_text.len(), thinking_len: acc_thinking.as_ref().map_or(0, |s| s.len()), + text_preview: acc_text.chars().take(200).collect(), + finish_reason: Some("STOP".to_string()), + function_calls: Vec::new(), grounding: false, + }).await; + if let Some(ref u) = last_usage { + t.set_usage(crate::trace::TrackedUsage { input_tokens: u.input_tokens, output_tokens: u.output_tokens, thinking_tokens: u.thinking_output_tokens, cache_read: u.cache_read_input_tokens }).await; + } + t.finish("completed").await; + } state.mitm_store.remove_request(&cascade_id).await; return Json(serde_json::json!({ "candidates": [{ @@ -487,11 +616,33 @@ async fn gemini_sync( "finishReason": "STOP", }], "modelVersion": model_name, - "usageMetadata": build_usage_metadata(&state.mitm_store, &cascade_id).await, + "usageMetadata": if let Some(ref u) = last_usage { + serde_json::json!({ + "promptTokenCount": u.input_tokens, + "candidatesTokenCount": u.output_tokens, + "totalTokenCount": u.input_tokens + u.output_tokens, + "thoughtsTokenCount": u.thinking_output_tokens, + "cachedContentTokenCount": u.cache_read_input_tokens, + }) + } else { + build_usage_metadata(&state.mitm_store, &cascade_id).await + }, })) .into_response(); } MitmEvent::UpstreamError(err) => { + if let Some(ref t) = trace { + t.record_response(0, crate::trace::ResponseSummary { + text_len: acc_text.len(), thinking_len: acc_thinking.as_ref().map_or(0, |s| s.len()), + text_preview: acc_text.chars().take(200).collect(), + finish_reason: Some("STOP".to_string()), + function_calls: Vec::new(), grounding: false, + }).await; + if let Some(ref u) = last_usage { + t.set_usage(crate::trace::TrackedUsage { input_tokens: u.input_tokens, output_tokens: u.output_tokens, thinking_tokens: u.thinking_output_tokens, cache_read: u.cache_read_input_tokens }).await; + } + t.finish("upstream_error").await; + } state.mitm_store.remove_request(&cascade_id).await; return upstream_err_response(&err); } @@ -499,6 +650,10 @@ async fn gemini_sync( } // Timeout + if let Some(ref t) = trace { + t.record_error(format!("Timeout: no response after {timeout}s")).await; + t.finish("timeout").await; + } state.mitm_store.remove_request(&cascade_id).await; return ( axum::http::StatusCode::GATEWAY_TIMEOUT, @@ -541,6 +696,25 @@ async fn gemini_sync( }) .collect(); + // Record trace + if let Some(ref t) = trace { + let fc_summaries: Vec = calls.iter().map(|fc| { + crate::trace::FunctionCallSummary { + name: fc.name.clone(), + args_preview: serde_json::to_string(&fc.args).unwrap_or_default().chars().take(200).collect(), + } + }).collect(); + t.record_response(0, crate::trace::ResponseSummary { + text_len: 0, + thinking_len: 0, + text_preview: String::new(), + finish_reason: Some("STOP".to_string()), + function_calls: fc_summaries, + grounding: false, + }).await; + t.finish("tool_call").await; + } + return Json(serde_json::json!({ "candidates": [{ "content": { @@ -562,6 +736,19 @@ async fn gemini_sync( } parts.push(serde_json::json!({"text": poll_result.text})); + // Record trace + if let Some(ref t) = trace { + t.record_response(0, crate::trace::ResponseSummary { + text_len: poll_result.text.len(), + thinking_len: poll_result.thinking.as_ref().map_or(0, |s| s.len()), + text_preview: poll_result.text.chars().take(200).collect(), + finish_reason: Some("STOP".to_string()), + function_calls: Vec::new(), + grounding: false, + }).await; + t.finish("completed").await; + } + Json(serde_json::json!({ "candidates": [{ "content": { @@ -584,11 +771,13 @@ async fn gemini_stream( cascade_id: String, timeout: u64, mitm_rx: Option>, + trace: Option, ) -> axum::response::Response { let stream = async_stream::stream! { let start = std::time::Instant::now(); let mut last_text = String::new(); let mut last_thinking = String::new(); + let mut last_usage: Option = None; // Clear stale response (only if no pre-installed channel) if mitm_rx.is_none() { @@ -665,13 +854,34 @@ async fn gemini_stream( "modelVersion": model_name, })).unwrap_or_default())); yield Ok(Event::default().data("[DONE]")); + if let Some(ref t) = trace { + let fc_summaries: Vec = calls.iter().map(|fc| crate::trace::FunctionCallSummary { + name: fc.name.clone(), args_preview: serde_json::to_string(&fc.args).unwrap_or_default().chars().take(200).collect(), + }).collect(); + t.record_response(0, crate::trace::ResponseSummary { + text_len: 0, thinking_len: last_thinking.len(), text_preview: String::new(), + finish_reason: Some("STOP".to_string()), + function_calls: fc_summaries, grounding: false, + }).await; + t.finish("tool_call").await; + } state.mitm_store.remove_request(&cascade_id).await; return; } MitmEvent::ResponseComplete => { if !last_text.is_empty() { // Final chunk with finishReason + usageMetadata - let usage_meta = build_usage_metadata(&state.mitm_store, &cascade_id).await; + let usage_meta = if let Some(ref u) = last_usage { + serde_json::json!({ + "promptTokenCount": u.input_tokens, + "candidatesTokenCount": u.output_tokens, + "totalTokenCount": u.input_tokens + u.output_tokens, + "thoughtsTokenCount": u.thinking_output_tokens, + "cachedContentTokenCount": u.cache_read_input_tokens, + }) + } else { + build_usage_metadata(&state.mitm_store, &cascade_id).await + }; yield Ok(Event::default().data(serde_json::to_string(&serde_json::json!({ "candidates": [{ "content": { @@ -684,6 +894,15 @@ async fn gemini_stream( "modelVersion": model_name, })).unwrap_or_default())); yield Ok(Event::default().data("[DONE]")); + if let Some(ref t) = trace { + t.record_response(0, crate::trace::ResponseSummary { + text_len: last_text.len(), thinking_len: last_thinking.len(), + text_preview: last_text.chars().take(200).collect(), + finish_reason: Some("STOP".to_string()), + function_calls: Vec::new(), grounding: false, + }).await; + t.finish("completed").await; + } state.mitm_store.remove_request(&cascade_id).await; return; } else if !last_thinking.is_empty() && !did_unblock_ls { @@ -714,10 +933,15 @@ async fn gemini_stream( } })).unwrap())); yield Ok(Event::default().data("[DONE]")); + if let Some(ref t) = trace { + t.record_error(format!("Upstream: {}", error_msg)).await; + t.finish("upstream_error").await; + } state.mitm_store.remove_request(&cascade_id).await; return; } - MitmEvent::Usage(_) | MitmEvent::Grounding(_) => {} + MitmEvent::Usage(u) => { last_usage = Some(u); } + MitmEvent::Grounding(_) => {} } } @@ -730,6 +954,10 @@ async fn gemini_stream( "code": 504, } })).unwrap())); + if let Some(ref t) = trace { + t.record_error(format!("Timeout: {timeout}s")).await; + t.finish("timeout").await; + } yield Ok(Event::default().data("[DONE]")); return; } diff --git a/src/api/mod.rs b/src/api/mod.rs index f5a1492..008b5c7 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -33,6 +33,10 @@ pub struct AppState { pub sessions: SessionManager, pub mitm_store: crate::mitm::store::MitmStore, pub quota_store: crate::quota::QuotaStore, + /// Whether the MITM proxy is active (false when --no-mitm). + pub mitm_enabled: bool, + /// Per-call debug trace collector. + pub trace: crate::trace::TraceCollector, } // ─── Router ────────────────────────────────────────────────────────────────── @@ -44,6 +48,7 @@ pub fn router(state: Arc) -> Router { "/v1/chat/completions", post(completions::handle_completions), ) + .route("/v1/gemini", post(gemini::handle_gemini)) .route( "/v1beta/{*path}", post(gemini::handle_gemini_v1beta), diff --git a/src/api/responses.rs b/src/api/responses.rs index 0359649..91007d6 100644 --- a/src/api/responses.rs +++ b/src/api/responses.rs @@ -11,7 +11,7 @@ use axum::{ use rand::Rng; use std::sync::atomic::{AtomicU32, Ordering}; use std::sync::Arc; -use tracing::{debug, info}; +use tracing::{debug, info, warn}; use super::models::{lookup_model, DEFAULT_MODEL, MODELS}; use super::polling::{ @@ -364,14 +364,9 @@ pub(crate) async fn handle_responses( } }); - // Build event channel - let has_custom_tools = tools.is_some(); - let (mitm_rx, event_tx) = if has_custom_tools { - let (tx, rx) = tokio::sync::mpsc::channel(64); - (Some(rx), Some(tx)) - } else { - (None, None) - }; + // Build event channel — always created for MITM response path + let (tx, rx) = tokio::sync::mpsc::channel(64); + let (mitm_rx, event_tx) = (Some(rx), tx); // Build tool rounds now that cascade_id is known let mut tool_rounds: Vec = Vec::new(); @@ -385,7 +380,23 @@ pub(crate) async fn handle_responses( }); } - // Register all per-request state atomically + // Start debug trace + let trace = state.trace.start(&cascade_id, "POST /v1/responses", &model.name, body.stream); + if let Some(ref t) = trace { + t.set_client_request(crate::trace::ClientRequestSummary { + message_count: if is_tool_result_turn { 0 } else { 1 }, + tool_count: body.tools.as_ref().map_or(0, |t| t.len()), + tool_round_count: tool_rounds.len(), + user_text_len: user_text.len(), + user_text_preview: user_text.chars().take(200).collect(), + system_prompt: body.instructions.is_some(), + has_image: image.is_some(), + }).await; + t.start_turn().await; + } + + let mitm_gate = std::sync::Arc::new(tokio::sync::Notify::new()); + let mitm_gate_clone = mitm_gate.clone(); state.mitm_store.register_request(crate::mitm::store::RequestContext { cascade_id: cascade_id.clone(), pending_user_text: user_text.clone(), @@ -399,6 +410,9 @@ pub(crate) async fn handle_responses( last_function_calls: Vec::new(), call_id_to_name: std::collections::HashMap::new(), created_at: std::time::Instant::now(), + gate: mitm_gate_clone, + trace_handle: trace.clone(), + trace_turn: 0, }).await; // Send REAL user text to LS @@ -432,6 +446,29 @@ pub(crate) async fn handle_responses( } } + // Wait for MITM gate: 5s → 502 if MITM enabled + let gate_start = std::time::Instant::now(); + let gate_matched = tokio::time::timeout( + std::time::Duration::from_secs(5), + mitm_gate.notified(), + ).await; + let gate_wait_ms = gate_start.elapsed().as_millis() as u64; + if gate_matched.is_err() { + if state.mitm_enabled { + state.mitm_store.remove_request(&cascade_id).await; + if let Some(ref t) = trace { t.record_error("MITM gate timeout (5s)".to_string()).await; t.finish("mitm_timeout").await; } + return err_response( + StatusCode::BAD_GATEWAY, + "MITM proxy did not match request within 5s".to_string(), + "mitm_timeout", + ); + } + warn!(cascade = %cascade_id, "MITM gate timeout (--no-mitm mode)"); + } else { + debug!(cascade = %cascade_id, gate_wait_ms, "MITM gate signaled — request matched"); + if let Some(ref t) = trace { t.record_mitm_match(0, gate_wait_ms).await; } + } + // Capture request params for response building let req_params = RequestParams { user_text: user_text.clone(), @@ -462,6 +499,7 @@ pub(crate) async fn handle_responses( body.timeout, req_params, mitm_rx, + trace, ) .await } else { @@ -473,6 +511,7 @@ pub(crate) async fn handle_responses( body.timeout, req_params, mitm_rx, + trace, ) .await } @@ -595,6 +634,7 @@ async fn handle_responses_sync( timeout: u64, params: RequestParams, mitm_rx: Option>, + trace: Option, ) -> axum::response::Response { let created_at = now_unix(); @@ -642,6 +682,30 @@ async fn handle_responses_sync( &state.mitm_store, &cascade_id, &None, ¶ms.user_text, "", ).await; state.mitm_store.remove_request(&cascade_id).await; + // Record trace before usage is moved + if let Some(ref t) = trace { + let fc_summaries: Vec = calls.iter().map(|fc| { + crate::trace::FunctionCallSummary { + name: fc.name.clone(), + args_preview: serde_json::to_string(&fc.args).unwrap_or_default().chars().take(200).collect(), + } + }).collect(); + t.record_response(0, crate::trace::ResponseSummary { + text_len: 0, + thinking_len: 0, + text_preview: String::new(), + finish_reason: Some("tool_calls".to_string()), + function_calls: fc_summaries, + grounding: false, + }).await; + t.set_usage(crate::trace::TrackedUsage { + input_tokens: usage.input_tokens, + output_tokens: usage.output_tokens, + thinking_tokens: usage.output_tokens_details.reasoning_tokens, + cache_read: usage.input_tokens_details.cached_tokens, + }).await; + t.finish("tool_call").await; + } let resp = build_response_object( ResponseData { id: response_id, @@ -688,6 +752,24 @@ async fn handle_responses_sync( let msg_id = format!("msg_{}", uuid::Uuid::new_v4().to_string().replace('-', "")); output_items.push(build_message_output(&msg_id, &acc_text)); + // Record trace before usage is moved + if let Some(ref t) = trace { + t.record_response(0, crate::trace::ResponseSummary { + text_len: acc_text.len(), + thinking_len: acc_thinking.as_ref().map_or(0, |s| s.len()), + text_preview: acc_text.chars().take(200).collect(), + finish_reason: Some("stop".to_string()), + function_calls: Vec::new(), + grounding: false, + }).await; + t.set_usage(crate::trace::TrackedUsage { + input_tokens: usage.input_tokens, + output_tokens: usage.output_tokens, + thinking_tokens: usage.output_tokens_details.reasoning_tokens, + cache_read: usage.input_tokens_details.cached_tokens, + }).await; + t.finish("completed").await; + } let resp = build_response_object( ResponseData { id: response_id, @@ -705,6 +787,7 @@ async fn handle_responses_sync( } MitmEvent::UpstreamError(err) => { state.mitm_store.remove_request(&cascade_id).await; + if let Some(ref t) = trace { t.record_error(format!("Upstream: {}", err.message.as_deref().unwrap_or("unknown"))).await; t.finish("upstream_error").await; } return upstream_err_response(&err); } } @@ -712,6 +795,7 @@ async fn handle_responses_sync( // Timeout state.mitm_store.remove_request(&cascade_id).await; + if let Some(ref t) = trace { t.record_error(format!("Timeout: {}s", timeout)).await; t.finish("timeout").await; } return err_response( StatusCode::GATEWAY_TIMEOUT, format!("Timeout: no response from Google API after {timeout}s"), @@ -772,6 +856,31 @@ async fn handle_responses_sync( ) .await; + // Record trace before usage is moved + if let Some(ref t) = trace { + let fc_summaries: Vec = calls.iter().map(|fc| { + crate::trace::FunctionCallSummary { + name: fc.name.clone(), + args_preview: serde_json::to_string(&fc.args).unwrap_or_default().chars().take(200).collect(), + } + }).collect(); + t.record_response(0, crate::trace::ResponseSummary { + text_len: poll_result.text.len(), + thinking_len: poll_result.thinking.as_ref().map_or(0, |s| s.len()), + text_preview: String::new(), + finish_reason: Some("tool_calls".to_string()), + function_calls: fc_summaries, + grounding: false, + }).await; + t.set_usage(crate::trace::TrackedUsage { + input_tokens: usage.input_tokens, + output_tokens: usage.output_tokens, + thinking_tokens: usage.output_tokens_details.reasoning_tokens, + cache_read: usage.input_tokens_details.cached_tokens, + }).await; + t.finish("tool_call").await; + } + let resp = build_response_object( ResponseData { id: response_id, @@ -809,6 +918,25 @@ async fn handle_responses_sync( } output_items.push(build_message_output(&msg_id, &poll_result.text)); + // Record trace before usage is moved + if let Some(ref t) = trace { + t.record_response(0, crate::trace::ResponseSummary { + text_len: poll_result.text.len(), + thinking_len: thinking_text.as_ref().map_or(0, |s| s.len()), + text_preview: poll_result.text.chars().take(200).collect(), + finish_reason: Some("stop".to_string()), + function_calls: Vec::new(), + grounding: false, + }).await; + t.set_usage(crate::trace::TrackedUsage { + input_tokens: usage.input_tokens, + output_tokens: usage.output_tokens, + thinking_tokens: usage.output_tokens_details.reasoning_tokens, + cache_read: usage.input_tokens_details.cached_tokens, + }).await; + t.finish("completed").await; + } + let resp = build_response_object( ResponseData { id: response_id, @@ -836,6 +964,7 @@ async fn handle_responses_stream( timeout: u64, params: RequestParams, mitm_rx: Option>, + trace: Option, ) -> axum::response::Response { let stream = async_stream::stream! { let msg_id = format!("msg_{}", uuid::Uuid::new_v4().to_string().replace('-', "")); @@ -1111,6 +1240,14 @@ async fn handle_responses_stream( ¶ms.user_text, "", ).await; + // Save trace usage before move + let trace_usage = crate::trace::TrackedUsage { + input_tokens: usage.input_tokens, + output_tokens: usage.output_tokens, + thinking_tokens: usage.output_tokens_details.reasoning_tokens, + cache_read: usage.input_tokens_details.cached_tokens, + }; + let final_resp = build_response_object( ResponseData { id: response_id.clone(), @@ -1132,6 +1269,19 @@ async fn handle_responses_stream( "response": response_to_json(&final_resp), }), )); + if let Some(ref t) = trace { + let fc_summaries: Vec = calls.iter().map(|fc| crate::trace::FunctionCallSummary { + name: fc.name.clone(), args_preview: serde_json::to_string(&fc.args).unwrap_or_default().chars().take(200).collect(), + }).collect(); + t.record_response(0, crate::trace::ResponseSummary { + text_len: 0, thinking_len: last_thinking.len(), + text_preview: String::new(), + finish_reason: Some("tool_calls".to_string()), + function_calls: fc_summaries, grounding: false, + }).await; + t.set_usage(trace_usage).await; + t.finish("tool_call").await; + } state.mitm_store.remove_request(&cascade_id).await; return; } @@ -1150,6 +1300,16 @@ async fn handle_responses_stream( ) { yield Ok(evt); } + if let Some(ref t) = trace { + t.record_response(0, crate::trace::ResponseSummary { + text_len: last_text.len(), + thinking_len: thinking_text.as_ref().map_or(0, |s| s.len()), + text_preview: last_text.chars().take(200).collect(), + finish_reason: Some("stop".to_string()), + function_calls: Vec::new(), grounding: false, + }).await; + t.finish("completed").await; + } state.mitm_store.remove_request(&cascade_id).await; return; } else if !last_thinking.is_empty() { @@ -1186,6 +1346,10 @@ async fn handle_responses_stream( }, }), )); + if let Some(ref t) = trace { + t.record_error(format!("Upstream: {}", error_msg)).await; + t.finish("upstream_error").await; + } state.mitm_store.remove_request(&cascade_id).await; return; } @@ -1213,6 +1377,10 @@ async fn handle_responses_stream( }, }), )); + if let Some(ref t) = trace { + t.record_error(format!("Timeout: {timeout}s")).await; + t.finish("timeout").await; + } return; } diff --git a/src/api/search.rs b/src/api/search.rs index de3f2c7..579f73f 100644 --- a/src/api/search.rs +++ b/src/api/search.rs @@ -138,12 +138,29 @@ async fn do_search(state: Arc, body: SearchRequest) -> axum::response: } }; - // Register per-request state — no tools, just generation params for search grounding + // Start debug trace + let trace = state.trace.start(&cascade_id, "POST /v1/search", model.name, false); + if let Some(ref t) = trace { + t.set_client_request(crate::trace::ClientRequestSummary { + message_count: 1, + tool_count: 0, + tool_round_count: 0, + user_text_len: body.query.len(), + user_text_preview: body.query.chars().take(200).collect(), + system_prompt: false, + has_image: false, + }).await; + t.start_turn().await; + } + + let mitm_gate = std::sync::Arc::new(tokio::sync::Notify::new()); + let mitm_gate_clone = mitm_gate.clone(); + let (mitm_tx, mut mitm_rx) = tokio::sync::mpsc::channel(64); state.mitm_store.register_request(crate::mitm::store::RequestContext { cascade_id: cascade_id.clone(), pending_user_text: search_prompt.clone(), - event_channel: None, - generation_params: Some(gp), + event_channel: mitm_tx, + generation_params: Some(gp.clone()), pending_image: None, tools: None, tool_config: None, @@ -152,6 +169,9 @@ async fn do_search(state: Arc, body: SearchRequest) -> axum::response: last_function_calls: Vec::new(), call_id_to_name: std::collections::HashMap::new(), created_at: std::time::Instant::now(), + gate: mitm_gate_clone, + trace_handle: trace.clone(), + trace_turn: 0, }).await; // Send dot to LS — real search prompt injected by MITM proxy @@ -168,32 +188,176 @@ async fn do_search(state: Arc, body: SearchRequest) -> axum::response: ); } - // Poll for response + // ── Strict timeout cascade ─────────────────────────────────────────────── + // 5s gate → MITM didn't match → 502 + let gate_matched = tokio::time::timeout( + std::time::Duration::from_secs(5), + mitm_gate.notified(), + ).await; + + if gate_matched.is_err() { + if state.mitm_enabled { + state.mitm_store.remove_request(&cascade_id).await; + return err_response( + StatusCode::BAD_GATEWAY, + "MITM proxy did not match request within 5s".to_string(), + "mitm_timeout", + ); + } + // --no-mitm fallback: use polling + tracing::warn!(cascade = %cascade_id, "MITM gate timeout (--no-mitm mode, falling back to polling)"); + } + + // ── Channel-based response path (primary) ──────────────────────────────── + if state.mitm_enabled { + let timeout = body.timeout; + let mut response_text = String::new(); + let mut last_usage: Option = None; + let mut retries = 0u32; + const MAX_RETRIES: u32 = 3; + + while let Some(event) = tokio::time::timeout( + std::time::Duration::from_secs(timeout), + mitm_rx.recv(), + ).await.ok().flatten() { + use crate::mitm::store::MitmEvent; + match event { + MitmEvent::TextDelta(t) => { response_text.push_str(&t); } + MitmEvent::ThinkingDelta(_) => {} // search doesn't use thinking + MitmEvent::Usage(u) => { last_usage = Some(u); } + MitmEvent::Grounding(_) => {} // stored by proxy directly + MitmEvent::FunctionCall(_) => {} // not expected for search + MitmEvent::ResponseComplete => { + // Check if we got actual content — if not, this was a + // thinking-only intermediate response. The LS will make + // a follow-up request; re-register context and keep waiting. + let grounding_peek = state.mitm_store.peek_grounding().await; + if response_text.is_empty() && grounding_peek.is_none() { + retries += 1; + if retries >= MAX_RETRIES { + tracing::warn!(cascade = %cascade_id, retries, "Search: max retries reached with no content — giving up"); + break; + } + let (new_tx, new_rx) = tokio::sync::mpsc::channel(64); + let new_gate = std::sync::Arc::new(tokio::sync::Notify::new()); + state.mitm_store.register_request(crate::mitm::store::RequestContext { + cascade_id: cascade_id.clone(), + pending_user_text: search_prompt.clone(), + event_channel: new_tx, + generation_params: Some(gp.clone()), + pending_image: None, + tools: None, + tool_config: None, + pending_tool_results: Vec::new(), + tool_rounds: Vec::new(), + last_function_calls: Vec::new(), + call_id_to_name: std::collections::HashMap::new(), + created_at: std::time::Instant::now(), + gate: new_gate, + trace_handle: trace.clone(), + trace_turn: 0, + }).await; + mitm_rx = new_rx; + tracing::debug!( + cascade = %cascade_id, retries, + "Search: empty response — re-registered context for follow-up" + ); + continue; + } + break; + } + MitmEvent::UpstreamError(err) => { + if let Some(ref t) = trace { + t.record_error(format!("Upstream: {}", super::util::upstream_error_message(&err))).await; + t.finish("upstream_error").await; + } + state.mitm_store.remove_request(&cascade_id).await; + return upstream_err_response(&err); + } + } + } + + // Extract grounding metadata (stored by dispatch_stream_events) + let grounding = state.mitm_store.take_grounding().await; + state.mitm_store.remove_request(&cascade_id).await; + + if response_text.is_empty() && grounding.is_none() { + if let Some(ref t) = trace { + t.record_error(format!("Timeout: no search response after {timeout}s (retries: {retries})")).await; + t.finish("timeout").await; + } + return err_response( + StatusCode::GATEWAY_TIMEOUT, + format!("Timeout: no search response after {timeout}s"), + "upstream_error", + ); + } + + return { + // Finalize trace for channel-based path + if let Some(ref t) = trace { + t.record_response(0, crate::trace::ResponseSummary { + text_len: response_text.len(), thinking_len: 0, + text_preview: response_text.chars().take(200).collect(), + finish_reason: Some("stop".to_string()), + function_calls: Vec::new(), grounding: grounding.is_some(), + }).await; + if let Some((it, ot)) = last_usage.as_ref().map(|u| (u.input_tokens, u.output_tokens)) { + t.set_usage(crate::trace::TrackedUsage { + input_tokens: it, output_tokens: ot, + thinking_tokens: 0, cache_read: 0, + }).await; + } + t.finish("completed").await; + } + build_search_response(&body.query, model.name, response_text, grounding, last_usage.map(|u| (u.input_tokens, u.output_tokens))) + }; + } + + // ── Fallback: polling path (--no-mitm only) ────────────────────────────── let poll_result = poll_for_response(&state, &cascade_id, body.timeout).await; if let Some(ref err) = poll_result.upstream_error { return upstream_err_response(err); } - // Extract grounding metadata let grounding = state.mitm_store.take_grounding().await; - // The poll result text contains the model's summary (grounded response) let response_text = if !poll_result.text.is_empty() { poll_result.text.clone() } else { - // Fall back to MITM captured text state.mitm_store.take_response_text().await.unwrap_or_default() }; - // Clean up state.mitm_store.remove_request(&cascade_id).await; state.mitm_store.clear_response_async().await; - // Build the search response + // Finalize trace for polling path + if let Some(ref t) = trace { + t.record_response(0, crate::trace::ResponseSummary { + text_len: response_text.len(), thinking_len: 0, + text_preview: response_text.chars().take(200).collect(), + finish_reason: Some("stop".to_string()), + function_calls: Vec::new(), grounding: grounding.is_some(), + }).await; + t.finish("completed").await; + } + + build_search_response(&body.query, model.name, response_text, grounding, poll_result.usage.map(|u| (u.input_tokens, u.output_tokens))) +} + +fn build_search_response( + query: &str, + model_name: &str, + response_text: String, + grounding: Option, + usage: Option<(u64, u64)>, +) -> axum::response::Response { + use axum::Json; + let mut response = serde_json::json!({ "object": "search_result", - "query": body.query, - "model": model.name, + "query": query, + "model": model_name, "summary": response_text, }); @@ -267,11 +431,11 @@ async fn do_search(state: Arc, body: SearchRequest) -> axum::response: } // Include usage if available - if let Some(ref u) = poll_result.usage { + if let Some((input, output)) = usage { response["usage"] = serde_json::json!({ - "input_tokens": u.input_tokens, - "output_tokens": u.output_tokens, - "total_tokens": u.input_tokens + u.output_tokens, + "input_tokens": input, + "output_tokens": output, + "total_tokens": input + output, }); } diff --git a/src/main.rs b/src/main.rs index 666c9ab..b6371e1 100644 --- a/src/main.rs +++ b/src/main.rs @@ -12,6 +12,7 @@ mod proto; mod quota; mod session; mod standalone; +mod trace; mod warmup; use api::AppState; @@ -62,6 +63,10 @@ struct Cli { /// Classic mode — requires a running Antigravity app. Alias for --no-headless. #[arg(long, conflicts_with = "headless")] classic: bool, + + /// Disable per-call debug traces (on by default, writes JSON to ~/.config/antigravity-proxy/traces/) + #[arg(long)] + no_trace: bool, } #[tokio::main] @@ -272,11 +277,21 @@ async fn main() { quota_store.clone().start_polling(Arc::clone(&backend)); info!("Quota monitor started (polling every 60s)"); + // ── Step 4c: Debug trace collector ──────────────────────────────────────── + let trace_enabled = !cli.no_trace; + let trace_collector = trace::TraceCollector::new(trace_enabled); + if trace_enabled { + trace_collector.cleanup_old_traces(7); + info!("Debug tracing enabled → ~/.config/antigravity-proxy/traces/"); + } + let state = Arc::new(AppState { backend, sessions: SessionManager::new(), mitm_store, quota_store, + mitm_enabled: mitm_handle.is_some(), + trace: trace_collector, }); // Periodic backend refresh — keeps LS connection details fresh diff --git a/src/mitm/intercept.rs b/src/mitm/intercept.rs index 87462e7..2fe0f22 100644 --- a/src/mitm/intercept.rs +++ b/src/mitm/intercept.rs @@ -89,6 +89,8 @@ pub struct StreamingAccumulator { pub grounding_metadata: Option, /// Buffer for reassembling lines split across TCP reads. pub pending_data: String, + /// Thinking signature (base64 opaque blob) from non-function-call response parts. + pub thinking_signature: Option, } impl StreamingAccumulator { @@ -150,8 +152,12 @@ impl StreamingAccumulator { .as_secs(), }); } - // Capture non-thinking response text (skip thoughtSignature parts) - else if part.get("thoughtSignature").is_none() { + // Capture non-thinking response text + else { + // Capture thoughtSignature from response parts (not function call parts) + if let Some(sig) = part.get("thoughtSignature").and_then(|v| v.as_str()) { + self.thinking_signature = Some(sig.to_string()); + } if let Some(text) = part["text"].as_str() { if !text.is_empty() { self.response_text.push_str(text); @@ -277,6 +283,7 @@ impl StreamingAccumulator { .duration_since(std::time::UNIX_EPOCH) .unwrap_or_default() .as_secs(), + thinking_signature: self.thinking_signature, } } } @@ -302,6 +309,7 @@ fn extract_usage_from_message(msg: &Value) -> Option { .duration_since(std::time::UNIX_EPOCH) .unwrap_or_default() .as_secs(), + thinking_signature: None, }) } diff --git a/src/mitm/proto.rs b/src/mitm/proto.rs index fe7aa3b..0dbdb95 100644 --- a/src/mitm/proto.rs +++ b/src/mitm/proto.rs @@ -95,6 +95,7 @@ impl GrpcUsage { .duration_since(std::time::UNIX_EPOCH) .unwrap_or_default() .as_secs(), + thinking_signature: None, } } } diff --git a/src/mitm/proxy.rs b/src/mitm/proxy.rs index b943065..44532d7 100644 --- a/src/mitm/proxy.rs +++ b/src/mitm/proxy.rs @@ -436,143 +436,111 @@ async fn handle_http_over_tls( // checkpoints) from stealing the RequestContext. // ── Request modification ───────────────────────────────────── - // Dechunk body → check if agent request → modify → rechunk + // Dechunk body → check for our nonce → modify → rechunk + // + // Detection is deterministic: if the raw body bytes contain our + // nonce tag AND we have a pending RequestContext + // for that cascade, it's our request. No JSON parsing needed. if modify_requests && body_len > 0 { let body_slice = &request_buf[headers_end..]; let raw_body = super::modify::dechunk(body_slice); - // Only modify "agent" requests, not "checkpoint" (LS internal) - let body_str = String::from_utf8_lossy(&raw_body); - let is_agent = body_str.contains("\"requestType\":\"agent\"") - || body_str.contains("\"requestType\": \"agent\""); + // Fast nonce detection: search raw bytes for tag. + // This is the sole signal — no requestType check, no + // USER_REQUEST wrapper scanning, no JSON parsing for detection. + let nonce_cascade = extract_cascade_hint(&raw_body); + let effective_cascade = nonce_cascade.or(cascade_hint.clone()); - if is_agent { - // Re-extract cascade_hint from the dechunked (JSON-parseable) body. - // The chunked transfer encoding body at `request_buf[headers_end..]` - // can't be JSON-parsed, but `raw_body` (dechunked) can. - let precise_cascade = extract_cascade_hint(&raw_body); + // Only take RequestContext if we found our nonce tag + let has_nonce = effective_cascade.is_some() && { + let body_str = String::from_utf8_lossy(&raw_body); + // The nonce is `` — check raw bytes + if let Some(ref cid) = effective_cascade { + body_str.contains(&format!("", cid)) + } else { + false + } + }; + + let mut request_ctx: Option = if has_nonce { debug!( - cascade = ?precise_cascade, - "MITM: cascade from dechunked requestId" + cascade = ?effective_cascade, + "MITM: nonce matched — taking RequestContext" ); - - // Check if ANY user message contains our dummy dot prompt - // within a wrapper. - // Only then should we consume the pending RequestContext. - // This prevents LS internal requests (title gen, etc.) from - // consuming the context meant for the user's actual request. - // NOTE: We check ALL user messages because the LS appends context - // messages AFTER the dot prompt (conversation summaries, etc.). - // We look for + dot specifically to avoid matching - // old markers in history (which are in model messages). - let contains_our_dot = serde_json::from_slice::(&raw_body) - .ok() - .and_then(|json| { - let contents = json.pointer("/request/contents")?.as_array()?; - for msg in contents.iter() { - let is_user = msg.get("role") - .and_then(|r| r.as_str()) - .map_or(true, |r| r == "user"); - if !is_user { continue; } - if let Some(text) = msg.pointer("/parts/0/text").and_then(|v| v.as_str()) { - // Check for dot in wrapper - if text.contains("") { - if let (Some(s), Some(e)) = (text.find(""), text.find("")) { - let inner = &text[s + 14..e]; // 14 = len("") - let it = inner.trim(); - if it == "." || it.starts_with(".")) { - return Some(true); - } - } - } - Some(false) - }) - .unwrap_or(false); - - // Only take the RequestContext if this request has our dot - let effective_cascade = precise_cascade.or(cascade_hint.clone()); - let mut request_ctx: Option = if contains_our_dot { - let ctx = if let Some(ref cid) = effective_cascade { - store.take_request(cid).await - } else { + let ctx = if let Some(ref cid) = effective_cascade { + store.take_request(cid).await + } else { + None + }; + if ctx.is_some() { + ctx + } else if let Some(ref cid) = effective_cascade { + // Subsequent turn of an already-processed cascade + if store.has_cascade_cache(cid).await { + debug!(cascade = %cid, "MITM: subsequent turn — using cached context"); None - }; - if ctx.is_some() { - ctx - } else if let Some(ref cid) = effective_cascade { - // Check if this is a subsequent turn (turn 1+) of an - // already-processed cascade. If so, DON'T fall through - // to take_latest_request — that would steal an unrelated - // cascade's context. - if store.has_cascade_cache(cid).await { - debug!(cascade = %cid, "MITM: subsequent turn — using cached context"); - None - } else { - // Unknown cascade, try latest fallback - store.take_latest_request().await - } } else { + // Unknown cascade with our nonce, try latest fallback store.take_latest_request().await } } else { - None - }; - - // Extract event channel from matched context - if let Some(ref mut ctx) = request_ctx { - event_tx = ctx.event_channel.take(); + store.take_latest_request().await } + } else { + // No nonce → LS internal request (title gen, checkpoint, etc.) + // Don't touch it. + None + }; - // Build ToolContext from RequestContext (turn 0) or cached - // context (turn 1+). On turn 0, we also cache the context - // for subsequent turns. - let tool_ctx = if let Some(ctx) = request_ctx.take() { - // Turn 0: cache context for subsequent turns - if let Some(ref cid) = effective_cascade { - store.cache_cascade(cid, super::store::CascadeCache { - user_text: ctx.pending_user_text.clone(), - tools: ctx.tools.clone(), - tool_config: ctx.tool_config.clone(), - generation_params: ctx.generation_params.clone(), - }).await; - } + // Extract event channel from matched context + if let Some(ref ctx) = request_ctx { + event_tx = Some(ctx.event_channel.clone()); + } + + // Build ToolContext from RequestContext (turn 0) or cached + // context (turn 1+). On turn 0, we also cache the context + // for subsequent turns. + let tool_ctx = if let Some(ctx) = request_ctx.take() { + // Turn 0: cache context for subsequent turns + if let Some(ref cid) = effective_cascade { + store.cache_cascade(cid, super::store::CascadeCache { + user_text: ctx.pending_user_text.clone(), + tools: ctx.tools.clone(), + tool_config: ctx.tool_config.clone(), + generation_params: ctx.generation_params.clone(), + }).await; + } + Some(super::modify::ToolContext { + pending_user_text: ctx.pending_user_text, + tools: ctx.tools, + tool_config: ctx.tool_config, + pending_results: ctx.pending_tool_results, + last_calls: ctx.last_function_calls, + generation_params: ctx.generation_params, + pending_image: ctx.pending_image, + tool_rounds: ctx.tool_rounds, + }) + } else if let Some(ref cid) = effective_cascade { + // Turn 1+: rebuild lite ToolContext from cache + if let Some(cached) = store.get_cascade_cache(cid).await { Some(super::modify::ToolContext { - pending_user_text: ctx.pending_user_text, - tools: ctx.tools, - tool_config: ctx.tool_config, - pending_results: ctx.pending_tool_results, - last_calls: ctx.last_function_calls, - generation_params: ctx.generation_params, - pending_image: ctx.pending_image, - tool_rounds: ctx.tool_rounds, + pending_user_text: cached.user_text, + tools: cached.tools, + tool_config: cached.tool_config, + pending_results: vec![], + last_calls: vec![], + generation_params: cached.generation_params, + pending_image: None, + tool_rounds: vec![], }) - } else if let Some(ref cid) = effective_cascade { - // Turn 1+: rebuild lite ToolContext from cache - if let Some(cached) = store.get_cascade_cache(cid).await { - Some(super::modify::ToolContext { - pending_user_text: cached.user_text, - tools: cached.tools, - tool_config: cached.tool_config, - pending_results: vec![], - last_calls: vec![], - generation_params: cached.generation_params, - pending_image: None, - tool_rounds: vec![], - }) - } else { - None - } } else { None - }; + } + } else { + None + }; + if tool_ctx.is_some() || has_nonce { if let Some(modified_body) = super::modify::modify_request(&raw_body, tool_ctx.as_ref()) { @@ -1014,6 +982,29 @@ async fn dispatch_stream_events( let _ = tx.send(super::store::MitmEvent::Grounding(gm.clone())).await; } if acc.is_complete { + // Send usage BEFORE ResponseComplete so handlers have it when processing completion + if acc.output_tokens > 0 || acc.input_tokens > 0 { + let usage_snapshot = super::store::ApiUsage { + input_tokens: acc.input_tokens, + output_tokens: acc.output_tokens, + cache_creation_input_tokens: acc.cache_creation_input_tokens, + cache_read_input_tokens: acc.cache_read_input_tokens, + thinking_output_tokens: acc.thinking_tokens, + thinking_text: None, + response_text: None, + response_output_tokens: 0, + model: acc.model.clone(), + stop_reason: acc.stop_reason.clone(), + api_provider: acc.api_provider.clone().unwrap_or_else(|| "unknown".to_string()).into(), + grpc_method: None, + captured_at: std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs(), + thinking_signature: acc.thinking_signature.clone(), + }; + let _ = tx.send(super::store::MitmEvent::Usage(usage_snapshot)).await; + } info!( response_text_len = acc.response_text.len(), thinking_text_len = acc.thinking_text.len(), diff --git a/src/mitm/store.rs b/src/mitm/store.rs index 18cb35a..e8d288a 100644 --- a/src/mitm/store.rs +++ b/src/mitm/store.rs @@ -45,6 +45,10 @@ pub struct ApiUsage { pub grpc_method: Option, /// Timestamp when this usage was captured. pub captured_at: u64, + /// Thinking signature from Google's response (base64 opaque blob). + /// Required for multi-turn with thinking models. + #[serde(skip_serializing_if = "Option::is_none")] + pub thinking_signature: Option, } /// A captured function call from Google's API response. @@ -188,8 +192,7 @@ pub struct RequestContext { /// Real user text for MITM injection (LS receives "." instead). pub pending_user_text: String, /// Event channel for real-time streaming from MITM → API handler. - /// Only present when custom tools are active. - pub event_channel: Option>, + pub event_channel: mpsc::Sender, /// Client-specified generation parameters (temperature, top_p, etc.). pub generation_params: Option, /// Image to inject into the Google API request. @@ -208,6 +211,13 @@ pub struct RequestContext { pub call_id_to_name: HashMap, /// When this context was created (for TTL cleanup). pub created_at: Instant, + /// Gate: signaled when MITM takes this context. + /// API handlers wait on this with a timeout to detect match failures. + pub gate: Arc, + /// Debug trace handle (if tracing is enabled). + pub trace_handle: Option, + /// Current turn index in the trace (for multi-turn tracking). + pub trace_turn: usize, } // ─── MitmStore ─────────────────────────────────────────────────────────────── @@ -295,8 +305,9 @@ impl MitmStore { /// Called by the MITM proxy when intercepting the LS's outbound request. pub async fn take_request(&self, cascade_id: &str) -> Option { let ctx = self.pending_requests.write().await.remove(cascade_id); - if ctx.is_some() { - debug!(cascade = %cascade_id, "Took request context"); + if let Some(ref c) = ctx { + c.gate.notify_one(); + debug!(cascade = %cascade_id, "Took request context (gate signaled)"); } ctx } @@ -315,8 +326,9 @@ impl MitmStore { .map(|(k, _)| k.clone()); if let Some(key) = latest_key { let ctx = pending.remove(&key); - if ctx.is_some() { - debug!(cascade = %key, "Took latest request context (fallback)"); + if let Some(ref c) = ctx { + c.gate.notify_one(); + debug!(cascade = %key, "Took latest request context (fallback, gate signaled)"); } ctx } else { @@ -577,12 +589,42 @@ impl MitmStore { // ── Compat shims for streaming tool-call loops ────────────────────── - /// Update the event channel on an existing request context. - /// Used by streaming loop handlers when re-registering for a new tool round. + /// Update the event channel on an existing request context, + /// or re-register a minimal context if it was already consumed by `take_request`. + /// + /// This is critical for thinking-only intermediate responses: the MITM proxy + /// consumes the context via `take_request`, but the handler needs to re-install + /// a channel for the LS's follow-up request. pub async fn set_channel(&self, cascade_id: &str, tx: mpsc::Sender) { - self.update_request(cascade_id, |ctx| { - ctx.event_channel = Some(tx); + let updated = self.update_request(cascade_id, |ctx| { + ctx.event_channel = tx.clone(); }).await; + if !updated { + // Context was already consumed — re-register a minimal one + // so the MITM proxy can match the follow-up request. + let gate = std::sync::Arc::new(tokio::sync::Notify::new()); + self.register_request(RequestContext { + cascade_id: cascade_id.to_string(), + pending_user_text: String::new(), + event_channel: tx, + generation_params: None, + pending_image: None, + tools: None, + tool_config: None, + pending_tool_results: Vec::new(), + tool_rounds: Vec::new(), + last_function_calls: Vec::new(), + call_id_to_name: std::collections::HashMap::new(), + created_at: std::time::Instant::now(), + gate, + trace_handle: None, + trace_turn: 0, + }).await; + tracing::debug!( + cascade = cascade_id, + "set_channel: re-registered minimal context (original was consumed)" + ); + } } /// No-op. Upstream errors are now delivered through the event channel. diff --git a/src/trace.rs b/src/trace.rs new file mode 100644 index 0000000..6ca6625 --- /dev/null +++ b/src/trace.rs @@ -0,0 +1,509 @@ +//! Per-call debug trace system. +//! +//! Every API call gets a structured JSON trace file written to +//! `~/.config/antigravity-proxy/traces/{YYYY-MM-DD}/{HH-MM-SS}_{cascade_short}.json`. +//! +//! Designed for LLM consumption: compact, structured, no raw bodies. + +use serde::Serialize; +use std::path::PathBuf; +use std::sync::Arc; +use std::time::Instant; +use tokio::sync::Mutex; + +/// Shared trace state for `AppState`. +#[derive(Clone)] +pub struct TraceCollector { + enabled: bool, + traces_dir: PathBuf, +} + +impl TraceCollector { + pub fn new(enabled: bool) -> Self { + let home = std::env::var("HOME").unwrap_or_else(|_| "/tmp".to_string()); + let traces_dir = PathBuf::from(home) + .join(".config") + .join("antigravity-proxy") + .join("traces"); + Self { + enabled, + traces_dir, + } + } + + /// Whether tracing is enabled. + pub fn enabled(&self) -> bool { + self.enabled + } + + /// Start a new trace for an API call. Returns `None` if tracing is disabled. + pub fn start(&self, cascade_id: &str, endpoint: &str, model: &str, stream: bool) -> Option { + if !self.enabled { + return None; + } + let now = chrono::Utc::now(); + Some(TraceHandle { + inner: Arc::new(Mutex::new(TraceData { + cascade_id: cascade_id.to_string(), + endpoint: endpoint.to_string(), + model: model.to_string(), + stream, + started_at: now.format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string(), + finished_at: None, + duration_ms: 0, + outcome: "in_progress".to_string(), + client_request: None, + turns: Vec::new(), + usage: None, + errors: Vec::new(), + })), + traces_dir: self.traces_dir.clone(), + started_at_chrono: now, + started_instant: Instant::now(), + }) + } + + /// Delete trace directories older than `max_age_days`. + pub fn cleanup_old_traces(&self, max_age_days: u32) { + if !self.enabled { + return; + } + let Ok(entries) = std::fs::read_dir(&self.traces_dir) else { + return; + }; + let cutoff = chrono::Utc::now() - chrono::Duration::days(max_age_days as i64); + let cutoff_str = cutoff.format("%Y-%m-%d").to_string(); + + for entry in entries.flatten() { + let name = entry.file_name().to_string_lossy().to_string(); + // Directory names are YYYY-MM-DD — lexicographic comparison works + if name < cutoff_str { + if let Err(e) = std::fs::remove_dir_all(entry.path()) { + tracing::warn!(dir = %name, error = %e, "trace: failed to remove old trace dir"); + } else { + tracing::info!(dir = %name, "trace: cleaned up old trace dir"); + } + } + } + } +} + +/// Handle to a single in-flight trace. Cloneable, thread-safe. +#[derive(Clone)] +pub struct TraceHandle { + inner: Arc>, + traces_dir: PathBuf, + started_at_chrono: chrono::DateTime, + started_instant: Instant, +} + +impl std::fmt::Debug for TraceHandle { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("TraceHandle") + .field("traces_dir", &self.traces_dir) + .finish() + } +} + +impl TraceHandle { + /// Record the client request summary. + pub async fn set_client_request(&self, req: ClientRequestSummary) { + self.inner.lock().await.client_request = Some(req); + } + + /// Start a new turn (thinking/tool-call round). Returns the turn index. + pub async fn start_turn(&self) -> usize { + let mut data = self.inner.lock().await; + let idx = data.turns.len(); + data.turns.push(TraceTurn { + turn: idx, + mitm_matched: false, + gate_wait_ms: None, + modify_summary: None, + request_bytes: None, + upstream_wait_ms: None, + response: None, + events_sent: Vec::new(), + handler_action: None, + }); + idx + } + + /// Record MITM match for a turn. + pub async fn record_mitm_match(&self, turn: usize, gate_wait_ms: u64) { + let mut data = self.inner.lock().await; + if let Some(t) = data.turns.get_mut(turn) { + t.mitm_matched = true; + t.gate_wait_ms = Some(gate_wait_ms); + } + } + + /// Record MITM modify summary for a turn. + pub async fn record_modify(&self, turn: usize, summary: String, original: u64, modified: u64) { + let mut data = self.inner.lock().await; + if let Some(t) = data.turns.get_mut(turn) { + t.modify_summary = Some(summary); + t.request_bytes = Some(RequestBytes { original, modified }); + } + } + + /// Record upstream wait time. + pub async fn record_upstream_wait(&self, turn: usize, wait_ms: u64) { + let mut data = self.inner.lock().await; + if let Some(t) = data.turns.get_mut(turn) { + t.upstream_wait_ms = Some(wait_ms); + } + } + + /// Record the response summary for a turn. + pub async fn record_response(&self, turn: usize, resp: ResponseSummary) { + let mut data = self.inner.lock().await; + if let Some(t) = data.turns.get_mut(turn) { + t.response = Some(resp); + } + } + + /// Record an event sent via channel. + pub async fn record_event(&self, turn: usize, event_name: &str) { + let mut data = self.inner.lock().await; + if let Some(t) = data.turns.get_mut(turn) { + t.events_sent.push(event_name.to_string()); + } + } + + /// Record the handler action for a turn. + pub async fn record_action(&self, turn: usize, action: &str) { + let mut data = self.inner.lock().await; + if let Some(t) = data.turns.get_mut(turn) { + t.handler_action = Some(action.to_string()); + } + } + + /// Record an error. + pub async fn record_error(&self, error: String) { + self.inner.lock().await.errors.push(error); + } + + /// Record final usage. + pub async fn set_usage(&self, usage: TrackedUsage) { + self.inner.lock().await.usage = Some(usage); + } + + /// Finalize the trace and write to disk as a per-call folder. + /// + /// Layout: `traces/{YYYY-MM-DD}/{HH-MM-SS}_{cascade_short}/` + /// - `summary.md` — always written, rich LLM-readable overview + /// - `request.json` — client request details (always) + /// - `turns.json` — per-turn MITM/gate/modify data (always) + /// - `response.json` — response summary + usage (if present) + /// - `events.json` — channel events (if non-empty) + /// - `errors.json` — errors (if any) + pub async fn finish(&self, outcome: &str) { + let mut data = self.inner.lock().await; + let now = chrono::Utc::now(); + data.finished_at = Some(now.format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string()); + data.duration_ms = self.started_instant.elapsed().as_millis() as u64; + data.outcome = outcome.to_string(); + + // Build folder path: traces/{YYYY-MM-DD}/{HH-MM-SS}_{cascade_short}/ + let date_str = self.started_at_chrono.format("%Y-%m-%d").to_string(); + let time_str = self.started_at_chrono.format("%H-%M-%S%.3f").to_string(); + let cascade_short = &data.cascade_id[..8.min(data.cascade_id.len())]; + let dir = self.traces_dir.join(&date_str).join(format!("{}_{}", time_str, cascade_short)); + + // Build all file contents while holding lock + let summary = generate_summary(&data); + let request_json = serde_json::to_string_pretty(&data.client_request).unwrap_or_default(); + let turns_json = serde_json::to_string_pretty(&data.turns).unwrap_or_default(); + + let response_json = if data.usage.is_some() || data.turns.iter().any(|t| t.response.is_some()) { + let resp = ResponseFile { + usage: data.usage.clone(), + }; + Some(serde_json::to_string_pretty(&resp).unwrap_or_default()) + } else { + None + }; + + let events_json = { + let all_events: Vec<_> = data.turns.iter() + .enumerate() + .filter(|(_, t)| !t.events_sent.is_empty()) + .map(|(i, t)| serde_json::json!({ "turn": i, "events": t.events_sent })) + .collect(); + if all_events.is_empty() { None } + else { Some(serde_json::to_string_pretty(&all_events).unwrap_or_default()) } + }; + + let errors_json = if data.errors.is_empty() { None } + else { Some(serde_json::to_string_pretty(&data.errors).unwrap_or_default()) }; + + // Build meta.txt for grep + let meta_txt = format!( + "cascade={} endpoint={} model={} outcome={} duration={}ms stream={}", + cascade_short, data.endpoint, data.model, data.outcome, data.duration_ms, data.stream + ); + + drop(data); // release lock before I/O + + tokio::spawn(async move { + if let Err(e) = tokio::fs::create_dir_all(&dir).await { + tracing::warn!(error = %e, "trace: failed to create dir"); + return; + } + // Always write summary + request + turns + meta + let _ = tokio::fs::write(dir.join("summary.md"), summary).await; + let _ = tokio::fs::write(dir.join("request.json"), request_json).await; + let _ = tokio::fs::write(dir.join("turns.json"), turns_json).await; + let _ = tokio::fs::write(dir.join("meta.txt"), meta_txt).await; + // Conditionally write response, events, errors + if let Some(j) = response_json { + let _ = tokio::fs::write(dir.join("response.json"), j).await; + } + if let Some(j) = events_json { + let _ = tokio::fs::write(dir.join("events.json"), j).await; + } + if let Some(j) = errors_json { + let _ = tokio::fs::write(dir.join("errors.json"), j).await; + } + tracing::debug!(path = %dir.display(), "trace: folder written"); + }); + } +} + +// ── Summary generation ───────────────────────────────────────────────── + +#[derive(Serialize)] +struct ResponseFile { + #[serde(skip_serializing_if = "Option::is_none")] + usage: Option, +} + +/// Build a rich markdown summary from trace data. +fn generate_summary(data: &TraceData) -> String { + let mut s = String::with_capacity(2048); + let cascade_short = &data.cascade_id[..8.min(data.cascade_id.len())]; + + // Header + s.push_str(&format!("# Trace: {} — {}\n\n", cascade_short, data.endpoint)); + + // Overview table + s.push_str("| Field | Value |\n|-------|-------|\n"); + s.push_str(&format!("| Cascade ID | `{}` |\n", data.cascade_id)); + s.push_str(&format!("| Model | {} |\n", data.model)); + s.push_str(&format!("| Stream | {} |\n", data.stream)); + s.push_str(&format!("| Started | {} |\n", data.started_at)); + if let Some(ref fin) = data.finished_at { + s.push_str(&format!("| Finished | {} |\n", fin)); + } + s.push_str(&format!("| Duration | {}ms |\n", data.duration_ms)); + s.push_str(&format!("| Outcome | **{}** |\n", data.outcome)); + s.push('\n'); + + // Client request + s.push_str("## Client Request\n\n"); + if let Some(ref req) = data.client_request { + s.push_str(&format!("- **Messages:** {} (user text: {} chars)\n", req.message_count, req.user_text_len)); + if !req.user_text_preview.is_empty() { + s.push_str(&format!("- **Preview:** `{}`\n", req.user_text_preview)); + } + s.push_str(&format!("- **Tools:** {} | **Tool rounds:** {}\n", req.tool_count, req.tool_round_count)); + if req.system_prompt { s.push_str("- **System prompt:** yes\n"); } + s.push_str(&format!("- **Image:** {}\n", if req.has_image { "yes" } else { "no" })); + } else { + s.push_str("(not recorded)\n"); + } + s.push_str("\n→ Full details in [request.json](./request.json)\n\n"); + + // Turns + s.push_str(&format!("## Turns ({} total)\n\n", data.turns.len())); + for turn in &data.turns { + s.push_str(&format!("### Turn {}\n\n", turn.turn)); + + // MITM match + if turn.mitm_matched { + s.push_str(&format!("- **MITM matched:** ✓ (gate wait: {}ms)\n", + turn.gate_wait_ms.unwrap_or(0))); + } else { + s.push_str("- **MITM matched:** ✗\n"); + } + + // Modify + if let Some(ref mod_sum) = turn.modify_summary { + s.push_str(&format!("- **Modify:** {}", mod_sum)); + if let Some(ref bytes) = turn.request_bytes { + s.push_str(&format!(" ({}B → {}B)", bytes.original, bytes.modified)); + } + s.push('\n'); + } + + // Upstream wait + if let Some(wait) = turn.upstream_wait_ms { + s.push_str(&format!("- **Upstream wait:** {}ms\n", wait)); + } + + // Response + if let Some(ref resp) = turn.response { + s.push_str(&format!("- **Response:** {} chars text, {} chars thinking", + resp.text_len, resp.thinking_len)); + if let Some(ref fr) = resp.finish_reason { + s.push_str(&format!(", finish_reason={}", fr)); + } + if !resp.function_calls.is_empty() { + let names: Vec<&str> = resp.function_calls.iter().map(|f| f.name.as_str()).collect(); + s.push_str(&format!(", tool_calls=[{}]", names.join(", "))); + } + if resp.grounding { + s.push_str(", grounding=yes"); + } + s.push('\n'); + if !resp.text_preview.is_empty() { + s.push_str(&format!("- **Output preview:** `{}`\n", resp.text_preview)); + } + } + + // Events + if !turn.events_sent.is_empty() { + s.push_str(&format!("- **Events:** {} sent ({})\n", + turn.events_sent.len(), + turn.events_sent.join(", "))); + } + + // Handler action + if let Some(ref action) = turn.handler_action { + s.push_str(&format!("- **Action:** {}\n", action)); + } + + s.push('\n'); + } + + if data.turns.iter().any(|t| t.response.is_some()) { + s.push_str("→ Full turn details in [turns.json](./turns.json)\n\n"); + } + + // Usage + if let Some(ref u) = data.usage { + s.push_str("## Usage\n\n"); + s.push_str(&format!("| Metric | Tokens |\n|--------|--------|\n")); + s.push_str(&format!("| Input | {} |\n", u.input_tokens)); + s.push_str(&format!("| Output | {} |\n", u.output_tokens)); + if u.thinking_tokens > 0 { + s.push_str(&format!("| Thinking | {} |\n", u.thinking_tokens)); + } + if u.cache_read > 0 { + s.push_str(&format!("| Cache read | {} |\n", u.cache_read)); + } + s.push_str("\n→ Full details in [response.json](./response.json)\n\n"); + } + + // Errors + if !data.errors.is_empty() { + s.push_str("## Errors\n\n"); + for err in &data.errors { + s.push_str(&format!("- ❌ {}\n", err)); + } + s.push_str("\n→ Full details in [errors.json](./errors.json)\n\n"); + } + + // Files index + s.push_str("## Files\n\n"); + s.push_str("| File | Contains |\n|------|----------|\n"); + s.push_str("| [request.json](./request.json) | Client request summary |\n"); + s.push_str("| [turns.json](./turns.json) | Per-turn MITM/gate/modify/response data |\n"); + if data.usage.is_some() || data.turns.iter().any(|t| t.response.is_some()) { + s.push_str("| [response.json](./response.json) | Response summaries + token usage |\n"); + } + if data.turns.iter().any(|t| !t.events_sent.is_empty()) { + s.push_str("| [events.json](./events.json) | Channel events per turn |\n"); + } + if !data.errors.is_empty() { + s.push_str("| [errors.json](./errors.json) | Error messages |\n"); + } + + s +} + +// ── Serializable data structures ─────────────────────────────────────── + +#[derive(Serialize)] +struct TraceData { + cascade_id: String, + endpoint: String, + model: String, + stream: bool, + started_at: String, + #[serde(skip_serializing_if = "Option::is_none")] + finished_at: Option, + duration_ms: u64, + outcome: String, + #[serde(skip_serializing_if = "Option::is_none")] + client_request: Option, + turns: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + usage: Option, + #[serde(skip_serializing_if = "Vec::is_empty")] + errors: Vec, +} + +#[derive(Serialize, Clone)] +pub struct ClientRequestSummary { + pub message_count: usize, + pub tool_count: usize, + pub tool_round_count: usize, + pub user_text_len: usize, + pub user_text_preview: String, + pub system_prompt: bool, + pub has_image: bool, +} + +#[derive(Serialize)] +struct TraceTurn { + turn: usize, + mitm_matched: bool, + #[serde(skip_serializing_if = "Option::is_none")] + gate_wait_ms: Option, + #[serde(skip_serializing_if = "Option::is_none")] + modify_summary: Option, + #[serde(skip_serializing_if = "Option::is_none")] + request_bytes: Option, + #[serde(skip_serializing_if = "Option::is_none")] + upstream_wait_ms: Option, + #[serde(skip_serializing_if = "Option::is_none")] + response: Option, + #[serde(skip_serializing_if = "Vec::is_empty")] + events_sent: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + handler_action: Option, +} + +#[derive(Serialize)] +struct RequestBytes { + original: u64, + modified: u64, +} + +#[derive(Serialize, Clone)] +pub struct ResponseSummary { + pub text_len: usize, + pub thinking_len: usize, + #[serde(skip_serializing_if = "String::is_empty")] + pub text_preview: String, + pub finish_reason: Option, + #[serde(skip_serializing_if = "Vec::is_empty")] + pub function_calls: Vec, + pub grounding: bool, +} + +#[derive(Serialize, Clone)] +pub struct FunctionCallSummary { + pub name: String, + pub args_preview: String, +} + +#[derive(Serialize, Clone)] +pub struct TrackedUsage { + pub input_tokens: u64, + pub output_tokens: u64, + pub thinking_tokens: u64, + pub cache_read: u64, +}