fix: block ALL LS follow-up requests across connections

Move the in-flight blocking check to the top of the LLM request flow,
BEFORE request modification. This catches follow-ups on ALL connections
(the LS opens multiple parallel TLS connections). Only the very first
modified request reaches Google — all others get fake STOP responses.

Previously, each new connection independently allowed one request
through before blocking, letting 4-5 requests leak per turn.
This commit is contained in:
Nikketryhard
2026-02-16 00:57:33 -06:00
parent a8f3c8915f
commit 3fdd0368a0
23 changed files with 992 additions and 568 deletions

View File

@@ -14,12 +14,15 @@ use std::sync::Arc;
use tracing::{debug, info};
use super::models::{lookup_model, DEFAULT_MODEL, MODELS};
use super::polling::{extract_response_text, is_response_done, poll_for_response, extract_model_usage, extract_thinking_signature, extract_thinking_content};
use super::polling::{
extract_model_usage, extract_response_text, extract_thinking_content,
extract_thinking_signature, is_response_done, poll_for_response,
};
use super::types::*;
use super::util::{err_response, upstream_err_response, now_unix, responses_sse_event};
use super::util::{err_response, now_unix, responses_sse_event, upstream_err_response};
use super::AppState;
use crate::mitm::modify::{openai_tool_choice_to_gemini, openai_tools_to_gemini};
use crate::mitm::store::PendingToolResult;
use crate::mitm::modify::{openai_tools_to_gemini, openai_tool_choice_to_gemini};
// ─── Input extraction ────────────────────────────────────────────────────────
@@ -35,7 +38,11 @@ struct ToolResultInput {
fn extract_responses_input(
input: &serde_json::Value,
instructions: Option<&str>,
) -> (String, Vec<ToolResultInput>, Option<crate::proto::ImageData>) {
) -> (
String,
Vec<ToolResultInput>,
Option<crate::proto::ImageData>,
) {
let mut tool_results: Vec<ToolResultInput> = Vec::new();
let mut image: Option<crate::proto::ImageData> = None;
@@ -45,10 +52,9 @@ fn extract_responses_input(
// Check for function_call_output items
for item in items {
if item["type"].as_str() == Some("function_call_output") {
if let (Some(call_id), Some(output)) = (
item["call_id"].as_str(),
item["output"].as_str(),
) {
if let (Some(call_id), Some(output)) =
(item["call_id"].as_str(), item["output"].as_str())
{
tool_results.push(ToolResultInput {
call_id: call_id.to_string(),
output: output.to_string(),
@@ -230,24 +236,31 @@ pub(crate) async fn handle_responses(
);
}
let (user_text, tool_results, image) = extract_responses_input(&body.input, body.instructions.as_deref());
let (user_text, tool_results, image) =
extract_responses_input(&body.input, body.instructions.as_deref());
// Handle tool result submission (function_call_output in input)
let is_tool_result_turn = !tool_results.is_empty();
if is_tool_result_turn {
for tr in &tool_results {
// Look up function name from call_id
let name = state.mitm_store.lookup_call_id(&tr.call_id).await
let name = state
.mitm_store
.lookup_call_id(&tr.call_id)
.await
.unwrap_or_else(|| "unknown_function".to_string());
// Parse the output as JSON, fall back to string wrapper
let result_value = serde_json::from_str::<serde_json::Value>(&tr.output)
.unwrap_or_else(|_| serde_json::json!({"result": tr.output}));
state.mitm_store.add_tool_result(PendingToolResult {
name,
result: result_value,
}).await;
state
.mitm_store
.add_tool_result(PendingToolResult {
name,
result: result_value,
})
.await;
}
info!(
count = tool_results.len(),
@@ -275,7 +288,10 @@ pub(crate) async fn handle_responses(
let gemini_tools = openai_tools_to_gemini(tools);
if !gemini_tools.is_empty() {
state.mitm_store.set_tools(gemini_tools).await;
info!(count = tools.len(), "Stored client tools for MITM injection");
info!(
count = tools.len(),
"Stored client tools for MITM injection"
);
}
}
if let Some(ref choice) = body.tool_choice {
@@ -289,7 +305,9 @@ pub(crate) async fn handle_responses(
let fmt_type = text_val["format"]["type"].as_str().unwrap_or("text");
if fmt_type == "json_schema" {
let name = text_val["format"]["name"].as_str().map(|s| s.to_string());
let schema = text_val["format"]["schema"].as_object().map(|o| serde_json::Value::Object(o.clone()));
let schema = text_val["format"]["schema"]
.as_object()
.map(|o| serde_json::Value::Object(o.clone()));
let strict = text_val["format"]["strict"].as_bool();
let tf = TextFormat {
format: TextFormatInner {
@@ -321,9 +339,13 @@ pub(crate) async fn handle_responses(
response_schema,
google_search: has_web_search,
};
if gp.temperature.is_some() || gp.top_p.is_some() || gp.max_output_tokens.is_some()
|| gp.reasoning_effort.is_some() || gp.response_mime_type.is_some()
|| gp.response_schema.is_some() || gp.google_search
if gp.temperature.is_some()
|| gp.top_p.is_some()
|| gp.max_output_tokens.is_some()
|| gp.reasoning_effort.is_some()
|| gp.response_mime_type.is_some()
|| gp.response_schema.is_some()
|| gp.google_search
{
state.mitm_store.set_generation_params(gp).await;
} else {
@@ -331,10 +353,7 @@ pub(crate) async fn handle_responses(
}
}
let response_id = format!(
"resp_{}",
uuid::Uuid::new_v4().to_string().replace('-', "")
);
let response_id = format!("resp_{}", uuid::Uuid::new_v4().to_string().replace('-', ""));
// Session/conversation management
let session_id_str = extract_conversation_id(&body.conversation);
@@ -371,12 +390,13 @@ pub(crate) async fn handle_responses(
// Store image for MITM injection (LS doesn't forward images to Google API)
if let Some(ref img) = image {
use base64::Engine;
state.mitm_store.set_pending_image(
crate::mitm::store::PendingImage {
state
.mitm_store
.set_pending_image(crate::mitm::store::PendingImage {
base64_data: base64::engine::general_purpose::STANDARD.encode(&img.data),
mime_type: img.mime_type.clone(),
}
).await;
})
.await;
}
match state
.backend
@@ -419,21 +439,32 @@ pub(crate) async fn handle_responses(
metadata: body.metadata.clone().unwrap_or(serde_json::json!({})),
max_tool_calls: body.max_tool_calls,
reasoning_effort: body.reasoning_effort.clone(),
tool_choice: body.tool_choice.clone().unwrap_or(serde_json::json!("auto")),
tool_choice: body
.tool_choice
.clone()
.unwrap_or(serde_json::json!("auto")),
tools: body.tools.clone().unwrap_or_default(),
text_format,
};
if body.stream {
handle_responses_stream(
state, response_id, model_name.to_string(), cascade_id,
body.timeout, req_params,
state,
response_id,
model_name.to_string(),
cascade_id,
body.timeout,
req_params,
)
.await
} else {
handle_responses_sync(
state, response_id, model_name.to_string(), cascade_id,
body.timeout, req_params,
state,
response_id,
model_name.to_string(),
cascade_id,
body.timeout,
req_params,
)
.await
}
@@ -485,7 +516,9 @@ async fn usage_from_poll(
if let Some(u) = mitm_store.peek_usage(key).await {
if u.thinking_output_tokens > 0 && u.thinking_text.is_none() {
// Call 2 hasn't arrived yet — wait briefly for the merge
tracing::debug!("MITM: thinking tokens found but no text, waiting for summary merge...");
tracing::debug!(
"MITM: thinking tokens found but no text, waiting for summary merge..."
);
for _ in 0..10 {
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
if let Some(u2) = mitm_store.peek_usage(key).await {
@@ -526,13 +559,18 @@ async fn usage_from_poll(
// Priority 2: LS trajectory data (from CHECKPOINT/metadata steps)
if let Some(u) = model_usage {
return (Usage {
input_tokens: u.input_tokens,
input_tokens_details: InputTokensDetails { cached_tokens: 0 },
output_tokens: u.output_tokens,
output_tokens_details: OutputTokensDetails { reasoning_tokens: 0 },
total_tokens: u.input_tokens + u.output_tokens,
}, None);
return (
Usage {
input_tokens: u.input_tokens,
input_tokens_details: InputTokensDetails { cached_tokens: 0 },
output_tokens: u.output_tokens,
output_tokens_details: OutputTokensDetails {
reasoning_tokens: 0,
},
total_tokens: u.input_tokens + u.output_tokens,
},
None,
);
}
// Priority 3: Estimate from text lengths
@@ -575,14 +613,22 @@ async fn handle_responses_sync(
"call_{}",
uuid::Uuid::new_v4().to_string().replace('-', "")[..24].to_string()
);
state.mitm_store.register_call_id(call_id.clone(), fc.name.clone()).await;
state
.mitm_store
.register_call_id(call_id.clone(), fc.name.clone())
.await;
let arguments = serde_json::to_string(&fc.args).unwrap_or_default();
output_items.push(build_function_call_output(&call_id, &fc.name, &arguments));
output_items
.push(build_function_call_output(&call_id, &fc.name, &arguments));
}
let (usage, _) = usage_from_poll(
&state.mitm_store, &cascade_id, &None,
&params.user_text, "",
).await;
&state.mitm_store,
&cascade_id,
&None,
&params.user_text,
"",
)
.await;
let resp = build_response_object(
ResponseData {
id: response_id,
@@ -602,12 +648,20 @@ async fn handle_responses_sync(
// Check for completed text response
if state.mitm_store.is_response_complete() {
let text = state.mitm_store.take_response_text().await.unwrap_or_default();
let text = state
.mitm_store
.take_response_text()
.await
.unwrap_or_default();
let thinking = state.mitm_store.take_thinking_text().await;
let (usage, _) = usage_from_poll(
&state.mitm_store, &cascade_id, &None,
&params.user_text, &text,
).await;
&state.mitm_store,
&cascade_id,
&None,
&params.user_text,
&text,
)
.await;
let mut output_items: Vec<serde_json::Value> = Vec::new();
if let Some(ref t) = thinking {
@@ -658,10 +712,7 @@ async fn handle_responses_sync(
return upstream_err_response(err);
}
let completed_at = now_unix();
let msg_id = format!(
"msg_{}",
uuid::Uuid::new_v4().to_string().replace('-', "")
);
let msg_id = format!("msg_{}", uuid::Uuid::new_v4().to_string().replace('-', ""));
// Check for captured function calls from MITM (clears the active flag)
let captured_tool_calls = state.mitm_store.take_any_function_calls().await;
@@ -689,7 +740,10 @@ async fn handle_responses_sync(
uuid::Uuid::new_v4().to_string().replace('-', "")[..24].to_string()
);
// Register call_id → name mapping for tool result routing
state.mitm_store.register_call_id(call_id.clone(), fc.name.clone()).await;
state
.mitm_store
.register_call_id(call_id.clone(), fc.name.clone())
.await;
// Stringify args (OpenAI sends arguments as JSON string)
let arguments = serde_json::to_string(&fc.args).unwrap_or_default();
@@ -697,9 +751,13 @@ async fn handle_responses_sync(
}
let (usage, _) = usage_from_poll(
&state.mitm_store, &cascade_id, &poll_result.usage,
&params.user_text, &poll_result.text,
).await;
&state.mitm_store,
&cascade_id,
&poll_result.usage,
&params.user_text,
&poll_result.text,
)
.await;
let resp = build_response_object(
ResponseData {
@@ -719,7 +777,14 @@ async fn handle_responses_sync(
}
// Normal text response (no tool calls)
let (usage, mitm_thinking) = usage_from_poll(&state.mitm_store, &cascade_id, &poll_result.usage, &params.user_text, &poll_result.text).await;
let (usage, mitm_thinking) = usage_from_poll(
&state.mitm_store,
&cascade_id,
&poll_result.usage,
&params.user_text,
&poll_result.text,
)
.await;
// Thinking text priority: MITM-captured (raw API) > LS-extracted (steps)
let thinking_text = mitm_thinking.or(poll_result.thinking);
@@ -1560,4 +1625,3 @@ fn completion_events(
events
}