diff --git a/src/api/completions.rs b/src/api/completions.rs index fc4ce0e..3d4bfa1 100644 --- a/src/api/completions.rs +++ b/src/api/completions.rs @@ -274,6 +274,10 @@ async fn chat_completions_stream( let stream = async_stream::stream! { let start = std::time::Instant::now(); let mut last_text = String::new(); + let has_custom_tools = state.mitm_store.get_tools().await.is_some(); + + // Clear any stale captured response from previous requests + state.mitm_store.clear_response_async().await; // Initial role chunk yield Ok::<_, std::convert::Infallible>(Event::default().data(serde_json::to_string(&serde_json::json!({ @@ -342,6 +346,112 @@ async fn chat_completions_stream( } } + // ── Check for MITM-captured response text (bypass LS) ── + if has_custom_tools { + if let Some(text) = state.mitm_store.peek_response_text().await { + if !text.is_empty() && text != last_text { + let delta = if text.len() > last_text.len() && text.starts_with(&*last_text) { + text[last_text.len()..].to_string() + } else { + text.clone() + }; + + if !delta.is_empty() { + yield Ok(Event::default().data(serde_json::to_string(&serde_json::json!({ + "id": completion_id, + "object": "chat.completion.chunk", + "created": now_unix(), + "model": model_name, + "choices": [{ + "index": 0, + "delta": {"content": delta}, + "finish_reason": serde_json::Value::Null, + }], + })).unwrap_or_default())); + last_text = text; + } + } + + // Check if MITM response is complete + if state.mitm_store.is_response_complete() && !last_text.is_empty() { + debug!("Completions: MITM response complete (bypass), text length={}", last_text.len()); + yield Ok(Event::default().data(serde_json::to_string(&serde_json::json!({ + "id": completion_id, + "object": "chat.completion.chunk", + "created": now_unix(), + "model": model_name, + "choices": [{ + "index": 0, + "delta": {}, + "finish_reason": "stop", + }], + })).unwrap_or_default())); + yield Ok(Event::default().data("[DONE]")); + return; + } + } else if state.mitm_store.is_response_complete() { + // Response complete but no text — might be a tool call we already handled + // or an empty response. Give it a moment then bail. + tokio::time::sleep(tokio::time::Duration::from_millis(200)).await; + // Re-check function calls one more time + let final_check = state.mitm_store.take_any_function_calls().await; + if let Some(ref calls) = final_check { + if !calls.is_empty() { + let mut tool_calls = Vec::new(); + for (i, fc) in calls.iter().enumerate() { + let call_id = format!( + "call_{}", + uuid::Uuid::new_v4().to_string().replace('-', "")[..24].to_string() + ); + let arguments = serde_json::to_string(&fc.args).unwrap_or_default(); + tool_calls.push(serde_json::json!({ + "index": i, + "id": call_id, + "type": "function", + "function": { + "name": fc.name, + "arguments": arguments, + }, + })); + } + yield Ok(Event::default().data(serde_json::to_string(&serde_json::json!({ + "id": completion_id, + "object": "chat.completion.chunk", + "created": now_unix(), + "model": model_name, + "choices": [{ + "index": 0, + "delta": {"tool_calls": tool_calls}, + "finish_reason": serde_json::Value::Null, + }], + })).unwrap_or_default())); + yield Ok(Event::default().data(serde_json::to_string(&serde_json::json!({ + "id": completion_id, + "object": "chat.completion.chunk", + "created": now_unix(), + "model": model_name, + "choices": [{ + "index": 0, + "delta": {}, + "finish_reason": "tool_calls", + }], + })).unwrap_or_default())); + yield Ok(Event::default().data("[DONE]")); + return; + } + } + } + + // When using bypass mode, skip LS step polling + keepalive_counter += 1; + if keepalive_counter % 10 == 0 { + yield Ok(Event::default().comment("keepalive")); + } + let poll_ms: u64 = rand::thread_rng().gen_range(200..350); + tokio::time::sleep(tokio::time::Duration::from_millis(poll_ms)).await; + continue; + } + // ── Check LS steps for text streaming ── if let Ok((status, data)) = state.backend.get_steps(&cascade_id).await { if status == 200 { diff --git a/src/mitm/proxy.rs b/src/mitm/proxy.rs index ae57e7b..3b09dc9 100644 --- a/src/mitm/proxy.rs +++ b/src/mitm/proxy.rs @@ -737,6 +737,8 @@ async fn handle_http_over_tls( // Parse ORIGINAL initial body for MITM interception let mut has_function_call = false; + let bypass_ls = modify_requests && store.get_tools().await.is_some(); + if is_streaming_response && hdr_end < header_buf.len() { let body = String::from_utf8_lossy(&header_buf[hdr_end..]); parse_streaming_chunk(&body, &mut streaming_acc); @@ -750,41 +752,38 @@ async fn handle_http_over_tls( store.set_last_function_calls(streaming_acc.function_calls.clone()).await; info!("MITM: stored {} function call(s) from initial body", streaming_acc.function_calls.len()); } + + // Capture response text directly into MitmStore + if bypass_ls && !streaming_acc.response_text.is_empty() { + store.set_response_text(&streaming_acc.response_text).await; + } + if bypass_ls && streaming_acc.is_complete { + store.mark_response_complete(); + } } - if has_function_call && modify_requests && store.get_tools().await.is_some() { - info!("MITM: functionCall detected → sending dummy STOP response to LS"); - - // Build a clean SSE response the LS will accept - let dummy_json = serde_json::json!({ - "response": { - "candidates": [{ - "content": { - "role": "model", - "parts": [{"text": "Tool call completed. Awaiting external tool result."}] - }, - "finishReason": "STOP" - }], - "modelVersion": "gemini-3-flash" - }, - "metadata": {} - }); - let dummy_data = format!("data: {}\r\n\r\n", serde_json::to_string(&dummy_json).unwrap()); - let dummy_chunk = format!("{:x}\r\n{}\r\n0\r\n\r\n", dummy_data.len(), dummy_data); - - // Send headers (from original response) + dummy body - let headers_only = &header_buf[..hdr_end]; - if let Err(e) = client.write_all(headers_only).await { - warn!(error = %e, "MITM: write headers failed"); + if bypass_ls { + if has_function_call { + info!("MITM: functionCall captured → NOT forwarding to LS (bypass mode)"); + store.mark_response_complete(); + break; } - if let Err(e) = client.write_all(dummy_chunk.as_bytes()).await { - warn!(error = %e, "MITM: write dummy body failed"); + // Don't forward to LS — just continue reading chunks + // Send headers only so upstream doesn't close + if let Some(cl) = response_content_length { + if response_body_buf.len() >= cl { + store.mark_response_complete(); + break; + } } - // Done — don't forward the real response - break; + if is_chunked && has_chunked_terminator(&response_body_buf) { + store.mark_response_complete(); + break; + } + continue; } - // Normal path: forward headers+body as-is + // Normal path (no custom tools): forward headers+body as-is if let Err(e) = client.write_all(&header_buf).await { warn!(error = %e, "MITM: write to client failed"); break; @@ -804,14 +803,15 @@ async fn handle_http_over_tls( } // ── Response body interception ──────────────────────────────── - // Parse ORIGINAL chunk for MITM interception (captures functionCalls) let mut chunk_has_fc = false; + let bypass_ls = modify_requests && store.get_tools().await.is_some(); + if is_streaming_response { let s = String::from_utf8_lossy(chunk); parse_streaming_chunk(&s, &mut streaming_acc); chunk_has_fc = !streaming_acc.function_calls.is_empty(); - // Immediately store captured function calls — don't wait for loop end + // Immediately store captured function calls if chunk_has_fc { for fc in &streaming_acc.function_calls { store.record_function_call(cascade_hint.as_deref(), fc.clone()).await; @@ -819,13 +819,35 @@ async fn handle_http_over_tls( store.set_last_function_calls(streaming_acc.function_calls.clone()).await; info!("MITM: stored {} function call(s) from body chunk", streaming_acc.function_calls.len()); } + + // Capture response text directly into MitmStore + if bypass_ls && !streaming_acc.response_text.is_empty() { + store.set_response_text(&streaming_acc.response_text).await; + } + if bypass_ls && streaming_acc.is_complete { + store.mark_response_complete(); + } } - // If functionCall detected + custom tools → send dummy + stop - if chunk_has_fc && modify_requests && store.get_tools().await.is_some() { - info!("MITM: functionCall in body chunk → sending chunked terminator to LS"); - let _ = client.write_all(b"0\r\n\r\n").await; - break; + if bypass_ls { + if chunk_has_fc || streaming_acc.is_complete { + info!("MITM: response captured → NOT forwarding to LS (bypass mode)"); + store.mark_response_complete(); + break; + } + // Keep reading chunks without forwarding to LS + response_body_buf.extend_from_slice(chunk); + if let Some(cl) = response_content_length { + if response_body_buf.len() >= cl { + store.mark_response_complete(); + break; + } + } + if is_chunked && has_chunked_terminator(&response_body_buf) { + store.mark_response_complete(); + break; + } + continue; } // Normal path: forward chunk to client (LS) diff --git a/src/mitm/store.rs b/src/mitm/store.rs index f497721..e0cbf29 100644 --- a/src/mitm/store.rs +++ b/src/mitm/store.rs @@ -88,6 +88,13 @@ pub struct MitmStore { call_id_to_name: Arc>>, /// Last captured function calls (for conversation history rewriting). last_function_calls: Arc>>, + + // ── Direct response capture (bypasses LS) ──────────────────────────── + /// Captured response text from MITM when custom tools are active. + /// The completions handler reads this instead of polling LS steps. + captured_response_text: Arc>>, + /// Whether the captured response is complete (finishReason received). + response_complete: Arc, } /// Aggregate statistics across all intercepted traffic. @@ -126,6 +133,8 @@ impl MitmStore { pending_tool_results: Arc::new(RwLock::new(Vec::new())), call_id_to_name: Arc::new(RwLock::new(HashMap::new())), last_function_calls: Arc::new(RwLock::new(Vec::new())), + captured_response_text: Arc::new(RwLock::new(None)), + response_complete: Arc::new(AtomicBool::new(false)), } } @@ -354,4 +363,56 @@ impl MitmStore { pub async fn get_last_function_calls(&self) -> Vec { self.last_function_calls.read().await.clone() } + + // ── Direct response capture (bypass LS) ────────────────────────────── + + /// Append text to the captured response. + pub async fn append_response_text(&self, text: &str) { + let mut resp = self.captured_response_text.write().await; + if let Some(ref mut existing) = *resp { + existing.push_str(text); + } else { + *resp = Some(text.to_string()); + } + } + + /// Set (replace) the captured response text. + pub async fn set_response_text(&self, text: &str) { + *self.captured_response_text.write().await = Some(text.to_string()); + } + + /// Take the captured response text (consumes it). + pub async fn take_response_text(&self) -> Option { + self.captured_response_text.write().await.take() + } + + /// Peek at the captured response text without consuming it. + pub async fn peek_response_text(&self) -> Option { + self.captured_response_text.read().await.clone() + } + + /// Mark the response as complete. + pub fn mark_response_complete(&self) { + self.response_complete.store(true, Ordering::SeqCst); + } + + /// Check if the response is complete. + pub fn is_response_complete(&self) -> bool { + self.response_complete.load(Ordering::SeqCst) + } + + /// Clear captured response state (call at start of new request). + pub fn clear_response(&self) { + self.response_complete.store(false, Ordering::SeqCst); + // Can't use async in sync fn, so we spawn a task... or just use try_write + if let Ok(mut resp) = self.captured_response_text.try_write() { + *resp = None; + } + } + + /// Async version of clear_response. + pub async fn clear_response_async(&self) { + self.response_complete.store(false, Ordering::SeqCst); + *self.captured_response_text.write().await = None; + } }