Files
zerogravity/src/mitm/modify.rs
Nikketryhard b1bd57ab5e feat: forward generation params via MITM + add usageMetadata to Gemini
- Add GenerationParams struct to MitmStore for temperature, top_p,
  top_k, max_output_tokens, stop_sequences, frequency/presence_penalty
- MITM modify_request injects params into request.generationConfig
- All 3 endpoints (Completions, Responses, Gemini) store client params
- Add usageMetadata to Gemini sync responses (promptTokenCount,
  candidatesTokenCount, totalTokenCount, thoughtsTokenCount)
- Add generation param fields to GeminiRequest (temperature, topP, etc.)
- Completions stream_options.include_usage emits final usage chunk
- Completions reasoning_tokens in completion_tokens_details
- Update endpoint gap analysis doc (all high-priority gaps resolved)
2026-02-15 14:23:05 -06:00

988 lines
38 KiB
Rust

//! Request body modification for intercepted LLM API calls.
//!
//! Aggressively strips everything except identity and actual conversation
//! from the Gemini API request. No integrity checks exist on the request
//! body — Google validates OAuth, project, model, and JSON structure only.
use regex::Regex;
use serde_json::Value;
use tracing::info;
use super::store::{CapturedFunctionCall, PendingToolResult};
/// Strip ALL tool definitions.
/// Must be true: with tools present, the LS enters full agentic mode
/// (multi-turn tool calls, file searches, etc.) burning quota.
const STRIP_ALL_TOOLS: bool = true;
/// Context for tool injection during request modification.
/// Built from MitmStore data before calling modify_request.
pub struct ToolContext {
/// Gemini-format tool declarations (functionDeclarations).
pub tools: Option<Vec<Value>>,
/// Gemini-format toolConfig.
pub tool_config: Option<Value>,
/// Pending tool results to inject as functionResponse.
pub pending_results: Vec<PendingToolResult>,
/// Last captured function calls for history rewriting.
pub last_calls: Vec<CapturedFunctionCall>,
/// Client-specified generation parameters (temperature, top_p, etc.).
pub generation_params: Option<super::store::GenerationParams>,
}
/// Modify a streamGenerateContent request body in-place.
/// Returns the modified JSON bytes, or None if modification wasn't possible.
pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option<Vec<u8>> {
let mut json: Value = serde_json::from_slice(body).ok()?;
let original_size = body.len();
let mut changes: Vec<String> = Vec::new();
// ── 1. System instruction: keep ONLY <identity>, nuke everything else ──
if let Some(sys) = json
.pointer_mut("/request/systemInstruction/parts/0/text")
.and_then(|v| v.as_str())
.map(|s| s.to_string())
{
let original_len = sys.len();
// Extract <identity>...</identity> block
let identity = extract_xml_section(&sys, "identity");
if let Some(identity_text) = identity {
let mut new_sys = format!("<identity>\n{}\n</identity>", identity_text.trim());
// When no tools are available, explicitly tell the model not to attempt
// function calls. Without this, the model's training causes it to try
// calling tools from its identity context, resulting in MALFORMED_FUNCTION_CALL.
let has_tools = tool_ctx.as_ref().map_or(false, |ctx| ctx.tools.is_some());
if !has_tools {
new_sys.push_str("\n\nIMPORTANT: You have NO tools available. Do not attempt to call any functions or tools. Respond with text only.");
}
let stripped = original_len - new_sys.len();
if stripped > 0 {
changes.push(format!(
"system instruction: keep <identity> only ({original_len}{} chars, -{stripped})",
new_sys.len()
));
json["request"]["systemInstruction"]["parts"][0]["text"] =
Value::String(new_sys);
}
} else {
// No identity tag found — clear the whole thing
changes.push(format!("system instruction: cleared ({original_len} chars)"));
json["request"]["systemInstruction"]["parts"][0]["text"] =
Value::String(String::new());
}
}
// ── 2. Content messages: keep only actual conversation turns ───────────
if let Some(contents) = json
.pointer_mut("/request/contents")
.and_then(|v| v.as_array_mut())
{
let before = contents.len();
// Remove messages that are pure Antigravity context injection
contents.retain(|msg| {
if let Some(text) = msg["parts"][0]["text"].as_str() {
// Strip user_information (OS, workspace paths)
if text.starts_with("<user_information>") {
return false;
}
// Strip user_rules / MEMORY blocks
if text.starts_with("<user_rules>") {
return false;
}
// Strip workflows
if text.starts_with("<workflows>") {
return false;
}
// Strip MCP servers block
if text.starts_with("<mcp_servers>") {
return false;
}
}
true
});
// For remaining messages, strip embedded metadata
for msg in contents.iter_mut() {
if let Some(text) = msg["parts"][0]["text"].as_str().map(|s| s.to_string()) {
let mut modified = text.clone();
// Strip conversation summaries block
if let Some(cleaned) = strip_between(&modified, "# Conversation History\n", "</conversation_summaries>") {
modified = cleaned;
}
// Strip <ADDITIONAL_METADATA> blocks (cursor pos, open files, etc.)
if let Some(cleaned) = strip_xml_section(&modified, "ADDITIONAL_METADATA") {
modified = cleaned;
}
// Strip <EPHEMERAL_MESSAGE> blocks
if let Some(cleaned) = strip_xml_section(&modified, "EPHEMERAL_MESSAGE") {
modified = cleaned;
}
// Strip "Step Id: N\n" prefixes
if modified.starts_with("Step Id:") {
if let Some(newline_pos) = modified.find('\n') {
modified = modified[newline_pos + 1..].to_string();
}
}
// Strip knowledge item blocks
if let Some(cleaned) = strip_between(&modified, "Here are the ", "</knowledge_item>") {
// Only strip if it's about knowledge items
if cleaned.len() < modified.len() && modified.contains("knowledge item") {
modified = cleaned;
}
}
// Clean up excessive whitespace from stripping
let modified = collapse_newlines(&modified);
if modified.len() < text.len() {
msg["parts"][0]["text"] = Value::String(modified);
}
}
}
// Remove now-empty messages
contents.retain(|msg| {
if let Some(text) = msg["parts"][0]["text"].as_str() {
!text.trim().is_empty()
} else {
true
}
});
let removed = before - contents.len();
if removed > 0 {
changes.push(format!("remove {removed}/{before} content messages"));
}
}
// ── 3. Strip LS tools, inject client tools ─────────────────────────────
let mut has_custom_tools = false;
if STRIP_ALL_TOOLS {
if let Some(tools) = json
.pointer_mut("/request/tools")
.and_then(|v| v.as_array_mut())
{
let count = tools.len();
if count > 0 {
tools.clear();
changes.push(format!("strip all {count} LS tools"));
}
// Inject client-provided tools from ToolContext
if let Some(ref ctx) = tool_ctx {
if let Some(ref custom_tools) = ctx.tools {
for tool in custom_tools {
tools.push(tool.clone());
}
has_custom_tools = true;
changes.push(format!("inject {} custom tool group(s)", custom_tools.len()));
}
}
}
}
// ── 3a. When no tools remain, clean up all tool-related config ────────
// The LS sets toolConfig.functionCallingConfig.mode = "VALIDATED" which
// forces Google to attempt function calls even with an empty tools array,
// causing MALFORMED_FUNCTION_CALL in an infinite retry loop.
if STRIP_ALL_TOOLS && !has_custom_tools {
if let Some(req) = json.get_mut("request").and_then(|v| v.as_object_mut()) {
// Remove the empty tools array entirely
if req.get("tools").and_then(|v| v.as_array()).map_or(false, |a| a.is_empty()) {
req.remove("tools");
changes.push("remove empty tools array".to_string());
}
// Remove toolConfig (VALIDATED mode with no tools = MALFORMED_FUNCTION_CALL)
if req.remove("toolConfig").is_some() {
changes.push("remove toolConfig (no tools)".to_string());
}
}
}
// ── 3b. ALWAYS strip old functionCall/functionResponse from history ───
// Even when custom tools are injected, the LS history contains function
// call parts for LS-internal tools we stripped. Google rejects these as
// MALFORMED_FUNCTION_CALL because the referenced tools don't exist.
if STRIP_ALL_TOOLS {
if let Some(contents) = json
.pointer_mut("/request/contents")
.and_then(|v| v.as_array_mut())
{
let mut stripped_fc = 0usize;
for msg in contents.iter_mut() {
if let Some(parts) = msg.get_mut("parts").and_then(|v| v.as_array_mut()) {
let before = parts.len();
parts.retain(|part| {
!part.get("functionCall").is_some()
&& !part.get("functionResponse").is_some()
});
stripped_fc += before - parts.len();
}
}
// Remove messages that became empty after stripping function parts
contents.retain(|msg| {
msg.get("parts")
.and_then(|v| v.as_array())
.map_or(true, |parts| !parts.is_empty())
});
if stripped_fc > 0 {
changes.push(format!("strip {stripped_fc} functionCall/Response parts from history"));
}
}
}
// Inject toolConfig if provided
if let Some(ref ctx) = tool_ctx {
if let Some(ref config) = ctx.tool_config {
if let Some(req) = json.get_mut("request").and_then(|v| v.as_object_mut()) {
req.insert("toolConfig".to_string(), config.clone());
changes.push("inject toolConfig".to_string());
}
}
}
// ── 3b. Rewrite conversation history for tool results ────────────
if let Some(ref ctx) = tool_ctx {
if !ctx.pending_results.is_empty() && !ctx.last_calls.is_empty() {
if let Some(contents) = json
.pointer_mut("/request/contents")
.and_then(|v| v.as_array_mut())
{
// Find the model turn with our fake "Tool call completed" text and replace it
// with the actual functionCall parts
for msg in contents.iter_mut() {
if msg["role"].as_str() == Some("model") {
if let Some(text) = msg["parts"][0]["text"].as_str() {
if text.contains("Tool call completed") || text.contains("Awaiting external tool result") {
// Replace with functionCall parts
let fc_parts: Vec<Value> = ctx.last_calls.iter().map(|fc| {
serde_json::json!({
"functionCall": {
"name": fc.name,
"args": fc.args,
}
})
}).collect();
msg["parts"] = Value::Array(fc_parts);
changes.push("rewrite model turn with functionCall".to_string());
break;
}
}
}
}
// Add functionResponse as a user turn before the last user message
let fn_response_parts: Vec<Value> = ctx.pending_results.iter().map(|r| {
serde_json::json!({
"functionResponse": {
"name": r.name,
"response": r.result,
}
})
}).collect();
let fn_response_turn = serde_json::json!({
"role": "user",
"parts": fn_response_parts,
});
// Insert before the last user message
let last_user_idx = contents.iter().rposition(|msg| {
msg["role"].as_str() == Some("user")
});
if let Some(idx) = last_user_idx {
contents.insert(idx, fn_response_turn);
} else {
contents.push(fn_response_turn);
}
changes.push(format!("inject {} functionResponse(s)", ctx.pending_results.len()));
}
}
}
// ── 4. Inject includeThoughts to capture thinking text ───────────────
// Without this flag, Google only reports thinking token counts
// but doesn't send the thinking text in SSE parts.
{
// Ensure request.generationConfig.thinkingConfig.includeThoughts = true
let request = json.get_mut("request").and_then(|v| v.as_object_mut());
if let Some(req) = request {
let gen_config = req
.entry("generationConfig")
.or_insert_with(|| serde_json::json!({}));
if let Some(gc) = gen_config.as_object_mut() {
let thinking_config = gc
.entry("thinkingConfig")
.or_insert_with(|| serde_json::json!({}));
if let Some(tc) = thinking_config.as_object_mut() {
if !tc.contains_key("includeThoughts") {
tc.insert("includeThoughts".to_string(), Value::Bool(true));
changes.push("inject includeThoughts".to_string());
}
}
}
} else {
// Not wrapped in request — try top-level (public API format)
let gen_config = json.as_object_mut().and_then(|o| {
Some(o.entry("generationConfig")
.or_insert_with(|| serde_json::json!({})))
});
if let Some(gc) = gen_config.and_then(|v| v.as_object_mut()) {
let thinking_config = gc
.entry("thinkingConfig")
.or_insert_with(|| serde_json::json!({}));
if let Some(tc) = thinking_config.as_object_mut() {
if !tc.contains_key("includeThoughts") {
tc.insert("includeThoughts".to_string(), Value::Bool(true));
changes.push("inject includeThoughts (top-level)".to_string());
}
}
}
}
}
// ── 5. Inject client-specified generation parameters ──────────────────
// These override the LS defaults (which are typically absent or conservative).
// Google generationConfig fields: temperature, topP, topK, maxOutputTokens,
// stopSequences, frequencyPenalty, presencePenalty.
if let Some(ref ctx) = tool_ctx {
if let Some(ref gp) = ctx.generation_params {
// Find or create generationConfig (same path as above)
let gc = if let Some(req) = json.get_mut("request").and_then(|v| v.as_object_mut()) {
Some(req.entry("generationConfig")
.or_insert_with(|| serde_json::json!({})))
} else {
json.as_object_mut().map(|o| {
o.entry("generationConfig")
.or_insert_with(|| serde_json::json!({}))
})
};
if let Some(gc) = gc.and_then(|v| v.as_object_mut()) {
let mut injected: Vec<String> = Vec::new();
if let Some(t) = gp.temperature {
gc.insert("temperature".to_string(), serde_json::json!(t));
injected.push(format!("temperature={t}"));
}
if let Some(p) = gp.top_p {
gc.insert("topP".to_string(), serde_json::json!(p));
injected.push(format!("topP={p}"));
}
if let Some(k) = gp.top_k {
gc.insert("topK".to_string(), serde_json::json!(k));
injected.push(format!("topK={k}"));
}
if let Some(m) = gp.max_output_tokens {
gc.insert("maxOutputTokens".to_string(), serde_json::json!(m));
injected.push(format!("maxOutputTokens={m}"));
}
if let Some(ref seqs) = gp.stop_sequences {
gc.insert("stopSequences".to_string(), serde_json::json!(seqs));
injected.push(format!("stopSequences({})", seqs.len()));
}
if let Some(fp) = gp.frequency_penalty {
gc.insert("frequencyPenalty".to_string(), serde_json::json!(fp));
injected.push(format!("frequencyPenalty={fp}"));
}
if let Some(pp) = gp.presence_penalty {
gc.insert("presencePenalty".to_string(), serde_json::json!(pp));
injected.push(format!("presencePenalty={pp}"));
}
if !injected.is_empty() {
changes.push(format!("inject generationConfig: {}", injected.join(", ")));
}
}
}
}
if changes.is_empty() {
return None; // Nothing modified
}
let modified_bytes = serde_json::to_vec(&json).ok()?;
let saved = original_size as i64 - modified_bytes.len() as i64;
let pct = if original_size > 0 {
(saved as f64 / original_size as f64 * 100.0) as i32
} else {
0
};
info!(
original = original_size,
modified = modified_bytes.len(),
saved_bytes = saved,
saved_pct = pct,
"MITM: request modified [{}]",
changes.join(", ")
);
Some(modified_bytes)
}
/// Extract the inner text of an XML-style section.
fn extract_xml_section(text: &str, tag: &str) -> Option<String> {
let open = format!("<{tag}>");
let close = format!("</{tag}>");
let start = text.find(&open)?;
let end = text.find(&close)?;
let inner_start = start + open.len();
if inner_start >= end {
return None;
}
Some(text[inner_start..end].to_string())
}
/// Strip an XML-style section and return the modified text.
fn strip_xml_section(text: &str, tag: &str) -> Option<String> {
let open = format!("<{tag}>");
let close = format!("</{tag}>");
let start = text.find(&open)?;
let end = text.find(&close)?;
let end_pos = end + close.len();
Some(format!("{}{}", &text[..start], &text[end_pos..]))
}
/// Strip everything between two markers (inclusive of markers).
fn strip_between(text: &str, start_marker: &str, end_marker: &str) -> Option<String> {
let start = text.find(start_marker)?;
let end = text.find(end_marker)?;
let end_pos = end + end_marker.len();
// Skip any trailing whitespace after end marker
let rest = text[end_pos..].trim_start();
Some(format!("{}{}", &text[..start], rest))
}
/// Collapse 3+ consecutive newlines into 2.
fn collapse_newlines(text: &str) -> String {
let re = Regex::new(r"\n{3,}").unwrap();
re.replace_all(text, "\n\n").to_string()
}
/// Dechunk an HTTP chunked-encoded body into raw bytes.
pub fn dechunk(data: &[u8]) -> Vec<u8> {
let mut result = Vec::with_capacity(data.len());
let mut pos = 0;
while pos < data.len() {
let line_end = match data[pos..].windows(2).position(|w| w == b"\r\n") {
Some(p) => pos + p,
None => break,
};
let size_str = std::str::from_utf8(&data[pos..line_end])
.unwrap_or("")
.split(';')
.next()
.unwrap_or("")
.trim();
let chunk_size = match usize::from_str_radix(size_str, 16) {
Ok(0) => break,
Ok(n) => n,
Err(_) => break,
};
let data_start = line_end + 2;
let data_end = (data_start + chunk_size).min(data.len());
result.extend_from_slice(&data[data_start..data_end]);
pos = data_end + 2;
}
result
}
/// Re-encode data as a single HTTP chunk + terminal chunk.
pub fn rechunk(data: &[u8]) -> Vec<u8> {
let hex_size = format!("{:x}", data.len());
let mut result = Vec::with_capacity(hex_size.len() + 2 + data.len() + 2 + 5);
result.extend_from_slice(hex_size.as_bytes());
result.extend_from_slice(b"\r\n");
result.extend_from_slice(data);
result.extend_from_slice(b"\r\n0\r\n\r\n");
result
}
// ── OpenAI → Gemini format conversion ────────────────────────────────────────
/// Convert OpenAI tool definitions to Gemini functionDeclarations format.
///
/// OpenAI: `[{"type":"function","function":{"name":"X","description":"Y","parameters":{...}}}]`
/// Gemini: `[{"functionDeclarations":[{"name":"X","description":"Y","parameters":{...}}]}]`
pub fn openai_tools_to_gemini(tools: &[Value]) -> Vec<Value> {
let declarations: Vec<Value> = tools
.iter()
.filter(|t| t["type"].as_str() == Some("function"))
.filter_map(|t| {
let func = t.get("function")?;
let mut decl = serde_json::json!({
"name": func["name"],
"description": func["description"],
});
if let Some(params) = func.get("parameters") {
let cleaned = clean_schema_for_gemini(uppercase_types(params.clone()));
decl["parameters"] = cleaned;
}
Some(decl)
})
.collect();
if declarations.is_empty() {
return vec![];
}
vec![serde_json::json!({"functionDeclarations": declarations})]
}
/// Recursively strip JSON Schema fields that Google's Gemini API doesn't accept.
/// Known unsupported: $schema, additionalProperties, $ref, $defs, default, examples
fn clean_schema_for_gemini(mut val: Value) -> Value {
const STRIP_KEYS: &[&str] = &[
"$schema",
"additionalProperties",
"$ref",
"$defs",
"default",
"examples",
"title",
];
match &mut val {
Value::Object(map) => {
for key in STRIP_KEYS {
map.remove(*key);
}
let keys: Vec<String> = map.keys().cloned().collect();
for key in keys {
if let Some(v) = map.remove(&key) {
map.insert(key, clean_schema_for_gemini(v));
}
}
}
Value::Array(arr) => {
for v in arr.iter_mut() {
*v = clean_schema_for_gemini(std::mem::take(v));
}
}
_ => {}
}
val
}
/// Convert OpenAI tool_choice to Gemini toolConfig format.
///
/// OpenAI: "auto" | "required" | "none" | {"type":"function","function":{"name":"X"}}
/// Gemini: {"functionCallingConfig":{"mode":"AUTO|ANY|NONE","allowedFunctionNames":[...]}}
pub fn openai_tool_choice_to_gemini(choice: &Value) -> Value {
match choice {
Value::String(s) => match s.as_str() {
"auto" => serde_json::json!({"functionCallingConfig": {"mode": "AUTO"}}),
"required" => serde_json::json!({"functionCallingConfig": {"mode": "ANY"}}),
"none" => serde_json::json!({"functionCallingConfig": {"mode": "NONE"}}),
_ => serde_json::json!({"functionCallingConfig": {"mode": "AUTO"}}),
},
Value::Object(obj) => {
if let Some(name) = obj.get("function").and_then(|f| f["name"].as_str()) {
serde_json::json!({
"functionCallingConfig": {
"mode": "ANY",
"allowedFunctionNames": [name]
}
})
} else {
serde_json::json!({"functionCallingConfig": {"mode": "AUTO"}})
}
}
_ => serde_json::json!({"functionCallingConfig": {"mode": "AUTO"}}),
}
}
/// Recursively convert JSON Schema type strings to uppercase (Gemini format).
/// "object" → "OBJECT", "string" → "STRING", etc.
fn uppercase_types(mut val: Value) -> Value {
match &mut val {
Value::Object(map) => {
if let Some(t) = map
.get("type")
.and_then(|v| v.as_str())
.map(|s| s.to_uppercase())
{
map.insert("type".to_string(), Value::String(t));
}
let keys: Vec<String> = map.keys().cloned().collect();
for key in keys {
if let Some(v) = map.remove(&key) {
map.insert(key, uppercase_types(v));
}
}
}
Value::Array(arr) => {
for v in arr.iter_mut() {
*v = uppercase_types(std::mem::take(v));
}
}
_ => {}
}
val
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dechunk_basic() {
let chunked = b"5\r\nhello\r\n6\r\n world\r\n0\r\n\r\n";
let result = dechunk(chunked);
assert_eq!(result, b"hello world");
}
#[test]
fn test_dechunk_single() {
let chunked = b"b\r\nhello world\r\n0\r\n\r\n";
let result = dechunk(chunked);
assert_eq!(result, b"hello world");
}
#[test]
fn test_rechunk() {
let data = b"hello world";
let chunked = rechunk(data);
let expected = b"b\r\nhello world\r\n0\r\n\r\n";
assert_eq!(chunked, expected);
}
#[test]
fn test_dechunk_rechunk_roundtrip() {
let original = b"5\r\nhello\r\n6\r\n world\r\n0\r\n\r\n";
let data = dechunk(original);
let rechunked = rechunk(&data);
let data2 = dechunk(&rechunked);
assert_eq!(data, data2);
}
#[test]
fn test_modify_strips_all_tools() {
let body = serde_json::json!({
"project": "test",
"requestId": "test/1",
"request": {
"contents": [{"role": "user", "parts": [{"text": "hello"}]}],
"tools": [
{"functionDeclarations": [{"name": "view_file", "description": "view", "parameters": {}}]},
{"functionDeclarations": [{"name": "browser_subagent", "description": "browse", "parameters": {}}]},
],
"generationConfig": {}
},
"model": "test"
});
let bytes = serde_json::to_vec(&body).unwrap();
let modified = modify_request(&bytes, None).unwrap();
let result: Value = serde_json::from_slice(&modified).unwrap();
// With no ToolContext, tools should be removed entirely
assert!(result["request"]["tools"].is_null() || result.pointer("/request/tools").is_none(),
"tools should be removed when no custom tools provided");
}
#[test]
fn test_modify_keeps_only_identity() {
let sys_text = "<identity>\nYou are a helpful AI.\n</identity>\n\n<tool_calling>\nUse absolute paths.\n</tool_calling>\n<web_application_development>\nlots of web dev stuff\n</web_application_development>\n<communication_style>\nbe helpful\n</communication_style>";
let body = serde_json::json!({
"project": "test",
"requestId": "test/1",
"request": {
"contents": [{"role": "user", "parts": [{"text": "hello"}]}],
"systemInstruction": {"parts": [{"text": sys_text}]},
"tools": [],
"generationConfig": {}
},
"model": "test"
});
let bytes = serde_json::to_vec(&body).unwrap();
let modified = modify_request(&bytes, None).unwrap();
let result: Value = serde_json::from_slice(&modified).unwrap();
let new_sys = result["request"]["systemInstruction"]["parts"][0]["text"]
.as_str()
.unwrap();
assert!(new_sys.contains("<identity>"));
assert!(new_sys.contains("You are a helpful AI."));
assert!(!new_sys.contains("tool_calling"));
assert!(!new_sys.contains("web_application_development"));
assert!(!new_sys.contains("communication_style"));
}
#[test]
fn test_modify_strips_context_messages() {
let body = serde_json::json!({
"project": "test",
"requestId": "test/1",
"request": {
"contents": [
{"role": "user", "parts": [{"text": "<user_information>\nLinux\n</user_information>"}]},
{"role": "user", "parts": [{"text": "<user_rules>\nno rules\n</user_rules>"}]},
{"role": "user", "parts": [{"text": "<workflows>\nsome workflows\n</workflows>"}]},
{"role": "user", "parts": [{"text": "Step Id: 0\n\n<USER_REQUEST>\nSay hello\n</USER_REQUEST>\n<ADDITIONAL_METADATA>\ncursor stuff\n</ADDITIONAL_METADATA>"}]},
{"role": "model", "parts": [{"text": "Hello!"}]},
],
"tools": [],
"generationConfig": {}
},
"model": "test"
});
let bytes = serde_json::to_vec(&body).unwrap();
let modified = modify_request(&bytes, None).unwrap();
let result: Value = serde_json::from_slice(&modified).unwrap();
let contents = result["request"]["contents"].as_array().unwrap();
// Should have removed user_information, user_rules, workflows (3 messages)
// Kept: USER_REQUEST message (with ADDITIONAL_METADATA stripped) + model response
assert_eq!(contents.len(), 2, "should keep only user request + model response");
// Check USER_REQUEST message had metadata stripped
let user_msg = contents[0]["parts"][0]["text"].as_str().unwrap();
assert!(user_msg.contains("Say hello"), "should keep user request");
assert!(!user_msg.contains("ADDITIONAL_METADATA"), "should strip metadata");
assert!(!user_msg.contains("cursor stuff"), "should strip cursor info");
assert!(!user_msg.starts_with("Step Id:"), "should strip step id");
// Model response kept intact
assert_eq!(contents[1]["parts"][0]["text"].as_str().unwrap(), "Hello!");
}
#[test]
fn test_extract_xml_section() {
let text = "before <identity>\nI am AI\n</identity> after";
let result = extract_xml_section(text, "identity").unwrap();
assert_eq!(result, "\nI am AI\n");
}
#[test]
fn test_strip_xml_section() {
let text = "before <META>\nstuff\n</META> after";
let result = strip_xml_section(text, "META").unwrap();
assert_eq!(result, "before after");
}
#[test]
fn test_strip_between() {
let text = "keep this # Conversation History\nlots of stuff\n</conversation_summaries>\nand this";
let result = strip_between(text, "# Conversation History\n", "</conversation_summaries>").unwrap();
assert_eq!(result, "keep this and this");
}
}
// ─── Response modification ──────────────────────────────────────────────────
/// Rewrite an SSE response chunk to replace `functionCall` parts with text,
#[allow(dead_code)]
/// so the LS doesn't see tool calls for tools it doesn't manage.
///
/// The MITM intercept layer has already captured the function call data
/// (via `parse_streaming_chunk`) before this function runs, so we're not
/// losing any information — just hiding it from the LS.
///
/// Handles HTTP chunked transfer encoding framing (size\r\n...data...\r\n).
///
/// Returns `Some(modified_bytes)` if the chunk was rewritten, `None` if no
/// change was needed.
pub fn modify_response_chunk(chunk: &[u8]) -> Option<Vec<u8>> {
let text = std::str::from_utf8(chunk).ok()?;
// Quick check — no point parsing if no functionCall present
if !text.contains("functionCall") {
return None;
}
// Strategy: find each `data: {json}` SSE event in the raw text (which may
// be wrapped in chunked encoding). Parse the JSON, rewrite functionCall
// parts, and rebuild the chunked frame with updated sizes.
// First, dechunk: extract SSE data lines from chunked encoding
// Chunked format: <hex-size>\r\n<data>\r\n
// We'll work on the whole text, finding "data: " prefixed JSON objects
let mut result = text.to_string();
let mut changed = false;
// Find all `data: {...}` patterns (SSE events with JSON)
// Use a simple approach: find "data: {" and match to the end of JSON
let mut search_from = 0;
while let Some(data_pos) = result[search_from..].find("data: {") {
let abs_pos = search_from + data_pos;
let json_start = abs_pos + 6; // skip "data: "
// Find the end of this JSON object by finding the matching closing brace
if let Some(json_end) = find_json_end(&result[json_start..]) {
let json_str = &result[json_start..json_start + json_end];
if json_str.contains("functionCall") {
if let Ok(mut json) = serde_json::from_str::<Value>(json_str) {
if rewrite_function_calls_in_response(&mut json) {
if let Ok(new_json) = serde_json::to_string(&json) {
// Replace the JSON in the result string
result.replace_range(json_start..json_start + json_end, &new_json);
changed = true;
info!("MITM: rewrote functionCall in response → text placeholder for LS");
search_from = json_start + new_json.len();
continue;
}
}
}
}
search_from = json_start + json_end;
} else {
search_from = json_start;
}
}
if !changed {
return None;
}
// Rechunk: if the original was chunked, we need to recalculate chunk sizes
// The format is: <hex-size>\r\n<payload>\r\n
// We'll rebuild the chunked encoding from scratch
if text.contains("\r\n") && text.chars().next().map_or(false, |c| c.is_ascii_hexdigit()) {
// This looks like chunked encoding — rebuild it
// Extract the payload (everything between first \r\n and last \r\n)
let rechunked = rechunk_response(&result);
Some(rechunked.into_bytes())
} else {
Some(result.into_bytes())
}
}
#[allow(dead_code)]
/// Find the end of a JSON object starting at the given string.
/// Returns the index past the closing brace.
fn find_json_end(s: &str) -> Option<usize> {
let mut depth = 0i32;
let mut in_string = false;
let mut escape = false;
for (i, c) in s.char_indices() {
if escape {
escape = false;
continue;
}
if c == '\\' && in_string {
escape = true;
continue;
}
if c == '"' {
in_string = !in_string;
continue;
}
if in_string {
continue;
}
if c == '{' {
depth += 1;
} else if c == '}' {
depth -= 1;
if depth == 0 {
return Some(i + 1);
}
}
}
None
}
#[allow(dead_code)]
/// Rebuild chunked encoding from a modified response body.
/// Takes the full text (which contains old chunk sizes) and rebuilds
/// with correct sizes.
fn rechunk_response(text: &str) -> String {
// Extract the actual SSE data lines (skip chunk size lines)
let mut payload = String::new();
for line in text.split('\n') {
let trimmed = line.trim_end_matches('\r');
// Skip lines that are purely hex chunk sizes
if trimmed.is_empty() {
continue;
}
if trimmed.chars().all(|c| c.is_ascii_hexdigit()) && !trimmed.is_empty() {
continue;
}
// Skip "0" (chunked terminator)
if trimmed == "0" {
continue;
}
payload.push_str(line);
if !line.ends_with('\n') {
payload.push('\n');
}
}
// Wrap in a single chunk
let payload_bytes = payload.as_bytes();
format!("{:x}\r\n{}\r\n", payload_bytes.len(), payload)
}
/// Rewrite a parsed SSE JSON object: replace `functionCall` parts with
/// text placeholder and change `finishReason` from `MALFORMED_FUNCTION_CALL`
/// or any non-STOP reason to `STOP`.
///
/// Handles both Gemini public API format (`{"candidates":[...]}`) and
/// internal LS format (`{"response":{"candidates":[...]}}`).
#[allow(dead_code)]
fn rewrite_function_calls_in_response(json: &mut Value) -> bool {
let mut changed = false;
// Helper to rewrite candidates array in-place
fn rewrite_candidates(candidates: &mut Vec<Value>) -> bool {
let mut changed = false;
for candidate in candidates.iter_mut() {
if let Some(parts) = candidate
.pointer_mut("/content/parts")
.and_then(|v| v.as_array_mut())
{
for part in parts.iter_mut() {
if part.get("functionCall").is_some() {
*part = serde_json::json!({
"text": "Tool call completed. Awaiting external tool result."
});
changed = true;
}
}
}
if let Some(reason) = candidate.get("finishReason").and_then(|v| v.as_str()) {
if reason != "STOP" {
candidate["finishReason"] = Value::String("STOP".to_string());
changed = true;
}
}
}
changed
}
// Try direct "candidates" first
if let Some(candidates) = json.get_mut("candidates").and_then(|v| v.as_array_mut()) {
changed |= rewrite_candidates(candidates);
}
// Try nested "response.candidates"
if let Some(candidates) = json.pointer_mut("/response/candidates").and_then(|v| v.as_array_mut()) {
changed |= rewrite_candidates(candidates);
}
changed
}