fix: block ALL LS follow-up requests across connections
Move the in-flight blocking check to the top of the LLM request flow, BEFORE request modification. This catches follow-ups on ALL connections (the LS opens multiple parallel TLS connections). Only the very first modified request reaches Google — all others get fake STOP responses. Previously, each new connection independently allowed one request through before blocking, letting 4-5 requests leak per turn.
This commit is contained in:
@@ -11,8 +11,7 @@
|
||||
|
||||
use super::ca::MitmCa;
|
||||
use super::intercept::{
|
||||
extract_cascade_hint, parse_non_streaming_response, parse_streaming_chunk,
|
||||
StreamingAccumulator,
|
||||
extract_cascade_hint, parse_non_streaming_response, parse_streaming_chunk, StreamingAccumulator,
|
||||
};
|
||||
use super::store::MitmStore;
|
||||
use std::sync::Arc;
|
||||
@@ -54,7 +53,6 @@ pub struct MitmConfig {
|
||||
pub modify_requests: bool,
|
||||
}
|
||||
|
||||
|
||||
/// Run the MITM proxy server.
|
||||
///
|
||||
/// Returns (port, task_handle) — port it's listening on, handle to abort on shutdown.
|
||||
@@ -84,7 +82,8 @@ pub async fn run(
|
||||
let ca = ca.clone();
|
||||
let store = store.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = handle_connection(stream, ca, store, modify_requests).await {
|
||||
if let Err(e) = handle_connection(stream, ca, store, modify_requests).await
|
||||
{
|
||||
warn!(error = %e, "MITM connection error");
|
||||
}
|
||||
});
|
||||
@@ -131,8 +130,7 @@ async fn handle_connection(
|
||||
.await
|
||||
.map_err(|e| format!("Peek ClientHello: {e}"))?;
|
||||
|
||||
let domain = extract_sni(&hello_buf[..n])
|
||||
.unwrap_or_else(|| "unknown".to_string());
|
||||
let domain = extract_sni(&hello_buf[..n]).unwrap_or_else(|| "unknown".to_string());
|
||||
|
||||
info!(domain, "MITM: transparent redirect (iptables)");
|
||||
|
||||
@@ -224,22 +222,30 @@ fn extract_sni(buf: &[u8]) -> Option<String> {
|
||||
let mut pos = 34; // skip version + random
|
||||
|
||||
// Session ID
|
||||
if pos >= body.len() { return None; }
|
||||
if pos >= body.len() {
|
||||
return None;
|
||||
}
|
||||
let sid_len = body[pos] as usize;
|
||||
pos += 1 + sid_len;
|
||||
|
||||
// Cipher suites
|
||||
if pos + 2 > body.len() { return None; }
|
||||
if pos + 2 > body.len() {
|
||||
return None;
|
||||
}
|
||||
let cs_len = u16::from_be_bytes([body[pos], body[pos + 1]]) as usize;
|
||||
pos += 2 + cs_len;
|
||||
|
||||
// Compression methods
|
||||
if pos >= body.len() { return None; }
|
||||
if pos >= body.len() {
|
||||
return None;
|
||||
}
|
||||
let cm_len = body[pos] as usize;
|
||||
pos += 1 + cm_len;
|
||||
|
||||
// Extensions
|
||||
if pos + 2 > body.len() { return None; }
|
||||
if pos + 2 > body.len() {
|
||||
return None;
|
||||
}
|
||||
let ext_len = u16::from_be_bytes([body[pos], body[pos + 1]]) as usize;
|
||||
pos += 2;
|
||||
let ext_end = pos + ext_len.min(body.len() - pos);
|
||||
@@ -304,32 +310,32 @@ async fn handle_intercepted(
|
||||
info!(domain, "MITM: intercepting TLS");
|
||||
|
||||
// Get or create server TLS config for this domain
|
||||
let server_config = ca
|
||||
.server_config_for_domain(domain)
|
||||
.await?;
|
||||
let server_config = ca.server_config_for_domain(domain).await?;
|
||||
|
||||
let acceptor = TlsAcceptor::from(server_config);
|
||||
|
||||
// Perform TLS handshake with the client (LS) — 10s timeout
|
||||
let tls_stream = match tokio::time::timeout(
|
||||
std::time::Duration::from_secs(10),
|
||||
acceptor.accept(stream),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(Ok(s)) => s,
|
||||
Ok(Err(e)) => {
|
||||
warn!(domain, error = %e, "MITM: TLS handshake FAILED (client rejected cert?)");
|
||||
return Err(format!("TLS handshake with client failed for {domain}: {e}"));
|
||||
}
|
||||
Err(_) => {
|
||||
warn!(domain, "MITM: TLS handshake TIMED OUT after 10s");
|
||||
return Err(format!("TLS handshake timed out for {domain}"));
|
||||
}
|
||||
};
|
||||
let tls_stream =
|
||||
match tokio::time::timeout(std::time::Duration::from_secs(10), acceptor.accept(stream))
|
||||
.await
|
||||
{
|
||||
Ok(Ok(s)) => s,
|
||||
Ok(Err(e)) => {
|
||||
warn!(domain, error = %e, "MITM: TLS handshake FAILED (client rejected cert?)");
|
||||
return Err(format!(
|
||||
"TLS handshake with client failed for {domain}: {e}"
|
||||
));
|
||||
}
|
||||
Err(_) => {
|
||||
warn!(domain, "MITM: TLS handshake TIMED OUT after 10s");
|
||||
return Err(format!("TLS handshake timed out for {domain}"));
|
||||
}
|
||||
};
|
||||
|
||||
// Check negotiated ALPN protocol
|
||||
let alpn = tls_stream.get_ref().1
|
||||
let alpn = tls_stream
|
||||
.get_ref()
|
||||
.1
|
||||
.alpn_protocol()
|
||||
.map(|p| String::from_utf8_lossy(p).to_string());
|
||||
|
||||
@@ -339,12 +345,7 @@ async fn handle_intercepted(
|
||||
Some("h2") => {
|
||||
// HTTP/2 — use the hyper-based gRPC handler
|
||||
info!(domain, "MITM: routing to HTTP/2 handler (gRPC)");
|
||||
super::h2_handler::handle_h2_connection(
|
||||
tls_stream,
|
||||
domain.to_string(),
|
||||
store,
|
||||
)
|
||||
.await
|
||||
super::h2_handler::handle_h2_connection(tls_stream, domain.to_string(), store).await
|
||||
}
|
||||
_ => {
|
||||
// HTTP/1.1 or no ALPN — use the existing handler
|
||||
@@ -434,7 +435,10 @@ async fn handle_http_over_tls(
|
||||
.await
|
||||
{
|
||||
let out = String::from_utf8_lossy(&output.stdout);
|
||||
if let Some(ip) = out.lines().find(|l| l.parse::<std::net::Ipv4Addr>().is_ok()) {
|
||||
if let Some(ip) = out
|
||||
.lines()
|
||||
.find(|l| l.parse::<std::net::Ipv4Addr>().is_ok())
|
||||
{
|
||||
return format!("{ip}:443");
|
||||
}
|
||||
}
|
||||
@@ -458,7 +462,6 @@ async fn handle_http_over_tls(
|
||||
loop {
|
||||
// ── Read the HTTP request from the client ─────────────────────────
|
||||
let mut request_buf = Vec::with_capacity(1024 * 64);
|
||||
let mut is_our_request = false;
|
||||
|
||||
// 60s timeout on initial read (LS may open connection without sending immediately)
|
||||
const IDLE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(60);
|
||||
@@ -513,7 +516,8 @@ async fn handle_http_over_tls(
|
||||
}
|
||||
|
||||
// Parse the HTTP request to find headers and body
|
||||
let (headers_end, content_length, _is_streaming_request) = parse_http_request_meta(&request_buf);
|
||||
let (headers_end, content_length, _is_streaming_request) =
|
||||
parse_http_request_meta(&request_buf);
|
||||
|
||||
// Try to extract cascade hint from request body
|
||||
let cascade_hint = if headers_end < request_buf.len() {
|
||||
@@ -545,6 +549,27 @@ async fn handle_http_over_tls(
|
||||
"MITM: forwarding LLM request"
|
||||
);
|
||||
|
||||
// ── Block ALL requests when one is already in-flight ─────────
|
||||
// The LS opens multiple connections and sends parallel requests.
|
||||
// When custom tools are active, only the FIRST request should reach
|
||||
// Google. Block everything else with a fake response.
|
||||
if store.is_request_in_flight() {
|
||||
info!("MITM: blocking LS request — another request already in-flight");
|
||||
let fake_response = "HTTP/1.1 200 OK\r\n\
|
||||
Content-Type: text/event-stream\r\n\
|
||||
Transfer-Encoding: chunked\r\n\
|
||||
\r\n";
|
||||
let fake_sse = "data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"Request handled.\"}],\"role\":\"model\"},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":0,\"candidatesTokenCount\":1,\"totalTokenCount\":1}}}\n\ndata: [DONE]\n\n";
|
||||
let chunked_body = super::modify::rechunk(fake_sse.as_bytes());
|
||||
let mut response = fake_response.as_bytes().to_vec();
|
||||
response.extend_from_slice(&chunked_body);
|
||||
if let Err(e) = client.write_all(&response).await {
|
||||
warn!(error = %e, "MITM: failed to write fake response");
|
||||
}
|
||||
let _ = client.flush().await;
|
||||
continue;
|
||||
}
|
||||
|
||||
// ── Request modification ─────────────────────────────────────
|
||||
// Dechunk body → check if agent request → modify → rechunk
|
||||
if modify_requests && body_len > 0 {
|
||||
@@ -565,7 +590,11 @@ async fn handle_http_over_tls(
|
||||
let generation_params = store.get_generation_params().await;
|
||||
let pending_image = store.take_pending_image().await;
|
||||
|
||||
let tool_ctx = if tools.is_some() || !pending_results.is_empty() || generation_params.is_some() || pending_image.is_some() {
|
||||
let tool_ctx = if tools.is_some()
|
||||
|| !pending_results.is_empty()
|
||||
|| generation_params.is_some()
|
||||
|| pending_image.is_some()
|
||||
{
|
||||
Some(super::modify::ToolContext {
|
||||
tools,
|
||||
tool_config,
|
||||
@@ -578,7 +607,9 @@ async fn handle_http_over_tls(
|
||||
None
|
||||
};
|
||||
|
||||
if let Some(modified_body) = super::modify::modify_request(&raw_body, tool_ctx.as_ref()) {
|
||||
if let Some(modified_body) =
|
||||
super::modify::modify_request(&raw_body, tool_ctx.as_ref())
|
||||
{
|
||||
// Rebuild request_buf: headers (with updated Content-Length) + rechunked modified body
|
||||
let new_chunked = super::modify::rechunk(&modified_body);
|
||||
|
||||
@@ -588,39 +619,12 @@ async fn handle_http_over_tls(
|
||||
let mut new_buf = updated_headers.into_bytes();
|
||||
new_buf.extend_from_slice(&new_chunked);
|
||||
request_buf = new_buf;
|
||||
|
||||
// Mark this as our modified request and set in-flight flag
|
||||
is_our_request = true;
|
||||
|
||||
// Mark in-flight IMMEDIATELY — blocks all subsequent requests
|
||||
store.mark_request_in_flight();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Block ALL LS follow-up requests once first is in-flight ──
|
||||
// When custom tools are active, we only need ONE request to Google.
|
||||
// The LS tries to send multiple requests (its own agentic loop +
|
||||
// internal requests on gemini-2.5-flash-lite). Block them ALL
|
||||
// immediately — don't wait for response_complete.
|
||||
let has_tools = store.get_tools().await.is_some();
|
||||
if has_tools && store.is_request_in_flight() && !is_our_request {
|
||||
info!(
|
||||
"MITM: blocking LS follow-up — request already in-flight"
|
||||
);
|
||||
// Return a fake SSE response that makes the LS stop
|
||||
let fake_response = "HTTP/1.1 200 OK\r\n\
|
||||
Content-Type: text/event-stream\r\n\
|
||||
Transfer-Encoding: chunked\r\n\
|
||||
\r\n";
|
||||
let fake_sse = "data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"Request handled.\"}],\"role\":\"model\"},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":0,\"candidatesTokenCount\":1,\"totalTokenCount\":1}}}\n\ndata: [DONE]\n\n";
|
||||
let chunked_body = super::modify::rechunk(fake_sse.as_bytes());
|
||||
let mut response = fake_response.as_bytes().to_vec();
|
||||
response.extend_from_slice(&chunked_body);
|
||||
if let Err(e) = client.write_all(&response).await {
|
||||
warn!(error = %e, "MITM: failed to write fake response");
|
||||
}
|
||||
let _ = client.flush().await;
|
||||
continue; // Skip the real upstream call
|
||||
}
|
||||
} else {
|
||||
debug!(
|
||||
domain,
|
||||
@@ -674,7 +678,10 @@ async fn handle_http_over_tls(
|
||||
};
|
||||
|
||||
let n = match tokio::time::timeout(timeout, conn.read(&mut tmp)).await {
|
||||
Ok(Ok(0)) => { upstream_ok = false; break; }
|
||||
Ok(Ok(0)) => {
|
||||
upstream_ok = false;
|
||||
break;
|
||||
}
|
||||
Ok(Ok(n)) => n,
|
||||
Ok(Err(e)) => {
|
||||
debug!(domain, error = %e, "MITM: upstream read ended");
|
||||
@@ -711,7 +718,9 @@ async fn handle_http_over_tls(
|
||||
if header.name.eq_ignore_ascii_case("content-type") {
|
||||
if let Ok(v) = std::str::from_utf8(header.value) {
|
||||
content_type = v.to_string();
|
||||
if v.contains("text/event-stream") { is_streaming_response = true; }
|
||||
if v.contains("text/event-stream") {
|
||||
is_streaming_response = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
if header.name.eq_ignore_ascii_case("content-length") {
|
||||
@@ -721,12 +730,16 @@ async fn handle_http_over_tls(
|
||||
}
|
||||
if header.name.eq_ignore_ascii_case("connection") {
|
||||
if let Ok(v) = std::str::from_utf8(header.value) {
|
||||
if v.trim().eq_ignore_ascii_case("close") { upstream_ok = false; }
|
||||
if v.trim().eq_ignore_ascii_case("close") {
|
||||
upstream_ok = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
if header.name.eq_ignore_ascii_case("transfer-encoding") {
|
||||
if let Ok(v) = std::str::from_utf8(header.value) {
|
||||
if v.trim().eq_ignore_ascii_case("chunked") { is_chunked = true; }
|
||||
if v.trim().eq_ignore_ascii_case("chunked") {
|
||||
is_chunked = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -749,22 +762,31 @@ async fn handle_http_over_tls(
|
||||
warn!(domain, status = http_status, body = %body_str, "MITM: upstream error response");
|
||||
|
||||
// Parse Google's error JSON: {"error": {"code": N, "message": "...", "status": "..."}}
|
||||
let (message, error_status) = serde_json::from_str::<serde_json::Value>(&body_str)
|
||||
.ok()
|
||||
.and_then(|v| {
|
||||
let err = v.get("error")?;
|
||||
let msg = err.get("message").and_then(|m| m.as_str()).map(|s| s.to_string());
|
||||
let status = err.get("status").and_then(|s| s.as_str()).map(|s| s.to_string());
|
||||
Some((msg, status))
|
||||
})
|
||||
.unwrap_or((None, None));
|
||||
let (message, error_status) =
|
||||
serde_json::from_str::<serde_json::Value>(&body_str)
|
||||
.ok()
|
||||
.and_then(|v| {
|
||||
let err = v.get("error")?;
|
||||
let msg = err
|
||||
.get("message")
|
||||
.and_then(|m| m.as_str())
|
||||
.map(|s| s.to_string());
|
||||
let status = err
|
||||
.get("status")
|
||||
.and_then(|s| s.as_str())
|
||||
.map(|s| s.to_string());
|
||||
Some((msg, status))
|
||||
})
|
||||
.unwrap_or((None, None));
|
||||
|
||||
store.set_upstream_error(super::store::UpstreamError {
|
||||
status: http_status,
|
||||
body: body_str,
|
||||
message,
|
||||
error_status,
|
||||
}).await;
|
||||
store
|
||||
.set_upstream_error(super::store::UpstreamError {
|
||||
status: http_status,
|
||||
body: body_str,
|
||||
message,
|
||||
error_status,
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
// Save body for usage parsing
|
||||
@@ -779,10 +801,15 @@ async fn handle_http_over_tls(
|
||||
if !streaming_acc.function_calls.is_empty() {
|
||||
let calls: Vec<_> = streaming_acc.function_calls.drain(..).collect();
|
||||
for fc in &calls {
|
||||
store.record_function_call(cascade_hint.as_deref(), fc.clone()).await;
|
||||
store
|
||||
.record_function_call(cascade_hint.as_deref(), fc.clone())
|
||||
.await;
|
||||
}
|
||||
store.set_last_function_calls(calls.clone()).await;
|
||||
info!("MITM: stored {} function call(s) from initial body", calls.len());
|
||||
info!(
|
||||
"MITM: stored {} function call(s) from initial body",
|
||||
calls.len()
|
||||
);
|
||||
}
|
||||
|
||||
// Capture response + thinking text + grounding into MitmStore
|
||||
@@ -816,7 +843,9 @@ async fn handle_http_over_tls(
|
||||
}
|
||||
|
||||
if let Some(cl) = response_content_length {
|
||||
if response_body_buf.len() >= cl { break; }
|
||||
if response_body_buf.len() >= cl {
|
||||
break;
|
||||
}
|
||||
}
|
||||
// Check chunked terminator in initial body
|
||||
if is_chunked && has_chunked_terminator(&response_body_buf) {
|
||||
@@ -837,10 +866,15 @@ async fn handle_http_over_tls(
|
||||
if !streaming_acc.function_calls.is_empty() {
|
||||
let calls: Vec<_> = streaming_acc.function_calls.drain(..).collect();
|
||||
for fc in &calls {
|
||||
store.record_function_call(cascade_hint.as_deref(), fc.clone()).await;
|
||||
store
|
||||
.record_function_call(cascade_hint.as_deref(), fc.clone())
|
||||
.await;
|
||||
}
|
||||
store.set_last_function_calls(calls.clone()).await;
|
||||
info!("MITM: stored {} function call(s) from body chunk", calls.len());
|
||||
info!(
|
||||
"MITM: stored {} function call(s) from body chunk",
|
||||
calls.len()
|
||||
);
|
||||
}
|
||||
|
||||
// Capture response + thinking text + grounding into MitmStore
|
||||
@@ -875,7 +909,9 @@ async fn handle_http_over_tls(
|
||||
response_body_buf.extend_from_slice(chunk);
|
||||
|
||||
if let Some(cl) = response_content_length {
|
||||
if response_body_buf.len() >= cl { break; }
|
||||
if response_body_buf.len() >= cl {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if is_chunked && has_chunked_terminator(&response_body_buf) {
|
||||
debug!(domain, "MITM: chunked response complete");
|
||||
@@ -912,11 +948,7 @@ async fn handle_http_over_tls(
|
||||
}
|
||||
|
||||
/// Handle a passthrough connection: transparent TCP tunnel to upstream.
|
||||
async fn handle_passthrough(
|
||||
mut client: TcpStream,
|
||||
domain: &str,
|
||||
port: u16,
|
||||
) -> Result<(), String> {
|
||||
async fn handle_passthrough(mut client: TcpStream, domain: &str, port: u16) -> Result<(), String> {
|
||||
trace!(domain, port, "MITM: transparent tunnel");
|
||||
|
||||
let mut upstream = TcpStream::connect(format!("{domain}:{port}"))
|
||||
@@ -926,7 +958,12 @@ async fn handle_passthrough(
|
||||
// Bidirectional copy
|
||||
match tokio::io::copy_bidirectional(&mut client, &mut upstream).await {
|
||||
Ok((client_to_server, server_to_client)) => {
|
||||
trace!(domain, client_to_server, server_to_client, "MITM: tunnel closed");
|
||||
trace!(
|
||||
domain,
|
||||
client_to_server,
|
||||
server_to_client,
|
||||
"MITM: tunnel closed"
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
trace!(domain, error = %e, "MITM: tunnel error (likely clean close)");
|
||||
@@ -945,7 +982,11 @@ fn has_chunked_terminator(body: &[u8]) -> bool {
|
||||
return false;
|
||||
}
|
||||
// Check last 7 bytes to account for possible trailing whitespace
|
||||
let tail = if body.len() > 7 { &body[body.len() - 7..] } else { body };
|
||||
let tail = if body.len() > 7 {
|
||||
&body[body.len() - 7..]
|
||||
} else {
|
||||
body
|
||||
};
|
||||
// Look for \r\n0\r\n\r\n anywhere in the tail
|
||||
tail.windows(5).any(|w| w == b"0\r\n\r\n")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user