fix: bypass LS entirely when custom tools are active

When custom tools are set, don't forward ANY response from Google
to the LS. Instead, capture text and function calls directly into
MitmStore. The completions handler reads from MitmStore.

This eliminates the LS multi-turn loop (5 requests, 30+ seconds)
that occurred because the LS kept processing responses internally.
Tool calls now return in ~1.3s instead of timing out.
This commit is contained in:
Nikketryhard
2026-02-15 00:54:40 -06:00
parent ec1c0c700d
commit 50b53097bc
3 changed files with 229 additions and 36 deletions

View File

@@ -274,6 +274,10 @@ async fn chat_completions_stream(
let stream = async_stream::stream! {
let start = std::time::Instant::now();
let mut last_text = String::new();
let has_custom_tools = state.mitm_store.get_tools().await.is_some();
// Clear any stale captured response from previous requests
state.mitm_store.clear_response_async().await;
// Initial role chunk
yield Ok::<_, std::convert::Infallible>(Event::default().data(serde_json::to_string(&serde_json::json!({
@@ -342,6 +346,112 @@ async fn chat_completions_stream(
}
}
// ── Check for MITM-captured response text (bypass LS) ──
if has_custom_tools {
if let Some(text) = state.mitm_store.peek_response_text().await {
if !text.is_empty() && text != last_text {
let delta = if text.len() > last_text.len() && text.starts_with(&*last_text) {
text[last_text.len()..].to_string()
} else {
text.clone()
};
if !delta.is_empty() {
yield Ok(Event::default().data(serde_json::to_string(&serde_json::json!({
"id": completion_id,
"object": "chat.completion.chunk",
"created": now_unix(),
"model": model_name,
"choices": [{
"index": 0,
"delta": {"content": delta},
"finish_reason": serde_json::Value::Null,
}],
})).unwrap_or_default()));
last_text = text;
}
}
// Check if MITM response is complete
if state.mitm_store.is_response_complete() && !last_text.is_empty() {
debug!("Completions: MITM response complete (bypass), text length={}", last_text.len());
yield Ok(Event::default().data(serde_json::to_string(&serde_json::json!({
"id": completion_id,
"object": "chat.completion.chunk",
"created": now_unix(),
"model": model_name,
"choices": [{
"index": 0,
"delta": {},
"finish_reason": "stop",
}],
})).unwrap_or_default()));
yield Ok(Event::default().data("[DONE]"));
return;
}
} else if state.mitm_store.is_response_complete() {
// Response complete but no text — might be a tool call we already handled
// or an empty response. Give it a moment then bail.
tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
// Re-check function calls one more time
let final_check = state.mitm_store.take_any_function_calls().await;
if let Some(ref calls) = final_check {
if !calls.is_empty() {
let mut tool_calls = Vec::new();
for (i, fc) in calls.iter().enumerate() {
let call_id = format!(
"call_{}",
uuid::Uuid::new_v4().to_string().replace('-', "")[..24].to_string()
);
let arguments = serde_json::to_string(&fc.args).unwrap_or_default();
tool_calls.push(serde_json::json!({
"index": i,
"id": call_id,
"type": "function",
"function": {
"name": fc.name,
"arguments": arguments,
},
}));
}
yield Ok(Event::default().data(serde_json::to_string(&serde_json::json!({
"id": completion_id,
"object": "chat.completion.chunk",
"created": now_unix(),
"model": model_name,
"choices": [{
"index": 0,
"delta": {"tool_calls": tool_calls},
"finish_reason": serde_json::Value::Null,
}],
})).unwrap_or_default()));
yield Ok(Event::default().data(serde_json::to_string(&serde_json::json!({
"id": completion_id,
"object": "chat.completion.chunk",
"created": now_unix(),
"model": model_name,
"choices": [{
"index": 0,
"delta": {},
"finish_reason": "tool_calls",
}],
})).unwrap_or_default()));
yield Ok(Event::default().data("[DONE]"));
return;
}
}
}
// When using bypass mode, skip LS step polling
keepalive_counter += 1;
if keepalive_counter % 10 == 0 {
yield Ok(Event::default().comment("keepalive"));
}
let poll_ms: u64 = rand::thread_rng().gen_range(200..350);
tokio::time::sleep(tokio::time::Duration::from_millis(poll_ms)).await;
continue;
}
// ── Check LS steps for text streaming ──
if let Ok((status, data)) = state.backend.get_steps(&cascade_id).await {
if status == 200 {

View File

@@ -737,6 +737,8 @@ async fn handle_http_over_tls(
// Parse ORIGINAL initial body for MITM interception
let mut has_function_call = false;
let bypass_ls = modify_requests && store.get_tools().await.is_some();
if is_streaming_response && hdr_end < header_buf.len() {
let body = String::from_utf8_lossy(&header_buf[hdr_end..]);
parse_streaming_chunk(&body, &mut streaming_acc);
@@ -750,41 +752,38 @@ async fn handle_http_over_tls(
store.set_last_function_calls(streaming_acc.function_calls.clone()).await;
info!("MITM: stored {} function call(s) from initial body", streaming_acc.function_calls.len());
}
// Capture response text directly into MitmStore
if bypass_ls && !streaming_acc.response_text.is_empty() {
store.set_response_text(&streaming_acc.response_text).await;
}
if bypass_ls && streaming_acc.is_complete {
store.mark_response_complete();
}
}
if has_function_call && modify_requests && store.get_tools().await.is_some() {
info!("MITM: functionCall detected → sending dummy STOP response to LS");
// Build a clean SSE response the LS will accept
let dummy_json = serde_json::json!({
"response": {
"candidates": [{
"content": {
"role": "model",
"parts": [{"text": "Tool call completed. Awaiting external tool result."}]
},
"finishReason": "STOP"
}],
"modelVersion": "gemini-3-flash"
},
"metadata": {}
});
let dummy_data = format!("data: {}\r\n\r\n", serde_json::to_string(&dummy_json).unwrap());
let dummy_chunk = format!("{:x}\r\n{}\r\n0\r\n\r\n", dummy_data.len(), dummy_data);
// Send headers (from original response) + dummy body
let headers_only = &header_buf[..hdr_end];
if let Err(e) = client.write_all(headers_only).await {
warn!(error = %e, "MITM: write headers failed");
if bypass_ls {
if has_function_call {
info!("MITM: functionCall captured → NOT forwarding to LS (bypass mode)");
store.mark_response_complete();
break;
}
if let Err(e) = client.write_all(dummy_chunk.as_bytes()).await {
warn!(error = %e, "MITM: write dummy body failed");
// Don't forward to LS — just continue reading chunks
// Send headers only so upstream doesn't close
if let Some(cl) = response_content_length {
if response_body_buf.len() >= cl {
store.mark_response_complete();
break;
}
}
// Done — don't forward the real response
break;
if is_chunked && has_chunked_terminator(&response_body_buf) {
store.mark_response_complete();
break;
}
continue;
}
// Normal path: forward headers+body as-is
// Normal path (no custom tools): forward headers+body as-is
if let Err(e) = client.write_all(&header_buf).await {
warn!(error = %e, "MITM: write to client failed");
break;
@@ -804,14 +803,15 @@ async fn handle_http_over_tls(
}
// ── Response body interception ────────────────────────────────
// Parse ORIGINAL chunk for MITM interception (captures functionCalls)
let mut chunk_has_fc = false;
let bypass_ls = modify_requests && store.get_tools().await.is_some();
if is_streaming_response {
let s = String::from_utf8_lossy(chunk);
parse_streaming_chunk(&s, &mut streaming_acc);
chunk_has_fc = !streaming_acc.function_calls.is_empty();
// Immediately store captured function calls — don't wait for loop end
// Immediately store captured function calls
if chunk_has_fc {
for fc in &streaming_acc.function_calls {
store.record_function_call(cascade_hint.as_deref(), fc.clone()).await;
@@ -819,13 +819,35 @@ async fn handle_http_over_tls(
store.set_last_function_calls(streaming_acc.function_calls.clone()).await;
info!("MITM: stored {} function call(s) from body chunk", streaming_acc.function_calls.len());
}
// Capture response text directly into MitmStore
if bypass_ls && !streaming_acc.response_text.is_empty() {
store.set_response_text(&streaming_acc.response_text).await;
}
if bypass_ls && streaming_acc.is_complete {
store.mark_response_complete();
}
}
// If functionCall detected + custom tools → send dummy + stop
if chunk_has_fc && modify_requests && store.get_tools().await.is_some() {
info!("MITM: functionCall in body chunk → sending chunked terminator to LS");
let _ = client.write_all(b"0\r\n\r\n").await;
break;
if bypass_ls {
if chunk_has_fc || streaming_acc.is_complete {
info!("MITM: response captured → NOT forwarding to LS (bypass mode)");
store.mark_response_complete();
break;
}
// Keep reading chunks without forwarding to LS
response_body_buf.extend_from_slice(chunk);
if let Some(cl) = response_content_length {
if response_body_buf.len() >= cl {
store.mark_response_complete();
break;
}
}
if is_chunked && has_chunked_terminator(&response_body_buf) {
store.mark_response_complete();
break;
}
continue;
}
// Normal path: forward chunk to client (LS)

View File

@@ -88,6 +88,13 @@ pub struct MitmStore {
call_id_to_name: Arc<RwLock<HashMap<String, String>>>,
/// Last captured function calls (for conversation history rewriting).
last_function_calls: Arc<RwLock<Vec<CapturedFunctionCall>>>,
// ── Direct response capture (bypasses LS) ────────────────────────────
/// Captured response text from MITM when custom tools are active.
/// The completions handler reads this instead of polling LS steps.
captured_response_text: Arc<RwLock<Option<String>>>,
/// Whether the captured response is complete (finishReason received).
response_complete: Arc<AtomicBool>,
}
/// Aggregate statistics across all intercepted traffic.
@@ -126,6 +133,8 @@ impl MitmStore {
pending_tool_results: Arc::new(RwLock::new(Vec::new())),
call_id_to_name: Arc::new(RwLock::new(HashMap::new())),
last_function_calls: Arc::new(RwLock::new(Vec::new())),
captured_response_text: Arc::new(RwLock::new(None)),
response_complete: Arc::new(AtomicBool::new(false)),
}
}
@@ -354,4 +363,56 @@ impl MitmStore {
pub async fn get_last_function_calls(&self) -> Vec<CapturedFunctionCall> {
self.last_function_calls.read().await.clone()
}
// ── Direct response capture (bypass LS) ──────────────────────────────
/// Append text to the captured response.
pub async fn append_response_text(&self, text: &str) {
let mut resp = self.captured_response_text.write().await;
if let Some(ref mut existing) = *resp {
existing.push_str(text);
} else {
*resp = Some(text.to_string());
}
}
/// Set (replace) the captured response text.
pub async fn set_response_text(&self, text: &str) {
*self.captured_response_text.write().await = Some(text.to_string());
}
/// Take the captured response text (consumes it).
pub async fn take_response_text(&self) -> Option<String> {
self.captured_response_text.write().await.take()
}
/// Peek at the captured response text without consuming it.
pub async fn peek_response_text(&self) -> Option<String> {
self.captured_response_text.read().await.clone()
}
/// Mark the response as complete.
pub fn mark_response_complete(&self) {
self.response_complete.store(true, Ordering::SeqCst);
}
/// Check if the response is complete.
pub fn is_response_complete(&self) -> bool {
self.response_complete.load(Ordering::SeqCst)
}
/// Clear captured response state (call at start of new request).
pub fn clear_response(&self) {
self.response_complete.store(false, Ordering::SeqCst);
// Can't use async in sync fn, so we spawn a task... or just use try_write
if let Ok(mut resp) = self.captured_response_text.try_write() {
*resp = None;
}
}
/// Async version of clear_response.
pub async fn clear_response_async(&self) {
self.response_complete.store(false, Ordering::SeqCst);
*self.captured_response_text.write().await = None;
}
}