diff --git a/src/api/completions.rs b/src/api/completions.rs index 7b0de4a..a085fcb 100644 --- a/src/api/completions.rs +++ b/src/api/completions.rs @@ -12,7 +12,7 @@ use tracing::{debug, info, warn}; use super::models::{lookup_model, DEFAULT_MODEL, MODELS}; use super::polling::{extract_response_text, extract_thinking_content, is_response_done, poll_for_response}; use super::types::*; -use super::util::{err_response, now_unix}; +use super::util::{err_response, upstream_err_response, now_unix}; use super::AppState; /// Extract a conversation/session ID from a flexible JSON value. @@ -488,8 +488,9 @@ async fn chat_completions_stream( let mut last_text = String::new(); let has_custom_tools = state.mitm_store.get_tools().await.is_some(); - // Clear any stale captured response from previous requests + // Clear any stale captured response and upstream errors from previous requests state.mitm_store.clear_response_async().await; + state.mitm_store.clear_upstream_error().await; // Initial role chunk yield Ok::<_, std::convert::Infallible>(Event::default().data(chunk_json( @@ -513,6 +514,27 @@ async fn chat_completions_stream( }; while start.elapsed().as_secs() < timeout { + // Check for upstream errors from MITM (Google API errors) + if let Some(err) = state.mitm_store.take_upstream_error().await { + let error_msg = err.message.clone() + .unwrap_or_else(|| format!("Google API returned HTTP {}", err.status)); + let error_type = match err.error_status.as_deref() { + Some("INVALID_ARGUMENT") => "invalid_request_error", + Some("RESOURCE_EXHAUSTED") => "rate_limit_error", + Some("PERMISSION_DENIED") | Some("UNAUTHENTICATED") => "authentication_error", + _ => "upstream_error", + }; + yield Ok(Event::default().data(serde_json::to_string(&serde_json::json!({ + "error": { + "message": error_msg, + "type": error_type, + "code": err.status, + } + })).unwrap())); + yield Ok(Event::default().data("[DONE]".to_string())); + break; + } + // ── Check for MITM-captured function calls FIRST ── // This runs independently of LS steps — the MITM captures tool calls // at the proxy layer, so we don't need to wait for LS processing. @@ -852,6 +874,9 @@ async fn chat_completions_sync( timeout: u64, ) -> axum::response::Response { let result = poll_for_response(&state, &cascade_id, timeout).await; + if let Some(ref err) = result.upstream_error { + return upstream_err_response(err); + } // Check MITM store first for real intercepted usage (fallback to _latest) let mitm = match state.mitm_store.take_usage(&cascade_id).await { diff --git a/src/api/gemini.rs b/src/api/gemini.rs index d3fe8e8..5f050ea 100644 --- a/src/api/gemini.rs +++ b/src/api/gemini.rs @@ -16,7 +16,7 @@ use tracing::{info, warn}; use super::models::{lookup_model, DEFAULT_MODEL, MODELS}; use super::polling::{extract_response_text, extract_thinking_content, is_response_done, poll_for_response}; -use super::util::err_response; +use super::util::{err_response, upstream_err_response}; use super::AppState; use crate::mitm::store::PendingToolResult; @@ -332,8 +332,9 @@ async fn gemini_sync( ) -> axum::response::Response { let has_custom_tools = state.mitm_store.get_tools().await.is_some(); - // Clear stale response + // Clear stale response and upstream errors state.mitm_store.clear_response_async().await; + state.mitm_store.clear_upstream_error().await; // ── MITM bypass: when tools active, poll MitmStore directly ── if has_custom_tools { @@ -417,6 +418,9 @@ async fn gemini_sync( // ── Normal LS path (no custom tools) ── let poll_result = poll_for_response(&state, &cascade_id, timeout).await; + if let Some(ref err) = poll_result.upstream_error { + return upstream_err_response(err); + } // Check for captured function calls — return in Gemini format let captured_tool_calls = state.mitm_store.take_any_function_calls().await; diff --git a/src/api/polling.rs b/src/api/polling.rs index 5cfa47a..373d6c6 100644 --- a/src/api/polling.rs +++ b/src/api/polling.rs @@ -28,6 +28,8 @@ pub(crate) struct PollResult { /// Time the model spent thinking, as reported by the LS (e.g. "0.041999832s"). #[allow(dead_code)] pub thinking_duration: Option, + /// Upstream error from Google's API, captured by MITM. + pub upstream_error: Option, } /// Extract the response text from steps — scans in REVERSE to find the latest response. @@ -190,6 +192,23 @@ pub(crate) async fn poll_for_response( let mut last_step_count: usize = 0; while start.elapsed().as_secs() < timeout { + // Check for upstream errors from MITM (Google API errors) + if let Some(err) = state.mitm_store.take_upstream_error().await { + warn!( + "Upstream error on cascade {short_id}: HTTP {} — {}", + err.status, + err.message.as_deref().unwrap_or("unknown") + ); + return PollResult { + text: String::new(), + usage: None, + thinking_signature: None, + thinking: None, + thinking_duration: None, + upstream_error: Some(err), + }; + } + if let Ok((status, data)) = state.backend.get_steps(cascade_id).await { if status == 200 { if let Some(steps) = data["steps"].as_array() { @@ -249,7 +268,7 @@ pub(crate) async fn poll_for_response( if thinking_signature.is_some() { ", has sig" } else { "" } ); } - return PollResult { text, usage, thinking_signature, thinking, thinking_duration }; + return PollResult { text, usage, thinking_signature, thinking, thinking_duration, upstream_error: None }; } } @@ -274,7 +293,7 @@ pub(crate) async fn poll_for_response( elapsed, text.len() ); - return PollResult { text, usage, thinking_signature, thinking, thinking_duration }; + return PollResult { text, usage, thinking_signature, thinking, thinking_duration, upstream_error: None }; } } } @@ -295,5 +314,6 @@ pub(crate) async fn poll_for_response( thinking_signature: None, thinking: None, thinking_duration: None, + upstream_error: None, } } diff --git a/src/api/responses.rs b/src/api/responses.rs index 3d0253b..6864dcd 100644 --- a/src/api/responses.rs +++ b/src/api/responses.rs @@ -16,7 +16,7 @@ use tracing::{debug, info}; use super::models::{lookup_model, DEFAULT_MODEL, MODELS}; use super::polling::{extract_response_text, is_response_done, poll_for_response, extract_model_usage, extract_thinking_signature, extract_thinking_content}; use super::types::*; -use super::util::{err_response, now_unix, responses_sse_event}; +use super::util::{err_response, upstream_err_response, now_unix, responses_sse_event}; use super::AppState; use crate::mitm::store::PendingToolResult; use crate::mitm::modify::{openai_tools_to_gemini, openai_tool_choice_to_gemini}; @@ -552,8 +552,9 @@ async fn handle_responses_sync( let created_at = now_unix(); let has_custom_tools = state.mitm_store.get_tools().await.is_some(); - // Clear stale captured response + // Clear stale captured response and upstream errors state.mitm_store.clear_response_async().await; + state.mitm_store.clear_upstream_error().await; // ── MITM bypass: poll MitmStore directly when custom tools active ── if has_custom_tools { @@ -653,6 +654,9 @@ async fn handle_responses_sync( // ── Normal LS path (no custom tools) ── let poll_result = poll_for_response(&state, &cascade_id, timeout).await; + if let Some(ref err) = poll_result.upstream_error { + return upstream_err_response(err); + } let completed_at = now_unix(); let msg_id = format!( "msg_{}", @@ -806,8 +810,9 @@ async fn handle_responses_stream( let reasoning_id = format!("rs_{}", uuid::Uuid::new_v4().to_string().replace('-', "")); let has_custom_tools = state.mitm_store.get_tools().await.is_some(); - // Clear stale captured response + // Clear stale captured response and upstream errors state.mitm_store.clear_response_async().await; + state.mitm_store.clear_upstream_error().await; // ── MITM bypass mode (when custom tools are active) ── // Skip LS entirely — read text, thinking, and tool calls directly from MitmStore. @@ -815,6 +820,29 @@ async fn handle_responses_stream( let mut last_thinking = String::new(); while start.elapsed().as_secs() < timeout { + // Check for upstream errors from MITM (Google API errors) + if let Some(err) = state.mitm_store.take_upstream_error().await { + let error_msg = err.message.clone() + .unwrap_or_else(|| format!("Google API returned HTTP {}", err.status)); + yield Ok(responses_sse_event( + "response.failed", + serde_json::json!({ + "type": "response.failed", + "sequence_number": next_seq(), + "response": { + "id": &response_id, + "status": "failed", + "error": { + "type": err.error_status.as_deref().unwrap_or("upstream_error"), + "message": error_msg, + "code": err.status, + }, + }, + }), + )); + break; + } + // Check for function calls first let captured = state.mitm_store.take_any_function_calls().await; if let Some(ref raw_calls) = captured { @@ -1135,6 +1163,29 @@ async fn handle_responses_stream( let mut last_thinking_len: usize = 0; while start.elapsed().as_secs() < timeout { + // Check for upstream errors from MITM (Google API errors) + if let Some(err) = state.mitm_store.take_upstream_error().await { + let error_msg = err.message.clone() + .unwrap_or_else(|| format!("Google API returned HTTP {}", err.status)); + yield Ok(responses_sse_event( + "response.failed", + serde_json::json!({ + "type": "response.failed", + "sequence_number": next_seq(), + "response": { + "id": &response_id, + "status": "failed", + "error": { + "type": err.error_status.as_deref().unwrap_or("upstream_error"), + "message": error_msg, + "code": err.status, + }, + }, + }), + )); + break; + } + if let Ok((status, data)) = state.backend.get_steps(&cascade_id).await { if status == 200 { if let Some(steps) = data["steps"].as_array() { diff --git a/src/api/search.rs b/src/api/search.rs index 73b8a4e..8c40110 100644 --- a/src/api/search.rs +++ b/src/api/search.rs @@ -17,7 +17,7 @@ use tracing::{info, warn}; use super::models::{lookup_model, MODELS}; use super::polling::poll_for_response; -use super::util::err_response; +use super::util::{err_response, upstream_err_response}; use super::AppState; /// Search request body. @@ -181,6 +181,9 @@ async fn do_search(state: Arc, body: SearchRequest) -> axum::response: // Poll for response let poll_result = poll_for_response(&state, &cascade_id, body.timeout).await; + if let Some(ref err) = poll_result.upstream_error { + return upstream_err_response(err); + } // Extract grounding metadata let grounding = state.mitm_store.take_grounding().await; diff --git a/src/api/util.rs b/src/api/util.rs index 2f4bee6..3fe57c6 100644 --- a/src/api/util.rs +++ b/src/api/util.rs @@ -25,6 +25,28 @@ pub(crate) fn err_response( (status, Json(body)).into_response() } +/// Convert a MITM-captured upstream error from Google into an HTTP response. +/// Maps Google's HTTP status codes and preserves the error message. +pub(crate) fn upstream_err_response(err: &crate::mitm::store::UpstreamError) -> axum::response::Response { + // Map Google's status code to HTTP status + let status = StatusCode::from_u16(err.status).unwrap_or(StatusCode::BAD_GATEWAY); + + // Map Google error status to OpenAI-style error type + let error_type = match err.error_status.as_deref() { + Some("INVALID_ARGUMENT") => "invalid_request_error", + Some("RESOURCE_EXHAUSTED") => "rate_limit_error", + Some("PERMISSION_DENIED") | Some("UNAUTHENTICATED") => "authentication_error", + Some("NOT_FOUND") => "not_found_error", + Some("INTERNAL") | Some("UNAVAILABLE") => "server_error", + _ => "upstream_error", + }; + + let message = err.message.clone() + .unwrap_or_else(|| format!("Google API returned HTTP {}", err.status)); + + err_response(status, message, error_type) +} + pub(crate) fn now_unix() -> u64 { SystemTime::now() .duration_since(UNIX_EPOCH) diff --git a/src/mitm/proxy.rs b/src/mitm/proxy.rs index 89787e3..d5c0914 100644 --- a/src/mitm/proxy.rs +++ b/src/mitm/proxy.rs @@ -744,10 +744,29 @@ async fn handle_http_over_tls( } headers_parsed = true; - // Log error response bodies for debugging - if resp.code.unwrap_or(0) >= 400 { - let body_preview = String::from_utf8_lossy(&header_buf[hdr_end..]); - warn!(domain, status = resp.code.unwrap_or(0), body = %body_preview, "MITM: upstream error response"); + // Capture upstream errors for forwarding to client + let http_status = resp.code.unwrap_or(0) as u16; + if http_status >= 400 { + let body_str = String::from_utf8_lossy(&header_buf[hdr_end..]).to_string(); + warn!(domain, status = http_status, body = %body_str, "MITM: upstream error response"); + + // Parse Google's error JSON: {"error": {"code": N, "message": "...", "status": "..."}} + let (message, error_status) = serde_json::from_str::(&body_str) + .ok() + .and_then(|v| { + let err = v.get("error")?; + let msg = err.get("message").and_then(|m| m.as_str()).map(|s| s.to_string()); + let status = err.get("status").and_then(|s| s.as_str()).map(|s| s.to_string()); + Some((msg, status)) + }) + .unwrap_or((None, None)); + + store.set_upstream_error(super::store::UpstreamError { + status: http_status, + body: body_str, + message, + error_status, + }).await; } // Save body for usage parsing diff --git a/src/mitm/store.rs b/src/mitm/store.rs index f3e20c8..be2ee79 100644 --- a/src/mitm/store.rs +++ b/src/mitm/store.rs @@ -60,6 +60,21 @@ pub struct PendingToolResult { pub result: serde_json::Value, } +/// An upstream error captured from Google's API response. +/// Stored by the MITM proxy so API handlers can return it to the client +/// instead of hanging forever waiting for a response that won't come. +#[derive(Debug, Clone)] +pub struct UpstreamError { + /// HTTP status code from Google (e.g. 400, 429, 500). + pub status: u16, + /// Raw error body from Google (usually JSON). + pub body: String, + /// Parsed error message, if available. + pub message: Option, + /// Google error status string (e.g. "INVALID_ARGUMENT", "RESOURCE_EXHAUSTED"). + pub error_status: Option, +} + /// A pending image to inject via MITM into the Google API request. /// The LS doesn't forward images from our SendUserCascadeMessage proto, /// so we inject them directly at the MITM layer. @@ -152,6 +167,10 @@ pub struct MitmStore { // ── Pending image for MITM injection ───────────────────────────────── /// Image to inject into the next Google API request via MITM. pending_image: Arc>>, + + // ── Upstream error capture ─────────────────────────────────────────── + /// Error from Google's API, captured by MITM for forwarding to client. + upstream_error: Arc>>, } /// Aggregate statistics across all intercepted traffic. @@ -197,6 +216,7 @@ impl MitmStore { generation_params: Arc::new(RwLock::new(None)), captured_grounding: Arc::new(RwLock::new(None)), pending_image: Arc::new(RwLock::new(None)), + upstream_error: Arc::new(RwLock::new(None)), } } @@ -534,4 +554,21 @@ impl MitmStore { pub async fn take_pending_image(&self) -> Option { self.pending_image.write().await.take() } + + // ── Upstream error capture ─────────────────────────────────────────── + + /// Store an upstream error from Google's API. + pub async fn set_upstream_error(&self, error: UpstreamError) { + *self.upstream_error.write().await = Some(error); + } + + /// Take (consume) captured upstream error. + pub async fn take_upstream_error(&self) -> Option { + self.upstream_error.write().await.take() + } + + /// Clear any stored upstream error. + pub async fn clear_upstream_error(&self) { + *self.upstream_error.write().await = None; + } }