fix: bypass LS entirely when custom tools are active
When custom tools are set, don't forward ANY response from Google to the LS. Instead, capture text and function calls directly into MitmStore. The completions handler reads from MitmStore. This eliminates the LS multi-turn loop (5 requests, 30+ seconds) that occurred because the LS kept processing responses internally. Tool calls now return in ~1.3s instead of timing out.
This commit is contained in:
@@ -274,6 +274,10 @@ async fn chat_completions_stream(
|
|||||||
let stream = async_stream::stream! {
|
let stream = async_stream::stream! {
|
||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
let mut last_text = String::new();
|
let mut last_text = String::new();
|
||||||
|
let has_custom_tools = state.mitm_store.get_tools().await.is_some();
|
||||||
|
|
||||||
|
// Clear any stale captured response from previous requests
|
||||||
|
state.mitm_store.clear_response_async().await;
|
||||||
|
|
||||||
// Initial role chunk
|
// Initial role chunk
|
||||||
yield Ok::<_, std::convert::Infallible>(Event::default().data(serde_json::to_string(&serde_json::json!({
|
yield Ok::<_, std::convert::Infallible>(Event::default().data(serde_json::to_string(&serde_json::json!({
|
||||||
@@ -342,6 +346,112 @@ async fn chat_completions_stream(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ── Check for MITM-captured response text (bypass LS) ──
|
||||||
|
if has_custom_tools {
|
||||||
|
if let Some(text) = state.mitm_store.peek_response_text().await {
|
||||||
|
if !text.is_empty() && text != last_text {
|
||||||
|
let delta = if text.len() > last_text.len() && text.starts_with(&*last_text) {
|
||||||
|
text[last_text.len()..].to_string()
|
||||||
|
} else {
|
||||||
|
text.clone()
|
||||||
|
};
|
||||||
|
|
||||||
|
if !delta.is_empty() {
|
||||||
|
yield Ok(Event::default().data(serde_json::to_string(&serde_json::json!({
|
||||||
|
"id": completion_id,
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": now_unix(),
|
||||||
|
"model": model_name,
|
||||||
|
"choices": [{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {"content": delta},
|
||||||
|
"finish_reason": serde_json::Value::Null,
|
||||||
|
}],
|
||||||
|
})).unwrap_or_default()));
|
||||||
|
last_text = text;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if MITM response is complete
|
||||||
|
if state.mitm_store.is_response_complete() && !last_text.is_empty() {
|
||||||
|
debug!("Completions: MITM response complete (bypass), text length={}", last_text.len());
|
||||||
|
yield Ok(Event::default().data(serde_json::to_string(&serde_json::json!({
|
||||||
|
"id": completion_id,
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": now_unix(),
|
||||||
|
"model": model_name,
|
||||||
|
"choices": [{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {},
|
||||||
|
"finish_reason": "stop",
|
||||||
|
}],
|
||||||
|
})).unwrap_or_default()));
|
||||||
|
yield Ok(Event::default().data("[DONE]"));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
} else if state.mitm_store.is_response_complete() {
|
||||||
|
// Response complete but no text — might be a tool call we already handled
|
||||||
|
// or an empty response. Give it a moment then bail.
|
||||||
|
tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
|
||||||
|
// Re-check function calls one more time
|
||||||
|
let final_check = state.mitm_store.take_any_function_calls().await;
|
||||||
|
if let Some(ref calls) = final_check {
|
||||||
|
if !calls.is_empty() {
|
||||||
|
let mut tool_calls = Vec::new();
|
||||||
|
for (i, fc) in calls.iter().enumerate() {
|
||||||
|
let call_id = format!(
|
||||||
|
"call_{}",
|
||||||
|
uuid::Uuid::new_v4().to_string().replace('-', "")[..24].to_string()
|
||||||
|
);
|
||||||
|
let arguments = serde_json::to_string(&fc.args).unwrap_or_default();
|
||||||
|
tool_calls.push(serde_json::json!({
|
||||||
|
"index": i,
|
||||||
|
"id": call_id,
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": fc.name,
|
||||||
|
"arguments": arguments,
|
||||||
|
},
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
yield Ok(Event::default().data(serde_json::to_string(&serde_json::json!({
|
||||||
|
"id": completion_id,
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": now_unix(),
|
||||||
|
"model": model_name,
|
||||||
|
"choices": [{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {"tool_calls": tool_calls},
|
||||||
|
"finish_reason": serde_json::Value::Null,
|
||||||
|
}],
|
||||||
|
})).unwrap_or_default()));
|
||||||
|
yield Ok(Event::default().data(serde_json::to_string(&serde_json::json!({
|
||||||
|
"id": completion_id,
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": now_unix(),
|
||||||
|
"model": model_name,
|
||||||
|
"choices": [{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {},
|
||||||
|
"finish_reason": "tool_calls",
|
||||||
|
}],
|
||||||
|
})).unwrap_or_default()));
|
||||||
|
yield Ok(Event::default().data("[DONE]"));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// When using bypass mode, skip LS step polling
|
||||||
|
keepalive_counter += 1;
|
||||||
|
if keepalive_counter % 10 == 0 {
|
||||||
|
yield Ok(Event::default().comment("keepalive"));
|
||||||
|
}
|
||||||
|
let poll_ms: u64 = rand::thread_rng().gen_range(200..350);
|
||||||
|
tokio::time::sleep(tokio::time::Duration::from_millis(poll_ms)).await;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
// ── Check LS steps for text streaming ──
|
// ── Check LS steps for text streaming ──
|
||||||
if let Ok((status, data)) = state.backend.get_steps(&cascade_id).await {
|
if let Ok((status, data)) = state.backend.get_steps(&cascade_id).await {
|
||||||
if status == 200 {
|
if status == 200 {
|
||||||
|
|||||||
@@ -737,6 +737,8 @@ async fn handle_http_over_tls(
|
|||||||
|
|
||||||
// Parse ORIGINAL initial body for MITM interception
|
// Parse ORIGINAL initial body for MITM interception
|
||||||
let mut has_function_call = false;
|
let mut has_function_call = false;
|
||||||
|
let bypass_ls = modify_requests && store.get_tools().await.is_some();
|
||||||
|
|
||||||
if is_streaming_response && hdr_end < header_buf.len() {
|
if is_streaming_response && hdr_end < header_buf.len() {
|
||||||
let body = String::from_utf8_lossy(&header_buf[hdr_end..]);
|
let body = String::from_utf8_lossy(&header_buf[hdr_end..]);
|
||||||
parse_streaming_chunk(&body, &mut streaming_acc);
|
parse_streaming_chunk(&body, &mut streaming_acc);
|
||||||
@@ -750,41 +752,38 @@ async fn handle_http_over_tls(
|
|||||||
store.set_last_function_calls(streaming_acc.function_calls.clone()).await;
|
store.set_last_function_calls(streaming_acc.function_calls.clone()).await;
|
||||||
info!("MITM: stored {} function call(s) from initial body", streaming_acc.function_calls.len());
|
info!("MITM: stored {} function call(s) from initial body", streaming_acc.function_calls.len());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Capture response text directly into MitmStore
|
||||||
|
if bypass_ls && !streaming_acc.response_text.is_empty() {
|
||||||
|
store.set_response_text(&streaming_acc.response_text).await;
|
||||||
|
}
|
||||||
|
if bypass_ls && streaming_acc.is_complete {
|
||||||
|
store.mark_response_complete();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if has_function_call && modify_requests && store.get_tools().await.is_some() {
|
if bypass_ls {
|
||||||
info!("MITM: functionCall detected → sending dummy STOP response to LS");
|
if has_function_call {
|
||||||
|
info!("MITM: functionCall captured → NOT forwarding to LS (bypass mode)");
|
||||||
// Build a clean SSE response the LS will accept
|
store.mark_response_complete();
|
||||||
let dummy_json = serde_json::json!({
|
break;
|
||||||
"response": {
|
|
||||||
"candidates": [{
|
|
||||||
"content": {
|
|
||||||
"role": "model",
|
|
||||||
"parts": [{"text": "Tool call completed. Awaiting external tool result."}]
|
|
||||||
},
|
|
||||||
"finishReason": "STOP"
|
|
||||||
}],
|
|
||||||
"modelVersion": "gemini-3-flash"
|
|
||||||
},
|
|
||||||
"metadata": {}
|
|
||||||
});
|
|
||||||
let dummy_data = format!("data: {}\r\n\r\n", serde_json::to_string(&dummy_json).unwrap());
|
|
||||||
let dummy_chunk = format!("{:x}\r\n{}\r\n0\r\n\r\n", dummy_data.len(), dummy_data);
|
|
||||||
|
|
||||||
// Send headers (from original response) + dummy body
|
|
||||||
let headers_only = &header_buf[..hdr_end];
|
|
||||||
if let Err(e) = client.write_all(headers_only).await {
|
|
||||||
warn!(error = %e, "MITM: write headers failed");
|
|
||||||
}
|
}
|
||||||
if let Err(e) = client.write_all(dummy_chunk.as_bytes()).await {
|
// Don't forward to LS — just continue reading chunks
|
||||||
warn!(error = %e, "MITM: write dummy body failed");
|
// Send headers only so upstream doesn't close
|
||||||
|
if let Some(cl) = response_content_length {
|
||||||
|
if response_body_buf.len() >= cl {
|
||||||
|
store.mark_response_complete();
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
// Done — don't forward the real response
|
if is_chunked && has_chunked_terminator(&response_body_buf) {
|
||||||
break;
|
store.mark_response_complete();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Normal path: forward headers+body as-is
|
// Normal path (no custom tools): forward headers+body as-is
|
||||||
if let Err(e) = client.write_all(&header_buf).await {
|
if let Err(e) = client.write_all(&header_buf).await {
|
||||||
warn!(error = %e, "MITM: write to client failed");
|
warn!(error = %e, "MITM: write to client failed");
|
||||||
break;
|
break;
|
||||||
@@ -804,14 +803,15 @@ async fn handle_http_over_tls(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ── Response body interception ────────────────────────────────
|
// ── Response body interception ────────────────────────────────
|
||||||
// Parse ORIGINAL chunk for MITM interception (captures functionCalls)
|
|
||||||
let mut chunk_has_fc = false;
|
let mut chunk_has_fc = false;
|
||||||
|
let bypass_ls = modify_requests && store.get_tools().await.is_some();
|
||||||
|
|
||||||
if is_streaming_response {
|
if is_streaming_response {
|
||||||
let s = String::from_utf8_lossy(chunk);
|
let s = String::from_utf8_lossy(chunk);
|
||||||
parse_streaming_chunk(&s, &mut streaming_acc);
|
parse_streaming_chunk(&s, &mut streaming_acc);
|
||||||
chunk_has_fc = !streaming_acc.function_calls.is_empty();
|
chunk_has_fc = !streaming_acc.function_calls.is_empty();
|
||||||
|
|
||||||
// Immediately store captured function calls — don't wait for loop end
|
// Immediately store captured function calls
|
||||||
if chunk_has_fc {
|
if chunk_has_fc {
|
||||||
for fc in &streaming_acc.function_calls {
|
for fc in &streaming_acc.function_calls {
|
||||||
store.record_function_call(cascade_hint.as_deref(), fc.clone()).await;
|
store.record_function_call(cascade_hint.as_deref(), fc.clone()).await;
|
||||||
@@ -819,13 +819,35 @@ async fn handle_http_over_tls(
|
|||||||
store.set_last_function_calls(streaming_acc.function_calls.clone()).await;
|
store.set_last_function_calls(streaming_acc.function_calls.clone()).await;
|
||||||
info!("MITM: stored {} function call(s) from body chunk", streaming_acc.function_calls.len());
|
info!("MITM: stored {} function call(s) from body chunk", streaming_acc.function_calls.len());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Capture response text directly into MitmStore
|
||||||
|
if bypass_ls && !streaming_acc.response_text.is_empty() {
|
||||||
|
store.set_response_text(&streaming_acc.response_text).await;
|
||||||
|
}
|
||||||
|
if bypass_ls && streaming_acc.is_complete {
|
||||||
|
store.mark_response_complete();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If functionCall detected + custom tools → send dummy + stop
|
if bypass_ls {
|
||||||
if chunk_has_fc && modify_requests && store.get_tools().await.is_some() {
|
if chunk_has_fc || streaming_acc.is_complete {
|
||||||
info!("MITM: functionCall in body chunk → sending chunked terminator to LS");
|
info!("MITM: response captured → NOT forwarding to LS (bypass mode)");
|
||||||
let _ = client.write_all(b"0\r\n\r\n").await;
|
store.mark_response_complete();
|
||||||
break;
|
break;
|
||||||
|
}
|
||||||
|
// Keep reading chunks without forwarding to LS
|
||||||
|
response_body_buf.extend_from_slice(chunk);
|
||||||
|
if let Some(cl) = response_content_length {
|
||||||
|
if response_body_buf.len() >= cl {
|
||||||
|
store.mark_response_complete();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if is_chunked && has_chunked_terminator(&response_body_buf) {
|
||||||
|
store.mark_response_complete();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Normal path: forward chunk to client (LS)
|
// Normal path: forward chunk to client (LS)
|
||||||
|
|||||||
@@ -88,6 +88,13 @@ pub struct MitmStore {
|
|||||||
call_id_to_name: Arc<RwLock<HashMap<String, String>>>,
|
call_id_to_name: Arc<RwLock<HashMap<String, String>>>,
|
||||||
/// Last captured function calls (for conversation history rewriting).
|
/// Last captured function calls (for conversation history rewriting).
|
||||||
last_function_calls: Arc<RwLock<Vec<CapturedFunctionCall>>>,
|
last_function_calls: Arc<RwLock<Vec<CapturedFunctionCall>>>,
|
||||||
|
|
||||||
|
// ── Direct response capture (bypasses LS) ────────────────────────────
|
||||||
|
/// Captured response text from MITM when custom tools are active.
|
||||||
|
/// The completions handler reads this instead of polling LS steps.
|
||||||
|
captured_response_text: Arc<RwLock<Option<String>>>,
|
||||||
|
/// Whether the captured response is complete (finishReason received).
|
||||||
|
response_complete: Arc<AtomicBool>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Aggregate statistics across all intercepted traffic.
|
/// Aggregate statistics across all intercepted traffic.
|
||||||
@@ -126,6 +133,8 @@ impl MitmStore {
|
|||||||
pending_tool_results: Arc::new(RwLock::new(Vec::new())),
|
pending_tool_results: Arc::new(RwLock::new(Vec::new())),
|
||||||
call_id_to_name: Arc::new(RwLock::new(HashMap::new())),
|
call_id_to_name: Arc::new(RwLock::new(HashMap::new())),
|
||||||
last_function_calls: Arc::new(RwLock::new(Vec::new())),
|
last_function_calls: Arc::new(RwLock::new(Vec::new())),
|
||||||
|
captured_response_text: Arc::new(RwLock::new(None)),
|
||||||
|
response_complete: Arc::new(AtomicBool::new(false)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -354,4 +363,56 @@ impl MitmStore {
|
|||||||
pub async fn get_last_function_calls(&self) -> Vec<CapturedFunctionCall> {
|
pub async fn get_last_function_calls(&self) -> Vec<CapturedFunctionCall> {
|
||||||
self.last_function_calls.read().await.clone()
|
self.last_function_calls.read().await.clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ── Direct response capture (bypass LS) ──────────────────────────────
|
||||||
|
|
||||||
|
/// Append text to the captured response.
|
||||||
|
pub async fn append_response_text(&self, text: &str) {
|
||||||
|
let mut resp = self.captured_response_text.write().await;
|
||||||
|
if let Some(ref mut existing) = *resp {
|
||||||
|
existing.push_str(text);
|
||||||
|
} else {
|
||||||
|
*resp = Some(text.to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set (replace) the captured response text.
|
||||||
|
pub async fn set_response_text(&self, text: &str) {
|
||||||
|
*self.captured_response_text.write().await = Some(text.to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Take the captured response text (consumes it).
|
||||||
|
pub async fn take_response_text(&self) -> Option<String> {
|
||||||
|
self.captured_response_text.write().await.take()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Peek at the captured response text without consuming it.
|
||||||
|
pub async fn peek_response_text(&self) -> Option<String> {
|
||||||
|
self.captured_response_text.read().await.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Mark the response as complete.
|
||||||
|
pub fn mark_response_complete(&self) {
|
||||||
|
self.response_complete.store(true, Ordering::SeqCst);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if the response is complete.
|
||||||
|
pub fn is_response_complete(&self) -> bool {
|
||||||
|
self.response_complete.load(Ordering::SeqCst)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Clear captured response state (call at start of new request).
|
||||||
|
pub fn clear_response(&self) {
|
||||||
|
self.response_complete.store(false, Ordering::SeqCst);
|
||||||
|
// Can't use async in sync fn, so we spawn a task... or just use try_write
|
||||||
|
if let Ok(mut resp) = self.captured_response_text.try_write() {
|
||||||
|
*resp = None;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Async version of clear_response.
|
||||||
|
pub async fn clear_response_async(&self) {
|
||||||
|
self.response_complete.store(false, Ordering::SeqCst);
|
||||||
|
*self.captured_response_text.write().await = None;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user