fix: block ALL LS follow-up requests across connections

Move the in-flight blocking check to the top of the LLM request flow,
BEFORE request modification. This catches follow-ups on ALL connections
(the LS opens multiple parallel TLS connections). Only the very first
modified request reaches Google — all others get fake STOP responses.

Previously, each new connection independently allowed one request
through before blocking, letting 4-5 requests leak per turn.
This commit is contained in:
Nikketryhard
2026-02-16 00:57:33 -06:00
parent a8f3c8915f
commit 3fdd0368a0
23 changed files with 992 additions and 568 deletions

View File

@@ -11,8 +11,7 @@
use super::ca::MitmCa;
use super::intercept::{
extract_cascade_hint, parse_non_streaming_response, parse_streaming_chunk,
StreamingAccumulator,
extract_cascade_hint, parse_non_streaming_response, parse_streaming_chunk, StreamingAccumulator,
};
use super::store::MitmStore;
use std::sync::Arc;
@@ -54,7 +53,6 @@ pub struct MitmConfig {
pub modify_requests: bool,
}
/// Run the MITM proxy server.
///
/// Returns (port, task_handle) — port it's listening on, handle to abort on shutdown.
@@ -84,7 +82,8 @@ pub async fn run(
let ca = ca.clone();
let store = store.clone();
tokio::spawn(async move {
if let Err(e) = handle_connection(stream, ca, store, modify_requests).await {
if let Err(e) = handle_connection(stream, ca, store, modify_requests).await
{
warn!(error = %e, "MITM connection error");
}
});
@@ -131,8 +130,7 @@ async fn handle_connection(
.await
.map_err(|e| format!("Peek ClientHello: {e}"))?;
let domain = extract_sni(&hello_buf[..n])
.unwrap_or_else(|| "unknown".to_string());
let domain = extract_sni(&hello_buf[..n]).unwrap_or_else(|| "unknown".to_string());
info!(domain, "MITM: transparent redirect (iptables)");
@@ -224,22 +222,30 @@ fn extract_sni(buf: &[u8]) -> Option<String> {
let mut pos = 34; // skip version + random
// Session ID
if pos >= body.len() { return None; }
if pos >= body.len() {
return None;
}
let sid_len = body[pos] as usize;
pos += 1 + sid_len;
// Cipher suites
if pos + 2 > body.len() { return None; }
if pos + 2 > body.len() {
return None;
}
let cs_len = u16::from_be_bytes([body[pos], body[pos + 1]]) as usize;
pos += 2 + cs_len;
// Compression methods
if pos >= body.len() { return None; }
if pos >= body.len() {
return None;
}
let cm_len = body[pos] as usize;
pos += 1 + cm_len;
// Extensions
if pos + 2 > body.len() { return None; }
if pos + 2 > body.len() {
return None;
}
let ext_len = u16::from_be_bytes([body[pos], body[pos + 1]]) as usize;
pos += 2;
let ext_end = pos + ext_len.min(body.len() - pos);
@@ -304,32 +310,32 @@ async fn handle_intercepted(
info!(domain, "MITM: intercepting TLS");
// Get or create server TLS config for this domain
let server_config = ca
.server_config_for_domain(domain)
.await?;
let server_config = ca.server_config_for_domain(domain).await?;
let acceptor = TlsAcceptor::from(server_config);
// Perform TLS handshake with the client (LS) — 10s timeout
let tls_stream = match tokio::time::timeout(
std::time::Duration::from_secs(10),
acceptor.accept(stream),
)
.await
{
Ok(Ok(s)) => s,
Ok(Err(e)) => {
warn!(domain, error = %e, "MITM: TLS handshake FAILED (client rejected cert?)");
return Err(format!("TLS handshake with client failed for {domain}: {e}"));
}
Err(_) => {
warn!(domain, "MITM: TLS handshake TIMED OUT after 10s");
return Err(format!("TLS handshake timed out for {domain}"));
}
};
let tls_stream =
match tokio::time::timeout(std::time::Duration::from_secs(10), acceptor.accept(stream))
.await
{
Ok(Ok(s)) => s,
Ok(Err(e)) => {
warn!(domain, error = %e, "MITM: TLS handshake FAILED (client rejected cert?)");
return Err(format!(
"TLS handshake with client failed for {domain}: {e}"
));
}
Err(_) => {
warn!(domain, "MITM: TLS handshake TIMED OUT after 10s");
return Err(format!("TLS handshake timed out for {domain}"));
}
};
// Check negotiated ALPN protocol
let alpn = tls_stream.get_ref().1
let alpn = tls_stream
.get_ref()
.1
.alpn_protocol()
.map(|p| String::from_utf8_lossy(p).to_string());
@@ -339,12 +345,7 @@ async fn handle_intercepted(
Some("h2") => {
// HTTP/2 — use the hyper-based gRPC handler
info!(domain, "MITM: routing to HTTP/2 handler (gRPC)");
super::h2_handler::handle_h2_connection(
tls_stream,
domain.to_string(),
store,
)
.await
super::h2_handler::handle_h2_connection(tls_stream, domain.to_string(), store).await
}
_ => {
// HTTP/1.1 or no ALPN — use the existing handler
@@ -434,7 +435,10 @@ async fn handle_http_over_tls(
.await
{
let out = String::from_utf8_lossy(&output.stdout);
if let Some(ip) = out.lines().find(|l| l.parse::<std::net::Ipv4Addr>().is_ok()) {
if let Some(ip) = out
.lines()
.find(|l| l.parse::<std::net::Ipv4Addr>().is_ok())
{
return format!("{ip}:443");
}
}
@@ -458,7 +462,6 @@ async fn handle_http_over_tls(
loop {
// ── Read the HTTP request from the client ─────────────────────────
let mut request_buf = Vec::with_capacity(1024 * 64);
let mut is_our_request = false;
// 60s timeout on initial read (LS may open connection without sending immediately)
const IDLE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(60);
@@ -513,7 +516,8 @@ async fn handle_http_over_tls(
}
// Parse the HTTP request to find headers and body
let (headers_end, content_length, _is_streaming_request) = parse_http_request_meta(&request_buf);
let (headers_end, content_length, _is_streaming_request) =
parse_http_request_meta(&request_buf);
// Try to extract cascade hint from request body
let cascade_hint = if headers_end < request_buf.len() {
@@ -545,6 +549,27 @@ async fn handle_http_over_tls(
"MITM: forwarding LLM request"
);
// ── Block ALL requests when one is already in-flight ─────────
// The LS opens multiple connections and sends parallel requests.
// When custom tools are active, only the FIRST request should reach
// Google. Block everything else with a fake response.
if store.is_request_in_flight() {
info!("MITM: blocking LS request — another request already in-flight");
let fake_response = "HTTP/1.1 200 OK\r\n\
Content-Type: text/event-stream\r\n\
Transfer-Encoding: chunked\r\n\
\r\n";
let fake_sse = "data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"Request handled.\"}],\"role\":\"model\"},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":0,\"candidatesTokenCount\":1,\"totalTokenCount\":1}}}\n\ndata: [DONE]\n\n";
let chunked_body = super::modify::rechunk(fake_sse.as_bytes());
let mut response = fake_response.as_bytes().to_vec();
response.extend_from_slice(&chunked_body);
if let Err(e) = client.write_all(&response).await {
warn!(error = %e, "MITM: failed to write fake response");
}
let _ = client.flush().await;
continue;
}
// ── Request modification ─────────────────────────────────────
// Dechunk body → check if agent request → modify → rechunk
if modify_requests && body_len > 0 {
@@ -565,7 +590,11 @@ async fn handle_http_over_tls(
let generation_params = store.get_generation_params().await;
let pending_image = store.take_pending_image().await;
let tool_ctx = if tools.is_some() || !pending_results.is_empty() || generation_params.is_some() || pending_image.is_some() {
let tool_ctx = if tools.is_some()
|| !pending_results.is_empty()
|| generation_params.is_some()
|| pending_image.is_some()
{
Some(super::modify::ToolContext {
tools,
tool_config,
@@ -578,7 +607,9 @@ async fn handle_http_over_tls(
None
};
if let Some(modified_body) = super::modify::modify_request(&raw_body, tool_ctx.as_ref()) {
if let Some(modified_body) =
super::modify::modify_request(&raw_body, tool_ctx.as_ref())
{
// Rebuild request_buf: headers (with updated Content-Length) + rechunked modified body
let new_chunked = super::modify::rechunk(&modified_body);
@@ -588,39 +619,12 @@ async fn handle_http_over_tls(
let mut new_buf = updated_headers.into_bytes();
new_buf.extend_from_slice(&new_chunked);
request_buf = new_buf;
// Mark this as our modified request and set in-flight flag
is_our_request = true;
// Mark in-flight IMMEDIATELY — blocks all subsequent requests
store.mark_request_in_flight();
}
}
}
// ── Block ALL LS follow-up requests once first is in-flight ──
// When custom tools are active, we only need ONE request to Google.
// The LS tries to send multiple requests (its own agentic loop +
// internal requests on gemini-2.5-flash-lite). Block them ALL
// immediately — don't wait for response_complete.
let has_tools = store.get_tools().await.is_some();
if has_tools && store.is_request_in_flight() && !is_our_request {
info!(
"MITM: blocking LS follow-up — request already in-flight"
);
// Return a fake SSE response that makes the LS stop
let fake_response = "HTTP/1.1 200 OK\r\n\
Content-Type: text/event-stream\r\n\
Transfer-Encoding: chunked\r\n\
\r\n";
let fake_sse = "data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"Request handled.\"}],\"role\":\"model\"},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":0,\"candidatesTokenCount\":1,\"totalTokenCount\":1}}}\n\ndata: [DONE]\n\n";
let chunked_body = super::modify::rechunk(fake_sse.as_bytes());
let mut response = fake_response.as_bytes().to_vec();
response.extend_from_slice(&chunked_body);
if let Err(e) = client.write_all(&response).await {
warn!(error = %e, "MITM: failed to write fake response");
}
let _ = client.flush().await;
continue; // Skip the real upstream call
}
} else {
debug!(
domain,
@@ -674,7 +678,10 @@ async fn handle_http_over_tls(
};
let n = match tokio::time::timeout(timeout, conn.read(&mut tmp)).await {
Ok(Ok(0)) => { upstream_ok = false; break; }
Ok(Ok(0)) => {
upstream_ok = false;
break;
}
Ok(Ok(n)) => n,
Ok(Err(e)) => {
debug!(domain, error = %e, "MITM: upstream read ended");
@@ -711,7 +718,9 @@ async fn handle_http_over_tls(
if header.name.eq_ignore_ascii_case("content-type") {
if let Ok(v) = std::str::from_utf8(header.value) {
content_type = v.to_string();
if v.contains("text/event-stream") { is_streaming_response = true; }
if v.contains("text/event-stream") {
is_streaming_response = true;
}
}
}
if header.name.eq_ignore_ascii_case("content-length") {
@@ -721,12 +730,16 @@ async fn handle_http_over_tls(
}
if header.name.eq_ignore_ascii_case("connection") {
if let Ok(v) = std::str::from_utf8(header.value) {
if v.trim().eq_ignore_ascii_case("close") { upstream_ok = false; }
if v.trim().eq_ignore_ascii_case("close") {
upstream_ok = false;
}
}
}
if header.name.eq_ignore_ascii_case("transfer-encoding") {
if let Ok(v) = std::str::from_utf8(header.value) {
if v.trim().eq_ignore_ascii_case("chunked") { is_chunked = true; }
if v.trim().eq_ignore_ascii_case("chunked") {
is_chunked = true;
}
}
}
}
@@ -749,22 +762,31 @@ async fn handle_http_over_tls(
warn!(domain, status = http_status, body = %body_str, "MITM: upstream error response");
// Parse Google's error JSON: {"error": {"code": N, "message": "...", "status": "..."}}
let (message, error_status) = serde_json::from_str::<serde_json::Value>(&body_str)
.ok()
.and_then(|v| {
let err = v.get("error")?;
let msg = err.get("message").and_then(|m| m.as_str()).map(|s| s.to_string());
let status = err.get("status").and_then(|s| s.as_str()).map(|s| s.to_string());
Some((msg, status))
})
.unwrap_or((None, None));
let (message, error_status) =
serde_json::from_str::<serde_json::Value>(&body_str)
.ok()
.and_then(|v| {
let err = v.get("error")?;
let msg = err
.get("message")
.and_then(|m| m.as_str())
.map(|s| s.to_string());
let status = err
.get("status")
.and_then(|s| s.as_str())
.map(|s| s.to_string());
Some((msg, status))
})
.unwrap_or((None, None));
store.set_upstream_error(super::store::UpstreamError {
status: http_status,
body: body_str,
message,
error_status,
}).await;
store
.set_upstream_error(super::store::UpstreamError {
status: http_status,
body: body_str,
message,
error_status,
})
.await;
}
// Save body for usage parsing
@@ -779,10 +801,15 @@ async fn handle_http_over_tls(
if !streaming_acc.function_calls.is_empty() {
let calls: Vec<_> = streaming_acc.function_calls.drain(..).collect();
for fc in &calls {
store.record_function_call(cascade_hint.as_deref(), fc.clone()).await;
store
.record_function_call(cascade_hint.as_deref(), fc.clone())
.await;
}
store.set_last_function_calls(calls.clone()).await;
info!("MITM: stored {} function call(s) from initial body", calls.len());
info!(
"MITM: stored {} function call(s) from initial body",
calls.len()
);
}
// Capture response + thinking text + grounding into MitmStore
@@ -816,7 +843,9 @@ async fn handle_http_over_tls(
}
if let Some(cl) = response_content_length {
if response_body_buf.len() >= cl { break; }
if response_body_buf.len() >= cl {
break;
}
}
// Check chunked terminator in initial body
if is_chunked && has_chunked_terminator(&response_body_buf) {
@@ -837,10 +866,15 @@ async fn handle_http_over_tls(
if !streaming_acc.function_calls.is_empty() {
let calls: Vec<_> = streaming_acc.function_calls.drain(..).collect();
for fc in &calls {
store.record_function_call(cascade_hint.as_deref(), fc.clone()).await;
store
.record_function_call(cascade_hint.as_deref(), fc.clone())
.await;
}
store.set_last_function_calls(calls.clone()).await;
info!("MITM: stored {} function call(s) from body chunk", calls.len());
info!(
"MITM: stored {} function call(s) from body chunk",
calls.len()
);
}
// Capture response + thinking text + grounding into MitmStore
@@ -875,7 +909,9 @@ async fn handle_http_over_tls(
response_body_buf.extend_from_slice(chunk);
if let Some(cl) = response_content_length {
if response_body_buf.len() >= cl { break; }
if response_body_buf.len() >= cl {
break;
}
}
if is_chunked && has_chunked_terminator(&response_body_buf) {
debug!(domain, "MITM: chunked response complete");
@@ -912,11 +948,7 @@ async fn handle_http_over_tls(
}
/// Handle a passthrough connection: transparent TCP tunnel to upstream.
async fn handle_passthrough(
mut client: TcpStream,
domain: &str,
port: u16,
) -> Result<(), String> {
async fn handle_passthrough(mut client: TcpStream, domain: &str, port: u16) -> Result<(), String> {
trace!(domain, port, "MITM: transparent tunnel");
let mut upstream = TcpStream::connect(format!("{domain}:{port}"))
@@ -926,7 +958,12 @@ async fn handle_passthrough(
// Bidirectional copy
match tokio::io::copy_bidirectional(&mut client, &mut upstream).await {
Ok((client_to_server, server_to_client)) => {
trace!(domain, client_to_server, server_to_client, "MITM: tunnel closed");
trace!(
domain,
client_to_server,
server_to_client,
"MITM: tunnel closed"
);
}
Err(e) => {
trace!(domain, error = %e, "MITM: tunnel error (likely clean close)");
@@ -945,7 +982,11 @@ fn has_chunked_terminator(body: &[u8]) -> bool {
return false;
}
// Check last 7 bytes to account for possible trailing whitespace
let tail = if body.len() > 7 { &body[body.len() - 7..] } else { body };
let tail = if body.len() > 7 {
&body[body.len() - 7..]
} else {
body
};
// Look for \r\n0\r\n\r\n anywhere in the tail
tail.windows(5).any(|w| w == b"0\r\n\r\n")
}