diff --git a/src/mitm/modify.rs b/src/mitm/modify.rs index 94e25bc..0462dec 100644 --- a/src/mitm/modify.rs +++ b/src/mitm/modify.rs @@ -694,3 +694,196 @@ mod tests { assert_eq!(result, "keep this and this"); } } + +// ─── Response modification ────────────────────────────────────────────────── + +/// Rewrite an SSE response chunk to replace `functionCall` parts with text, +/// so the LS doesn't see tool calls for tools it doesn't manage. +/// +/// The MITM intercept layer has already captured the function call data +/// (via `parse_streaming_chunk`) before this function runs, so we're not +/// losing any information — just hiding it from the LS. +/// +/// Handles HTTP chunked transfer encoding framing (size\r\n...data...\r\n). +/// +/// Returns `Some(modified_bytes)` if the chunk was rewritten, `None` if no +/// change was needed. +pub fn modify_response_chunk(chunk: &[u8]) -> Option> { + let text = std::str::from_utf8(chunk).ok()?; + + // Quick check — no point parsing if no functionCall present + if !text.contains("functionCall") { + return None; + } + + // Strategy: find each `data: {json}` SSE event in the raw text (which may + // be wrapped in chunked encoding). Parse the JSON, rewrite functionCall + // parts, and rebuild the chunked frame with updated sizes. + + // First, dechunk: extract SSE data lines from chunked encoding + // Chunked format: \r\n\r\n + // We'll work on the whole text, finding "data: " prefixed JSON objects + let mut result = text.to_string(); + let mut changed = false; + + // Find all `data: {...}` patterns (SSE events with JSON) + // Use a simple approach: find "data: {" and match to the end of JSON + let mut search_from = 0; + while let Some(data_pos) = result[search_from..].find("data: {") { + let abs_pos = search_from + data_pos; + let json_start = abs_pos + 6; // skip "data: " + + // Find the end of this JSON object by finding the matching closing brace + if let Some(json_end) = find_json_end(&result[json_start..]) { + let json_str = &result[json_start..json_start + json_end]; + + if json_str.contains("functionCall") { + if let Ok(mut json) = serde_json::from_str::(json_str) { + if rewrite_function_calls_in_response(&mut json) { + if let Ok(new_json) = serde_json::to_string(&json) { + // Replace the JSON in the result string + result.replace_range(json_start..json_start + json_end, &new_json); + changed = true; + info!("MITM: rewrote functionCall in response → text placeholder for LS"); + search_from = json_start + new_json.len(); + continue; + } + } + } + } + search_from = json_start + json_end; + } else { + search_from = json_start; + } + } + + if !changed { + return None; + } + + // Rechunk: if the original was chunked, we need to recalculate chunk sizes + // The format is: \r\n\r\n + // We'll rebuild the chunked encoding from scratch + if text.contains("\r\n") && text.chars().next().map_or(false, |c| c.is_ascii_hexdigit()) { + // This looks like chunked encoding — rebuild it + // Extract the payload (everything between first \r\n and last \r\n) + let rechunked = rechunk_response(&result); + Some(rechunked.into_bytes()) + } else { + Some(result.into_bytes()) + } +} + +/// Find the end of a JSON object starting at the given string. +/// Returns the index past the closing brace. +fn find_json_end(s: &str) -> Option { + let mut depth = 0i32; + let mut in_string = false; + let mut escape = false; + + for (i, c) in s.char_indices() { + if escape { + escape = false; + continue; + } + if c == '\\' && in_string { + escape = true; + continue; + } + if c == '"' { + in_string = !in_string; + continue; + } + if in_string { + continue; + } + if c == '{' { + depth += 1; + } else if c == '}' { + depth -= 1; + if depth == 0 { + return Some(i + 1); + } + } + } + None +} + +/// Rebuild chunked encoding from a modified response body. +/// Takes the full text (which contains old chunk sizes) and rebuilds +/// with correct sizes. +fn rechunk_response(text: &str) -> String { + // Extract the actual SSE data lines (skip chunk size lines) + let mut payload = String::new(); + for line in text.split('\n') { + let trimmed = line.trim_end_matches('\r'); + // Skip lines that are purely hex chunk sizes + if trimmed.is_empty() { + continue; + } + if trimmed.chars().all(|c| c.is_ascii_hexdigit()) && !trimmed.is_empty() { + continue; + } + // Skip "0" (chunked terminator) + if trimmed == "0" { + continue; + } + payload.push_str(line); + if !line.ends_with('\n') { + payload.push('\n'); + } + } + + // Wrap in a single chunk + let payload_bytes = payload.as_bytes(); + format!("{:x}\r\n{}\r\n", payload_bytes.len(), payload) +} + +/// Rewrite a parsed SSE JSON object: replace `functionCall` parts with +/// text placeholder and change `finishReason` from `MALFORMED_FUNCTION_CALL` +/// or any non-STOP reason to `STOP`. +/// +/// Handles both Gemini public API format (`{"candidates":[...]}`) and +/// internal LS format (`{"response":{"candidates":[...]}}`). +fn rewrite_function_calls_in_response(json: &mut Value) -> bool { + let mut changed = false; + + // Helper to rewrite candidates array in-place + fn rewrite_candidates(candidates: &mut Vec) -> bool { + let mut changed = false; + for candidate in candidates.iter_mut() { + if let Some(parts) = candidate + .pointer_mut("/content/parts") + .and_then(|v| v.as_array_mut()) + { + for part in parts.iter_mut() { + if part.get("functionCall").is_some() { + *part = serde_json::json!({ + "text": "Tool call completed. Awaiting external tool result." + }); + changed = true; + } + } + } + if let Some(reason) = candidate.get("finishReason").and_then(|v| v.as_str()) { + if reason != "STOP" { + candidate["finishReason"] = Value::String("STOP".to_string()); + changed = true; + } + } + } + changed + } + + // Try direct "candidates" first + if let Some(candidates) = json.get_mut("candidates").and_then(|v| v.as_array_mut()) { + changed |= rewrite_candidates(candidates); + } + + // Try nested "response.candidates" + if let Some(candidates) = json.pointer_mut("/response/candidates").and_then(|v| v.as_array_mut()) { + changed |= rewrite_candidates(candidates); + } + + changed +} diff --git a/src/mitm/proxy.rs b/src/mitm/proxy.rs index b1aec73..3f8ef90 100644 --- a/src/mitm/proxy.rs +++ b/src/mitm/proxy.rs @@ -735,15 +735,54 @@ async fn handle_http_over_tls( // Save body for usage parsing response_body_buf.extend_from_slice(&header_buf[hdr_end..]); - // Forward to client immediately - if let Err(e) = client.write_all(&header_buf).await { - warn!(error = %e, "MITM: write to client failed"); - break; - } - + // Parse ORIGINAL initial body for MITM interception + let mut has_function_call = false; if is_streaming_response && hdr_end < header_buf.len() { let body = String::from_utf8_lossy(&header_buf[hdr_end..]); parse_streaming_chunk(&body, &mut streaming_acc); + has_function_call = body.contains("functionCall"); + } + + // If we detected a functionCall AND custom tools are active, + // forge a dummy "STOP" response for the LS so it doesn't + // freak out and retry. The real function call data is already + // captured in MitmStore. + if has_function_call && modify_requests && store.get_tools().await.is_some() { + info!("MITM: functionCall detected → sending dummy STOP response to LS"); + + // Build a clean SSE response the LS will accept + let dummy_json = serde_json::json!({ + "response": { + "candidates": [{ + "content": { + "role": "model", + "parts": [{"text": "Tool call completed. Awaiting external tool result."}] + }, + "finishReason": "STOP" + }], + "modelVersion": "gemini-3-flash" + }, + "metadata": {} + }); + let dummy_data = format!("data: {}\r\n\r\n", serde_json::to_string(&dummy_json).unwrap()); + let dummy_chunk = format!("{:x}\r\n{}\r\n0\r\n\r\n", dummy_data.len(), dummy_data); + + // Send headers (from original response) + dummy body + let headers_only = &header_buf[..hdr_end]; + if let Err(e) = client.write_all(headers_only).await { + warn!(error = %e, "MITM: write headers failed"); + } + if let Err(e) = client.write_all(dummy_chunk.as_bytes()).await { + warn!(error = %e, "MITM: write dummy body failed"); + } + // Done — don't forward the real response + break; + } + + // Normal path: forward headers+body as-is + if let Err(e) = client.write_all(&header_buf).await { + warn!(error = %e, "MITM: write to client failed"); + break; } if let Some(cl) = response_content_length { @@ -759,17 +798,30 @@ async fn handle_http_over_tls( continue; } - // Forward to client immediately + // ── Response body interception ──────────────────────────────── + // Parse ORIGINAL chunk for MITM interception (captures functionCalls) + let mut chunk_has_fc = false; + if is_streaming_response { + let s = String::from_utf8_lossy(chunk); + parse_streaming_chunk(&s, &mut streaming_acc); + chunk_has_fc = s.contains("functionCall"); + } + + // If functionCall in body chunk + custom tools → send dummy + stop + if chunk_has_fc && modify_requests && store.get_tools().await.is_some() { + info!("MITM: functionCall in body chunk → sending chunked terminator to LS"); + // Send the chunked terminator to end the stream + let _ = client.write_all(b"0\r\n\r\n").await; + break; + } + + // Normal path: forward chunk to client (LS) if let Err(e) = client.write_all(chunk).await { warn!(error = %e, "MITM: write to client failed"); break; } response_body_buf.extend_from_slice(chunk); - if is_streaming_response { - let s = String::from_utf8_lossy(chunk); - parse_streaming_chunk(&s, &mut streaming_acc); - } if let Some(cl) = response_content_length { if response_body_buf.len() >= cl { break; } }