From 2882f7cce263e81bb1e097ca1dd7d0dcd573afb5 Mon Sep 17 00:00:00 2001 From: Nikketryhard Date: Sun, 15 Feb 2026 18:19:38 -0600 Subject: [PATCH] feat: propagate Google upstream errors to client When Google returns an error (400, 429, 500, etc.), the MITM proxy now captures it and the API handlers return it immediately instead of hanging until timeout. - UpstreamError struct stored in MitmStore - MITM proxy parses Google error JSON (message + status) - Polling handler checks for upstream errors each cycle - Streaming handlers emit response.failed / SSE error events - Error status mapped to OpenAI-style types (invalid_request_error, rate_limit_error, authentication_error, server_error, etc.) - All handlers clear stale errors at request start --- src/api/completions.rs | 29 +++++++++++++++++++-- src/api/gemini.rs | 8 ++++-- src/api/polling.rs | 24 ++++++++++++++++-- src/api/responses.rs | 57 +++++++++++++++++++++++++++++++++++++++--- src/api/search.rs | 5 +++- src/api/util.rs | 22 ++++++++++++++++ src/mitm/proxy.rs | 27 +++++++++++++++++--- src/mitm/store.rs | 37 +++++++++++++++++++++++++++ 8 files changed, 195 insertions(+), 14 deletions(-) diff --git a/src/api/completions.rs b/src/api/completions.rs index 7b0de4a..a085fcb 100644 --- a/src/api/completions.rs +++ b/src/api/completions.rs @@ -12,7 +12,7 @@ use tracing::{debug, info, warn}; use super::models::{lookup_model, DEFAULT_MODEL, MODELS}; use super::polling::{extract_response_text, extract_thinking_content, is_response_done, poll_for_response}; use super::types::*; -use super::util::{err_response, now_unix}; +use super::util::{err_response, upstream_err_response, now_unix}; use super::AppState; /// Extract a conversation/session ID from a flexible JSON value. @@ -488,8 +488,9 @@ async fn chat_completions_stream( let mut last_text = String::new(); let has_custom_tools = state.mitm_store.get_tools().await.is_some(); - // Clear any stale captured response from previous requests + // Clear any stale captured response and upstream errors from previous requests state.mitm_store.clear_response_async().await; + state.mitm_store.clear_upstream_error().await; // Initial role chunk yield Ok::<_, std::convert::Infallible>(Event::default().data(chunk_json( @@ -513,6 +514,27 @@ async fn chat_completions_stream( }; while start.elapsed().as_secs() < timeout { + // Check for upstream errors from MITM (Google API errors) + if let Some(err) = state.mitm_store.take_upstream_error().await { + let error_msg = err.message.clone() + .unwrap_or_else(|| format!("Google API returned HTTP {}", err.status)); + let error_type = match err.error_status.as_deref() { + Some("INVALID_ARGUMENT") => "invalid_request_error", + Some("RESOURCE_EXHAUSTED") => "rate_limit_error", + Some("PERMISSION_DENIED") | Some("UNAUTHENTICATED") => "authentication_error", + _ => "upstream_error", + }; + yield Ok(Event::default().data(serde_json::to_string(&serde_json::json!({ + "error": { + "message": error_msg, + "type": error_type, + "code": err.status, + } + })).unwrap())); + yield Ok(Event::default().data("[DONE]".to_string())); + break; + } + // ── Check for MITM-captured function calls FIRST ── // This runs independently of LS steps — the MITM captures tool calls // at the proxy layer, so we don't need to wait for LS processing. @@ -852,6 +874,9 @@ async fn chat_completions_sync( timeout: u64, ) -> axum::response::Response { let result = poll_for_response(&state, &cascade_id, timeout).await; + if let Some(ref err) = result.upstream_error { + return upstream_err_response(err); + } // Check MITM store first for real intercepted usage (fallback to _latest) let mitm = match state.mitm_store.take_usage(&cascade_id).await { diff --git a/src/api/gemini.rs b/src/api/gemini.rs index d3fe8e8..5f050ea 100644 --- a/src/api/gemini.rs +++ b/src/api/gemini.rs @@ -16,7 +16,7 @@ use tracing::{info, warn}; use super::models::{lookup_model, DEFAULT_MODEL, MODELS}; use super::polling::{extract_response_text, extract_thinking_content, is_response_done, poll_for_response}; -use super::util::err_response; +use super::util::{err_response, upstream_err_response}; use super::AppState; use crate::mitm::store::PendingToolResult; @@ -332,8 +332,9 @@ async fn gemini_sync( ) -> axum::response::Response { let has_custom_tools = state.mitm_store.get_tools().await.is_some(); - // Clear stale response + // Clear stale response and upstream errors state.mitm_store.clear_response_async().await; + state.mitm_store.clear_upstream_error().await; // ── MITM bypass: when tools active, poll MitmStore directly ── if has_custom_tools { @@ -417,6 +418,9 @@ async fn gemini_sync( // ── Normal LS path (no custom tools) ── let poll_result = poll_for_response(&state, &cascade_id, timeout).await; + if let Some(ref err) = poll_result.upstream_error { + return upstream_err_response(err); + } // Check for captured function calls — return in Gemini format let captured_tool_calls = state.mitm_store.take_any_function_calls().await; diff --git a/src/api/polling.rs b/src/api/polling.rs index 5cfa47a..373d6c6 100644 --- a/src/api/polling.rs +++ b/src/api/polling.rs @@ -28,6 +28,8 @@ pub(crate) struct PollResult { /// Time the model spent thinking, as reported by the LS (e.g. "0.041999832s"). #[allow(dead_code)] pub thinking_duration: Option, + /// Upstream error from Google's API, captured by MITM. + pub upstream_error: Option, } /// Extract the response text from steps — scans in REVERSE to find the latest response. @@ -190,6 +192,23 @@ pub(crate) async fn poll_for_response( let mut last_step_count: usize = 0; while start.elapsed().as_secs() < timeout { + // Check for upstream errors from MITM (Google API errors) + if let Some(err) = state.mitm_store.take_upstream_error().await { + warn!( + "Upstream error on cascade {short_id}: HTTP {} — {}", + err.status, + err.message.as_deref().unwrap_or("unknown") + ); + return PollResult { + text: String::new(), + usage: None, + thinking_signature: None, + thinking: None, + thinking_duration: None, + upstream_error: Some(err), + }; + } + if let Ok((status, data)) = state.backend.get_steps(cascade_id).await { if status == 200 { if let Some(steps) = data["steps"].as_array() { @@ -249,7 +268,7 @@ pub(crate) async fn poll_for_response( if thinking_signature.is_some() { ", has sig" } else { "" } ); } - return PollResult { text, usage, thinking_signature, thinking, thinking_duration }; + return PollResult { text, usage, thinking_signature, thinking, thinking_duration, upstream_error: None }; } } @@ -274,7 +293,7 @@ pub(crate) async fn poll_for_response( elapsed, text.len() ); - return PollResult { text, usage, thinking_signature, thinking, thinking_duration }; + return PollResult { text, usage, thinking_signature, thinking, thinking_duration, upstream_error: None }; } } } @@ -295,5 +314,6 @@ pub(crate) async fn poll_for_response( thinking_signature: None, thinking: None, thinking_duration: None, + upstream_error: None, } } diff --git a/src/api/responses.rs b/src/api/responses.rs index 3d0253b..6864dcd 100644 --- a/src/api/responses.rs +++ b/src/api/responses.rs @@ -16,7 +16,7 @@ use tracing::{debug, info}; use super::models::{lookup_model, DEFAULT_MODEL, MODELS}; use super::polling::{extract_response_text, is_response_done, poll_for_response, extract_model_usage, extract_thinking_signature, extract_thinking_content}; use super::types::*; -use super::util::{err_response, now_unix, responses_sse_event}; +use super::util::{err_response, upstream_err_response, now_unix, responses_sse_event}; use super::AppState; use crate::mitm::store::PendingToolResult; use crate::mitm::modify::{openai_tools_to_gemini, openai_tool_choice_to_gemini}; @@ -552,8 +552,9 @@ async fn handle_responses_sync( let created_at = now_unix(); let has_custom_tools = state.mitm_store.get_tools().await.is_some(); - // Clear stale captured response + // Clear stale captured response and upstream errors state.mitm_store.clear_response_async().await; + state.mitm_store.clear_upstream_error().await; // ── MITM bypass: poll MitmStore directly when custom tools active ── if has_custom_tools { @@ -653,6 +654,9 @@ async fn handle_responses_sync( // ── Normal LS path (no custom tools) ── let poll_result = poll_for_response(&state, &cascade_id, timeout).await; + if let Some(ref err) = poll_result.upstream_error { + return upstream_err_response(err); + } let completed_at = now_unix(); let msg_id = format!( "msg_{}", @@ -806,8 +810,9 @@ async fn handle_responses_stream( let reasoning_id = format!("rs_{}", uuid::Uuid::new_v4().to_string().replace('-', "")); let has_custom_tools = state.mitm_store.get_tools().await.is_some(); - // Clear stale captured response + // Clear stale captured response and upstream errors state.mitm_store.clear_response_async().await; + state.mitm_store.clear_upstream_error().await; // ── MITM bypass mode (when custom tools are active) ── // Skip LS entirely — read text, thinking, and tool calls directly from MitmStore. @@ -815,6 +820,29 @@ async fn handle_responses_stream( let mut last_thinking = String::new(); while start.elapsed().as_secs() < timeout { + // Check for upstream errors from MITM (Google API errors) + if let Some(err) = state.mitm_store.take_upstream_error().await { + let error_msg = err.message.clone() + .unwrap_or_else(|| format!("Google API returned HTTP {}", err.status)); + yield Ok(responses_sse_event( + "response.failed", + serde_json::json!({ + "type": "response.failed", + "sequence_number": next_seq(), + "response": { + "id": &response_id, + "status": "failed", + "error": { + "type": err.error_status.as_deref().unwrap_or("upstream_error"), + "message": error_msg, + "code": err.status, + }, + }, + }), + )); + break; + } + // Check for function calls first let captured = state.mitm_store.take_any_function_calls().await; if let Some(ref raw_calls) = captured { @@ -1135,6 +1163,29 @@ async fn handle_responses_stream( let mut last_thinking_len: usize = 0; while start.elapsed().as_secs() < timeout { + // Check for upstream errors from MITM (Google API errors) + if let Some(err) = state.mitm_store.take_upstream_error().await { + let error_msg = err.message.clone() + .unwrap_or_else(|| format!("Google API returned HTTP {}", err.status)); + yield Ok(responses_sse_event( + "response.failed", + serde_json::json!({ + "type": "response.failed", + "sequence_number": next_seq(), + "response": { + "id": &response_id, + "status": "failed", + "error": { + "type": err.error_status.as_deref().unwrap_or("upstream_error"), + "message": error_msg, + "code": err.status, + }, + }, + }), + )); + break; + } + if let Ok((status, data)) = state.backend.get_steps(&cascade_id).await { if status == 200 { if let Some(steps) = data["steps"].as_array() { diff --git a/src/api/search.rs b/src/api/search.rs index 73b8a4e..8c40110 100644 --- a/src/api/search.rs +++ b/src/api/search.rs @@ -17,7 +17,7 @@ use tracing::{info, warn}; use super::models::{lookup_model, MODELS}; use super::polling::poll_for_response; -use super::util::err_response; +use super::util::{err_response, upstream_err_response}; use super::AppState; /// Search request body. @@ -181,6 +181,9 @@ async fn do_search(state: Arc, body: SearchRequest) -> axum::response: // Poll for response let poll_result = poll_for_response(&state, &cascade_id, body.timeout).await; + if let Some(ref err) = poll_result.upstream_error { + return upstream_err_response(err); + } // Extract grounding metadata let grounding = state.mitm_store.take_grounding().await; diff --git a/src/api/util.rs b/src/api/util.rs index 2f4bee6..3fe57c6 100644 --- a/src/api/util.rs +++ b/src/api/util.rs @@ -25,6 +25,28 @@ pub(crate) fn err_response( (status, Json(body)).into_response() } +/// Convert a MITM-captured upstream error from Google into an HTTP response. +/// Maps Google's HTTP status codes and preserves the error message. +pub(crate) fn upstream_err_response(err: &crate::mitm::store::UpstreamError) -> axum::response::Response { + // Map Google's status code to HTTP status + let status = StatusCode::from_u16(err.status).unwrap_or(StatusCode::BAD_GATEWAY); + + // Map Google error status to OpenAI-style error type + let error_type = match err.error_status.as_deref() { + Some("INVALID_ARGUMENT") => "invalid_request_error", + Some("RESOURCE_EXHAUSTED") => "rate_limit_error", + Some("PERMISSION_DENIED") | Some("UNAUTHENTICATED") => "authentication_error", + Some("NOT_FOUND") => "not_found_error", + Some("INTERNAL") | Some("UNAVAILABLE") => "server_error", + _ => "upstream_error", + }; + + let message = err.message.clone() + .unwrap_or_else(|| format!("Google API returned HTTP {}", err.status)); + + err_response(status, message, error_type) +} + pub(crate) fn now_unix() -> u64 { SystemTime::now() .duration_since(UNIX_EPOCH) diff --git a/src/mitm/proxy.rs b/src/mitm/proxy.rs index 89787e3..d5c0914 100644 --- a/src/mitm/proxy.rs +++ b/src/mitm/proxy.rs @@ -744,10 +744,29 @@ async fn handle_http_over_tls( } headers_parsed = true; - // Log error response bodies for debugging - if resp.code.unwrap_or(0) >= 400 { - let body_preview = String::from_utf8_lossy(&header_buf[hdr_end..]); - warn!(domain, status = resp.code.unwrap_or(0), body = %body_preview, "MITM: upstream error response"); + // Capture upstream errors for forwarding to client + let http_status = resp.code.unwrap_or(0) as u16; + if http_status >= 400 { + let body_str = String::from_utf8_lossy(&header_buf[hdr_end..]).to_string(); + warn!(domain, status = http_status, body = %body_str, "MITM: upstream error response"); + + // Parse Google's error JSON: {"error": {"code": N, "message": "...", "status": "..."}} + let (message, error_status) = serde_json::from_str::(&body_str) + .ok() + .and_then(|v| { + let err = v.get("error")?; + let msg = err.get("message").and_then(|m| m.as_str()).map(|s| s.to_string()); + let status = err.get("status").and_then(|s| s.as_str()).map(|s| s.to_string()); + Some((msg, status)) + }) + .unwrap_or((None, None)); + + store.set_upstream_error(super::store::UpstreamError { + status: http_status, + body: body_str, + message, + error_status, + }).await; } // Save body for usage parsing diff --git a/src/mitm/store.rs b/src/mitm/store.rs index f3e20c8..be2ee79 100644 --- a/src/mitm/store.rs +++ b/src/mitm/store.rs @@ -60,6 +60,21 @@ pub struct PendingToolResult { pub result: serde_json::Value, } +/// An upstream error captured from Google's API response. +/// Stored by the MITM proxy so API handlers can return it to the client +/// instead of hanging forever waiting for a response that won't come. +#[derive(Debug, Clone)] +pub struct UpstreamError { + /// HTTP status code from Google (e.g. 400, 429, 500). + pub status: u16, + /// Raw error body from Google (usually JSON). + pub body: String, + /// Parsed error message, if available. + pub message: Option, + /// Google error status string (e.g. "INVALID_ARGUMENT", "RESOURCE_EXHAUSTED"). + pub error_status: Option, +} + /// A pending image to inject via MITM into the Google API request. /// The LS doesn't forward images from our SendUserCascadeMessage proto, /// so we inject them directly at the MITM layer. @@ -152,6 +167,10 @@ pub struct MitmStore { // ── Pending image for MITM injection ───────────────────────────────── /// Image to inject into the next Google API request via MITM. pending_image: Arc>>, + + // ── Upstream error capture ─────────────────────────────────────────── + /// Error from Google's API, captured by MITM for forwarding to client. + upstream_error: Arc>>, } /// Aggregate statistics across all intercepted traffic. @@ -197,6 +216,7 @@ impl MitmStore { generation_params: Arc::new(RwLock::new(None)), captured_grounding: Arc::new(RwLock::new(None)), pending_image: Arc::new(RwLock::new(None)), + upstream_error: Arc::new(RwLock::new(None)), } } @@ -534,4 +554,21 @@ impl MitmStore { pub async fn take_pending_image(&self) -> Option { self.pending_image.write().await.take() } + + // ── Upstream error capture ─────────────────────────────────────────── + + /// Store an upstream error from Google's API. + pub async fn set_upstream_error(&self, error: UpstreamError) { + *self.upstream_error.write().await = Some(error); + } + + /// Take (consume) captured upstream error. + pub async fn take_upstream_error(&self) -> Option { + self.upstream_error.write().await.take() + } + + /// Clear any stored upstream error. + pub async fn clear_upstream_error(&self) { + *self.upstream_error.write().await = None; + } }