fix: block ALL LS follow-up requests across connections
Move the in-flight blocking check to the top of the LLM request flow, BEFORE request modification. This catches follow-ups on ALL connections (the LS opens multiple parallel TLS connections). Only the very first modified request reaches Google — all others get fake STOP responses. Previously, each new connection independently allowed one request through before blocking, letting 4-5 requests leak per turn.
This commit is contained in:
@@ -14,12 +14,15 @@ use std::sync::Arc;
|
||||
use tracing::{debug, info};
|
||||
|
||||
use super::models::{lookup_model, DEFAULT_MODEL, MODELS};
|
||||
use super::polling::{extract_response_text, is_response_done, poll_for_response, extract_model_usage, extract_thinking_signature, extract_thinking_content};
|
||||
use super::polling::{
|
||||
extract_model_usage, extract_response_text, extract_thinking_content,
|
||||
extract_thinking_signature, is_response_done, poll_for_response,
|
||||
};
|
||||
use super::types::*;
|
||||
use super::util::{err_response, upstream_err_response, now_unix, responses_sse_event};
|
||||
use super::util::{err_response, now_unix, responses_sse_event, upstream_err_response};
|
||||
use super::AppState;
|
||||
use crate::mitm::modify::{openai_tool_choice_to_gemini, openai_tools_to_gemini};
|
||||
use crate::mitm::store::PendingToolResult;
|
||||
use crate::mitm::modify::{openai_tools_to_gemini, openai_tool_choice_to_gemini};
|
||||
|
||||
// ─── Input extraction ────────────────────────────────────────────────────────
|
||||
|
||||
@@ -35,7 +38,11 @@ struct ToolResultInput {
|
||||
fn extract_responses_input(
|
||||
input: &serde_json::Value,
|
||||
instructions: Option<&str>,
|
||||
) -> (String, Vec<ToolResultInput>, Option<crate::proto::ImageData>) {
|
||||
) -> (
|
||||
String,
|
||||
Vec<ToolResultInput>,
|
||||
Option<crate::proto::ImageData>,
|
||||
) {
|
||||
let mut tool_results: Vec<ToolResultInput> = Vec::new();
|
||||
let mut image: Option<crate::proto::ImageData> = None;
|
||||
|
||||
@@ -45,10 +52,9 @@ fn extract_responses_input(
|
||||
// Check for function_call_output items
|
||||
for item in items {
|
||||
if item["type"].as_str() == Some("function_call_output") {
|
||||
if let (Some(call_id), Some(output)) = (
|
||||
item["call_id"].as_str(),
|
||||
item["output"].as_str(),
|
||||
) {
|
||||
if let (Some(call_id), Some(output)) =
|
||||
(item["call_id"].as_str(), item["output"].as_str())
|
||||
{
|
||||
tool_results.push(ToolResultInput {
|
||||
call_id: call_id.to_string(),
|
||||
output: output.to_string(),
|
||||
@@ -230,24 +236,31 @@ pub(crate) async fn handle_responses(
|
||||
);
|
||||
}
|
||||
|
||||
let (user_text, tool_results, image) = extract_responses_input(&body.input, body.instructions.as_deref());
|
||||
let (user_text, tool_results, image) =
|
||||
extract_responses_input(&body.input, body.instructions.as_deref());
|
||||
|
||||
// Handle tool result submission (function_call_output in input)
|
||||
let is_tool_result_turn = !tool_results.is_empty();
|
||||
if is_tool_result_turn {
|
||||
for tr in &tool_results {
|
||||
// Look up function name from call_id
|
||||
let name = state.mitm_store.lookup_call_id(&tr.call_id).await
|
||||
let name = state
|
||||
.mitm_store
|
||||
.lookup_call_id(&tr.call_id)
|
||||
.await
|
||||
.unwrap_or_else(|| "unknown_function".to_string());
|
||||
|
||||
// Parse the output as JSON, fall back to string wrapper
|
||||
let result_value = serde_json::from_str::<serde_json::Value>(&tr.output)
|
||||
.unwrap_or_else(|_| serde_json::json!({"result": tr.output}));
|
||||
|
||||
state.mitm_store.add_tool_result(PendingToolResult {
|
||||
name,
|
||||
result: result_value,
|
||||
}).await;
|
||||
state
|
||||
.mitm_store
|
||||
.add_tool_result(PendingToolResult {
|
||||
name,
|
||||
result: result_value,
|
||||
})
|
||||
.await;
|
||||
}
|
||||
info!(
|
||||
count = tool_results.len(),
|
||||
@@ -275,7 +288,10 @@ pub(crate) async fn handle_responses(
|
||||
let gemini_tools = openai_tools_to_gemini(tools);
|
||||
if !gemini_tools.is_empty() {
|
||||
state.mitm_store.set_tools(gemini_tools).await;
|
||||
info!(count = tools.len(), "Stored client tools for MITM injection");
|
||||
info!(
|
||||
count = tools.len(),
|
||||
"Stored client tools for MITM injection"
|
||||
);
|
||||
}
|
||||
}
|
||||
if let Some(ref choice) = body.tool_choice {
|
||||
@@ -289,7 +305,9 @@ pub(crate) async fn handle_responses(
|
||||
let fmt_type = text_val["format"]["type"].as_str().unwrap_or("text");
|
||||
if fmt_type == "json_schema" {
|
||||
let name = text_val["format"]["name"].as_str().map(|s| s.to_string());
|
||||
let schema = text_val["format"]["schema"].as_object().map(|o| serde_json::Value::Object(o.clone()));
|
||||
let schema = text_val["format"]["schema"]
|
||||
.as_object()
|
||||
.map(|o| serde_json::Value::Object(o.clone()));
|
||||
let strict = text_val["format"]["strict"].as_bool();
|
||||
let tf = TextFormat {
|
||||
format: TextFormatInner {
|
||||
@@ -321,9 +339,13 @@ pub(crate) async fn handle_responses(
|
||||
response_schema,
|
||||
google_search: has_web_search,
|
||||
};
|
||||
if gp.temperature.is_some() || gp.top_p.is_some() || gp.max_output_tokens.is_some()
|
||||
|| gp.reasoning_effort.is_some() || gp.response_mime_type.is_some()
|
||||
|| gp.response_schema.is_some() || gp.google_search
|
||||
if gp.temperature.is_some()
|
||||
|| gp.top_p.is_some()
|
||||
|| gp.max_output_tokens.is_some()
|
||||
|| gp.reasoning_effort.is_some()
|
||||
|| gp.response_mime_type.is_some()
|
||||
|| gp.response_schema.is_some()
|
||||
|| gp.google_search
|
||||
{
|
||||
state.mitm_store.set_generation_params(gp).await;
|
||||
} else {
|
||||
@@ -331,10 +353,7 @@ pub(crate) async fn handle_responses(
|
||||
}
|
||||
}
|
||||
|
||||
let response_id = format!(
|
||||
"resp_{}",
|
||||
uuid::Uuid::new_v4().to_string().replace('-', "")
|
||||
);
|
||||
let response_id = format!("resp_{}", uuid::Uuid::new_v4().to_string().replace('-', ""));
|
||||
|
||||
// Session/conversation management
|
||||
let session_id_str = extract_conversation_id(&body.conversation);
|
||||
@@ -371,12 +390,13 @@ pub(crate) async fn handle_responses(
|
||||
// Store image for MITM injection (LS doesn't forward images to Google API)
|
||||
if let Some(ref img) = image {
|
||||
use base64::Engine;
|
||||
state.mitm_store.set_pending_image(
|
||||
crate::mitm::store::PendingImage {
|
||||
state
|
||||
.mitm_store
|
||||
.set_pending_image(crate::mitm::store::PendingImage {
|
||||
base64_data: base64::engine::general_purpose::STANDARD.encode(&img.data),
|
||||
mime_type: img.mime_type.clone(),
|
||||
}
|
||||
).await;
|
||||
})
|
||||
.await;
|
||||
}
|
||||
match state
|
||||
.backend
|
||||
@@ -419,21 +439,32 @@ pub(crate) async fn handle_responses(
|
||||
metadata: body.metadata.clone().unwrap_or(serde_json::json!({})),
|
||||
max_tool_calls: body.max_tool_calls,
|
||||
reasoning_effort: body.reasoning_effort.clone(),
|
||||
tool_choice: body.tool_choice.clone().unwrap_or(serde_json::json!("auto")),
|
||||
tool_choice: body
|
||||
.tool_choice
|
||||
.clone()
|
||||
.unwrap_or(serde_json::json!("auto")),
|
||||
tools: body.tools.clone().unwrap_or_default(),
|
||||
text_format,
|
||||
};
|
||||
|
||||
if body.stream {
|
||||
handle_responses_stream(
|
||||
state, response_id, model_name.to_string(), cascade_id,
|
||||
body.timeout, req_params,
|
||||
state,
|
||||
response_id,
|
||||
model_name.to_string(),
|
||||
cascade_id,
|
||||
body.timeout,
|
||||
req_params,
|
||||
)
|
||||
.await
|
||||
} else {
|
||||
handle_responses_sync(
|
||||
state, response_id, model_name.to_string(), cascade_id,
|
||||
body.timeout, req_params,
|
||||
state,
|
||||
response_id,
|
||||
model_name.to_string(),
|
||||
cascade_id,
|
||||
body.timeout,
|
||||
req_params,
|
||||
)
|
||||
.await
|
||||
}
|
||||
@@ -485,7 +516,9 @@ async fn usage_from_poll(
|
||||
if let Some(u) = mitm_store.peek_usage(key).await {
|
||||
if u.thinking_output_tokens > 0 && u.thinking_text.is_none() {
|
||||
// Call 2 hasn't arrived yet — wait briefly for the merge
|
||||
tracing::debug!("MITM: thinking tokens found but no text, waiting for summary merge...");
|
||||
tracing::debug!(
|
||||
"MITM: thinking tokens found but no text, waiting for summary merge..."
|
||||
);
|
||||
for _ in 0..10 {
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
||||
if let Some(u2) = mitm_store.peek_usage(key).await {
|
||||
@@ -526,13 +559,18 @@ async fn usage_from_poll(
|
||||
|
||||
// Priority 2: LS trajectory data (from CHECKPOINT/metadata steps)
|
||||
if let Some(u) = model_usage {
|
||||
return (Usage {
|
||||
input_tokens: u.input_tokens,
|
||||
input_tokens_details: InputTokensDetails { cached_tokens: 0 },
|
||||
output_tokens: u.output_tokens,
|
||||
output_tokens_details: OutputTokensDetails { reasoning_tokens: 0 },
|
||||
total_tokens: u.input_tokens + u.output_tokens,
|
||||
}, None);
|
||||
return (
|
||||
Usage {
|
||||
input_tokens: u.input_tokens,
|
||||
input_tokens_details: InputTokensDetails { cached_tokens: 0 },
|
||||
output_tokens: u.output_tokens,
|
||||
output_tokens_details: OutputTokensDetails {
|
||||
reasoning_tokens: 0,
|
||||
},
|
||||
total_tokens: u.input_tokens + u.output_tokens,
|
||||
},
|
||||
None,
|
||||
);
|
||||
}
|
||||
|
||||
// Priority 3: Estimate from text lengths
|
||||
@@ -575,14 +613,22 @@ async fn handle_responses_sync(
|
||||
"call_{}",
|
||||
uuid::Uuid::new_v4().to_string().replace('-', "")[..24].to_string()
|
||||
);
|
||||
state.mitm_store.register_call_id(call_id.clone(), fc.name.clone()).await;
|
||||
state
|
||||
.mitm_store
|
||||
.register_call_id(call_id.clone(), fc.name.clone())
|
||||
.await;
|
||||
let arguments = serde_json::to_string(&fc.args).unwrap_or_default();
|
||||
output_items.push(build_function_call_output(&call_id, &fc.name, &arguments));
|
||||
output_items
|
||||
.push(build_function_call_output(&call_id, &fc.name, &arguments));
|
||||
}
|
||||
let (usage, _) = usage_from_poll(
|
||||
&state.mitm_store, &cascade_id, &None,
|
||||
¶ms.user_text, "",
|
||||
).await;
|
||||
&state.mitm_store,
|
||||
&cascade_id,
|
||||
&None,
|
||||
¶ms.user_text,
|
||||
"",
|
||||
)
|
||||
.await;
|
||||
let resp = build_response_object(
|
||||
ResponseData {
|
||||
id: response_id,
|
||||
@@ -602,12 +648,20 @@ async fn handle_responses_sync(
|
||||
|
||||
// Check for completed text response
|
||||
if state.mitm_store.is_response_complete() {
|
||||
let text = state.mitm_store.take_response_text().await.unwrap_or_default();
|
||||
let text = state
|
||||
.mitm_store
|
||||
.take_response_text()
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
let thinking = state.mitm_store.take_thinking_text().await;
|
||||
let (usage, _) = usage_from_poll(
|
||||
&state.mitm_store, &cascade_id, &None,
|
||||
¶ms.user_text, &text,
|
||||
).await;
|
||||
&state.mitm_store,
|
||||
&cascade_id,
|
||||
&None,
|
||||
¶ms.user_text,
|
||||
&text,
|
||||
)
|
||||
.await;
|
||||
|
||||
let mut output_items: Vec<serde_json::Value> = Vec::new();
|
||||
if let Some(ref t) = thinking {
|
||||
@@ -658,10 +712,7 @@ async fn handle_responses_sync(
|
||||
return upstream_err_response(err);
|
||||
}
|
||||
let completed_at = now_unix();
|
||||
let msg_id = format!(
|
||||
"msg_{}",
|
||||
uuid::Uuid::new_v4().to_string().replace('-', "")
|
||||
);
|
||||
let msg_id = format!("msg_{}", uuid::Uuid::new_v4().to_string().replace('-', ""));
|
||||
|
||||
// Check for captured function calls from MITM (clears the active flag)
|
||||
let captured_tool_calls = state.mitm_store.take_any_function_calls().await;
|
||||
@@ -689,7 +740,10 @@ async fn handle_responses_sync(
|
||||
uuid::Uuid::new_v4().to_string().replace('-', "")[..24].to_string()
|
||||
);
|
||||
// Register call_id → name mapping for tool result routing
|
||||
state.mitm_store.register_call_id(call_id.clone(), fc.name.clone()).await;
|
||||
state
|
||||
.mitm_store
|
||||
.register_call_id(call_id.clone(), fc.name.clone())
|
||||
.await;
|
||||
|
||||
// Stringify args (OpenAI sends arguments as JSON string)
|
||||
let arguments = serde_json::to_string(&fc.args).unwrap_or_default();
|
||||
@@ -697,9 +751,13 @@ async fn handle_responses_sync(
|
||||
}
|
||||
|
||||
let (usage, _) = usage_from_poll(
|
||||
&state.mitm_store, &cascade_id, &poll_result.usage,
|
||||
¶ms.user_text, &poll_result.text,
|
||||
).await;
|
||||
&state.mitm_store,
|
||||
&cascade_id,
|
||||
&poll_result.usage,
|
||||
¶ms.user_text,
|
||||
&poll_result.text,
|
||||
)
|
||||
.await;
|
||||
|
||||
let resp = build_response_object(
|
||||
ResponseData {
|
||||
@@ -719,7 +777,14 @@ async fn handle_responses_sync(
|
||||
}
|
||||
|
||||
// Normal text response (no tool calls)
|
||||
let (usage, mitm_thinking) = usage_from_poll(&state.mitm_store, &cascade_id, &poll_result.usage, ¶ms.user_text, &poll_result.text).await;
|
||||
let (usage, mitm_thinking) = usage_from_poll(
|
||||
&state.mitm_store,
|
||||
&cascade_id,
|
||||
&poll_result.usage,
|
||||
¶ms.user_text,
|
||||
&poll_result.text,
|
||||
)
|
||||
.await;
|
||||
|
||||
// Thinking text priority: MITM-captured (raw API) > LS-extracted (steps)
|
||||
let thinking_text = mitm_thinking.or(poll_result.thinking);
|
||||
@@ -1560,4 +1625,3 @@ fn completion_events(
|
||||
|
||||
events
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user