Files
zerogravity/src/mitm/modify.rs

1570 lines
57 KiB
Rust

//! Request body modification for intercepted LLM API calls.
//!
//! Aggressively strips everything except identity and actual conversation
//! from the Gemini API request. No integrity checks exist on the request
//! body — Google validates OAuth, project, model, and JSON structure only.
use regex::Regex;
use serde_json::Value;
use tracing::info;
use super::store::{CapturedFunctionCall, PendingImage, PendingToolResult, ToolRound};
/// Strip ALL tool definitions.
/// Must be true: with tools present, the LS enters full agentic mode
/// (multi-turn tool calls, file searches, etc.) burning quota.
const STRIP_ALL_TOOLS: bool = true;
/// Context for tool injection during request modification.
/// Built from MitmStore data before calling modify_request.
pub struct ToolContext {
/// Real user text to replace the "." dot prompt sent to LS.
pub pending_user_text: String,
/// Gemini-format tool declarations (functionDeclarations).
pub tools: Option<Vec<Value>>,
/// Gemini-format toolConfig.
pub tool_config: Option<Value>,
/// Pending tool results to inject as functionResponse.
pub pending_results: Vec<PendingToolResult>,
/// Last captured function calls for history rewriting.
pub last_calls: Vec<CapturedFunctionCall>,
/// Client-specified generation parameters (temperature, top_p, etc.).
pub generation_params: Option<super::store::GenerationParams>,
/// Pending image to inject as inlineData in the user message.
pub pending_image: Option<PendingImage>,
/// Multi-round tool call history. Each entry is a (calls, results) pair
/// from one round of tool use. Preferred over last_calls/pending_results.
pub tool_rounds: Vec<ToolRound>,
}
/// Build a functionCall part JSON, including `thoughtSignature` as a sibling.
/// Google's part structure: `{functionCall: {name, args}, thoughtSignature: "..."}`
/// NOT nested inside functionCall.
fn build_function_call_part(fc: &super::store::CapturedFunctionCall) -> Value {
let mut part = serde_json::json!({
"functionCall": {
"name": fc.name,
"args": fc.args,
}
});
if let Some(ref sig) = fc.thought_signature {
part["thoughtSignature"] = Value::String(sig.clone());
}
part
}
/// Modify a streamGenerateContent request body in-place.
/// Returns the modified JSON bytes, or None if modification wasn't possible.
pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option<Vec<u8>> {
let mut json: Value = serde_json::from_slice(body).ok()?;
let original_size = body.len();
let mut changes: Vec<String> = Vec::new();
// Each phase mutates `json` in place and appends to `changes`.
rewrite_system_instruction(&mut json, &mut changes);
strip_context_messages(&mut json, &mut changes);
replace_dummy_prompt(&mut json, tool_ctx, &mut changes);
manage_tools_and_history(&mut json, tool_ctx, &mut changes);
inject_thinking_config(&mut json, tool_ctx, &mut changes);
inject_generation_params(&mut json, tool_ctx, &mut changes);
inject_pending_image(&mut json, tool_ctx, &mut changes);
if changes.is_empty() {
return None;
}
let modified_bytes = serde_json::to_vec(&json).ok()?;
let saved = original_size as i64 - modified_bytes.len() as i64;
let pct = if original_size > 0 {
(saved as f64 / original_size as f64 * 100.0) as i32
} else {
0
};
info!(
original = original_size,
modified = modified_bytes.len(),
saved_bytes = saved,
saved_pct = pct,
"MITM: request modified [{}]",
changes.join(", ")
);
Some(modified_bytes)
}
// ─── modify_request sub-functions ────────────────────────────────────────────
/// Rewrite systemInstruction to CLIProxyAPI-style multi-part format.
///
/// Extracts `<identity>` block, builds:
/// part[0] = identity text
/// part[1] = "Please ignore following [ignore]<identity>[/ignore]"
/// part[2..] = remaining original parts
fn rewrite_system_instruction(json: &mut Value, changes: &mut Vec<String>) {
let sys = match json
.pointer_mut("/request/systemInstruction/parts/0/text")
.and_then(|v| v.as_str())
.map(|s| s.to_string())
{
Some(s) => s,
None => return,
};
let original_len = sys.len();
if let Some(identity_text) = extract_xml_section(&sys, "identity") {
let identity_clean = identity_text.trim().to_string();
let part0 = identity_clean.clone();
let part1 = format!(
"Please ignore following [ignore]{}[/ignore]",
identity_clean
);
let mut extra_parts: Vec<Value> = json
.pointer("/request/systemInstruction/parts")
.and_then(|v| v.as_array())
.map(|parts| parts.iter().skip(1).cloned().collect())
.unwrap_or_default();
let mut new_parts = vec![
serde_json::json!({"text": part0}),
serde_json::json!({"text": part1}),
];
new_parts.append(&mut extra_parts);
json["request"]["systemInstruction"]["parts"] = Value::Array(new_parts);
let new_len = part0.len() + part1.len();
if original_len > new_len {
changes.push(format!(
"system instruction: CLIProxyAPI-style rewrite ({original_len}{new_len} chars)"
));
}
} else {
changes.push(format!(
"system instruction: cleared ({original_len} chars)"
));
json["request"]["systemInstruction"]["parts"][0]["text"] = Value::String(String::new());
}
}
/// Strip Antigravity-injected context messages and inline metadata.
///
/// Removes entire messages that are pure context (user_information, user_rules,
/// workflows, mcp_servers) and strips embedded metadata from remaining messages
/// (conversation summaries, ADDITIONAL_METADATA, EPHEMERAL_MESSAGE, cid markers,
/// Step Id prefixes, knowledge items). Also collapses excessive newlines.
fn strip_context_messages(json: &mut Value, changes: &mut Vec<String>) {
let contents = match json
.pointer_mut("/request/contents")
.and_then(|v| v.as_array_mut())
{
Some(c) => c,
None => return,
};
let before = contents.len();
// Phase 1: Remove whole messages that are pure context injection
contents.retain(|msg| {
// Always keep messages with image/binary data
if let Some(parts) = msg["parts"].as_array() {
if parts.iter().any(|p| p.get("inlineData").is_some()) {
return true;
}
}
if let Some(text) = msg["parts"][0]["text"].as_str() {
if text.starts_with("<user_information>")
|| text.starts_with("<user_rules>")
|| text.starts_with("<workflows>")
|| text.starts_with("<mcp_servers>")
{
return false;
}
}
true
});
// Phase 2: Strip embedded metadata from remaining messages
for msg in contents.iter_mut() {
let text = match msg["parts"][0]["text"].as_str().map(|s| s.to_string()) {
Some(t) => t,
None => continue,
};
let mut m = text.clone();
// Conversation summaries
if let Some(c) = strip_between(&m, "# Conversation History\n", "</conversation_summaries>")
{
m = c;
}
// <ADDITIONAL_METADATA> and <EPHEMERAL_MESSAGE>
if let Some(c) = strip_xml_section(&m, "ADDITIONAL_METADATA") {
m = c;
}
if let Some(c) = strip_xml_section(&m, "EPHEMERAL_MESSAGE") {
m = c;
}
// <cid:UUID> markers
while let Some(start) = m.find("<cid:") {
if let Some(end) = m[start..].find('>') {
m = format!("{}{}", &m[..start], &m[start + end + 1..]);
} else {
break;
}
}
// "Step Id: N\n" prefixes
if m.starts_with("Step Id:") {
if let Some(nl) = m.find('\n') {
m = m[nl + 1..].to_string();
}
}
// Knowledge item blocks
if let Some(c) = strip_between(&m, "Here are the ", "</knowledge_item>") {
if c.len() < m.len() && m.contains("knowledge item") {
m = c;
}
}
let m = collapse_newlines(&m);
if m.len() < text.len() {
msg["parts"][0]["text"] = Value::String(m);
}
}
// Phase 3: Remove now-empty messages (preserve image parts)
contents.retain(|msg| {
if let Some(parts) = msg["parts"].as_array() {
if parts.iter().any(|p| p.get("inlineData").is_some()) {
return true;
}
}
msg["parts"][0]["text"]
.as_str()
.is_none_or(|t| !t.trim().is_empty())
});
let removed = before - contents.len();
if removed > 0 {
changes.push(format!("remove {removed}/{before} content messages"));
}
}
/// Replace dummy ".<cid:UUID>" prompt with real user text from the ToolContext.
///
/// The LS receives "." as the user prompt. Antigravity wraps it in
/// `<USER_REQUEST>...</USER_REQUEST>` tags. This function swaps the dot for the
/// actual user text before sending to Google.
fn replace_dummy_prompt(
json: &mut Value,
tool_ctx: Option<&ToolContext>,
changes: &mut Vec<String>,
) {
let ctx = match tool_ctx {
Some(c) if !c.pending_user_text.is_empty() => c,
_ => return,
};
let contents = match json
.pointer_mut("/request/contents")
.and_then(|v| v.as_array_mut())
{
Some(c) => c,
None => return,
};
for msg in contents.iter_mut() {
let is_user = msg
.get("role")
.and_then(|r| r.as_str())
.is_none_or(|r| r == "user");
if !is_user {
continue;
}
let text_val = match msg.pointer_mut("/parts/0/text") {
Some(v) => v,
None => continue,
};
let old = text_val.as_str().unwrap_or("");
let is_dot_in_wrapper = old.contains("<USER_REQUEST>")
&& extract_xml_section(old, "USER_REQUEST").is_some_and(|inner| {
let t = inner.trim();
t == "." || t.starts_with(".<cid:")
});
let is_bare_dot =
old.trim() == "." || (old.trim().starts_with(".<cid:") && old.trim().ends_with(">"));
if is_dot_in_wrapper {
*text_val = Value::String(format!(
"\n<USER_REQUEST>\n{}\n</USER_REQUEST>\n",
ctx.pending_user_text
));
changes.push(format!(
"replace dummy prompt in USER_REQUEST wrapper ({} chars)",
ctx.pending_user_text.len()
));
return;
} else if is_bare_dot {
*text_val = Value::String(ctx.pending_user_text.clone());
changes.push(format!(
"replace bare dummy prompt ({} chars)",
ctx.pending_user_text.len()
));
return;
}
}
}
/// Strip LS tools, inject client tools, clean up functionCall history, and
/// rewrite conversation history with tool call/response pairs.
fn manage_tools_and_history(
json: &mut Value,
tool_ctx: Option<&ToolContext>,
changes: &mut Vec<String>,
) {
let mut has_custom_tools = false;
// ── Strip LS tools, inject client tools ──────────────────────────────
if STRIP_ALL_TOOLS {
if let Some(tools) = json
.pointer_mut("/request/tools")
.and_then(|v| v.as_array_mut())
{
let count = tools.len();
if count > 0 {
tools.clear();
changes.push(format!("strip all {count} LS tools"));
}
if let Some(ctx) = tool_ctx {
if let Some(ref custom_tools) = ctx.tools {
for tool in custom_tools {
tools.push(tool.clone());
}
has_custom_tools = true;
changes.push(format!(
"inject {} custom tool group(s)",
custom_tools.len()
));
// Override VALIDATED → AUTO for custom tools
if let Some(req) = json.get_mut("request").and_then(|v| v.as_object_mut()) {
let has_validated = req
.get("toolConfig")
.and_then(|tc| tc.pointer("/functionCallingConfig/mode"))
.and_then(|m| m.as_str())
== Some("VALIDATED");
if has_validated {
req.insert(
"toolConfig".to_string(),
serde_json::json!({"functionCallingConfig": {"mode": "AUTO"}}),
);
changes.push("override toolConfig VALIDATED → AUTO".to_string());
}
}
}
}
}
}
// ── Clean up when no tools remain ────────────────────────────────────
if STRIP_ALL_TOOLS && !has_custom_tools {
if let Some(req) = json.get_mut("request").and_then(|v| v.as_object_mut()) {
if req
.get("tools")
.and_then(|v| v.as_array())
.is_some_and(|a| a.is_empty())
{
req.remove("tools");
changes.push("remove empty tools array".to_string());
}
if req.remove("toolConfig").is_some() {
changes.push("remove toolConfig (no tools)".to_string());
}
}
}
// ── Strip old functionCall/functionResponse from history ──────────────
if STRIP_ALL_TOOLS {
let custom_tool_names: std::collections::HashSet<String> = tool_ctx
.as_ref()
.and_then(|ctx| ctx.tools.as_ref())
.map(|tools| {
tools
.iter()
.filter_map(|t| t["functionDeclarations"].as_array())
.flatten()
.filter_map(|decl| decl["name"].as_str().map(|s| s.to_string()))
.collect()
})
.unwrap_or_default();
if let Some(contents) = json
.pointer_mut("/request/contents")
.and_then(|v| v.as_array_mut())
{
let mut stripped_fc = 0usize;
for msg in contents.iter_mut() {
if let Some(parts) = msg.get_mut("parts").and_then(|v| v.as_array_mut()) {
let before = parts.len();
parts.retain(|part| {
if let Some(fc) = part.get("functionCall") {
return fc
.get("name")
.and_then(|v| v.as_str())
.is_some_and(|n| custom_tool_names.contains(n));
}
if let Some(fr) = part.get("functionResponse") {
return fr
.get("name")
.and_then(|v| v.as_str())
.is_some_and(|n| custom_tool_names.contains(n));
}
true
});
stripped_fc += before - parts.len();
}
}
contents.retain(|msg| {
msg.get("parts")
.and_then(|v| v.as_array())
.is_none_or(|p| !p.is_empty())
});
if stripped_fc > 0 {
changes.push(format!(
"strip {stripped_fc} functionCall/Response parts from history"
));
}
}
}
// ── Inject toolConfig if provided ────────────────────────────────────
if let Some(ctx) = tool_ctx {
if let Some(ref config) = ctx.tool_config {
if let Some(req) = json.get_mut("request").and_then(|v| v.as_object_mut()) {
req.insert("toolConfig".to_string(), config.clone());
changes.push("inject toolConfig".to_string());
}
}
}
// ── Rewrite conversation history for tool results ────────────────────
rewrite_tool_rounds(json, tool_ctx, changes);
}
/// Rewrite conversation history: replace placeholder model turns with real
/// functionCall parts and inject functionResponse user turns.
fn rewrite_tool_rounds(
json: &mut Value,
tool_ctx: Option<&ToolContext>,
changes: &mut Vec<String>,
) {
let ctx = match tool_ctx {
Some(c) => c,
None => return,
};
let rounds = if !ctx.tool_rounds.is_empty() {
ctx.tool_rounds.clone()
} else if !ctx.pending_results.is_empty() && !ctx.last_calls.is_empty() {
vec![ToolRound {
calls: ctx.last_calls.clone(),
results: ctx.pending_results.clone(),
}]
} else {
return;
};
let contents = match json
.pointer_mut("/request/contents")
.and_then(|v| v.as_array_mut())
{
Some(c) => c,
None => return,
};
// Phase 1: find model turns with placeholder text
let mut rewrites: Vec<(usize, usize)> = Vec::new();
let mut round_idx = 0;
for (i, msg) in contents.iter().enumerate() {
if round_idx >= rounds.len() {
break;
}
if msg["role"].as_str() == Some("model") {
if let Some(text) = msg["parts"][0]["text"].as_str() {
if text.contains("Tool call completed")
|| text.contains("Awaiting external tool result")
{
rewrites.push((i, round_idx));
round_idx += 1;
}
}
}
}
// Phase 2: apply rewrites
let mut insert_offset = 0;
for (content_idx, round_idx) in &rewrites {
let actual_idx = *content_idx + insert_offset;
let round = &rounds[*round_idx];
let fc_parts: Vec<Value> = round.calls.iter().map(build_function_call_part).collect();
contents[actual_idx]["parts"] = Value::Array(fc_parts);
if !round.results.is_empty() {
let fr_parts: Vec<Value> = round.results.iter()
.map(|r| serde_json::json!({"functionResponse": {"name": r.name, "response": r.result}}))
.collect();
contents.insert(
actual_idx + 1,
serde_json::json!({"role": "user", "parts": fr_parts}),
);
insert_offset += 1;
}
}
if !rewrites.is_empty() {
changes.push(format!(
"rewrite {} tool round(s) in history",
rewrites.len()
));
} else {
// Append as new messages (no existing model turns to rewrite)
let insert_pos = contents.len();
let mut offset = 0;
for round in &rounds {
let fc_parts: Vec<Value> = round.calls.iter().map(build_function_call_part).collect();
contents.insert(
insert_pos + offset,
serde_json::json!({"role": "model", "parts": fc_parts}),
);
offset += 1;
if !round.results.is_empty() {
let fr_parts: Vec<Value> = round.results.iter()
.map(|r| serde_json::json!({"functionResponse": {"name": r.name, "response": r.result}}))
.collect();
contents.insert(
insert_pos + offset,
serde_json::json!({"role": "user", "parts": fr_parts}),
);
offset += 1;
}
}
changes.push(format!(
"append {} tool round(s) as functionCall/Response pairs (no model turns found)",
rounds.len()
));
}
}
/// Inject `includeThoughts` and `thinkingLevel` into generationConfig.
fn inject_thinking_config(
json: &mut Value,
tool_ctx: Option<&ToolContext>,
changes: &mut Vec<String>,
) {
let reasoning_effort = tool_ctx
.and_then(|ctx| ctx.generation_params.as_ref())
.and_then(|gp| gp.reasoning_effort.clone());
// Helper: inject into a thinkingConfig object
let inject =
|tc: &mut serde_json::Map<String, Value>, changes: &mut Vec<String>, suffix: &str| {
if !tc.contains_key("includeThoughts") {
tc.insert("includeThoughts".to_string(), Value::Bool(true));
changes.push(format!("inject includeThoughts{suffix}"));
}
if let Some(ref effort) = reasoning_effort {
tc.insert("thinkingLevel".to_string(), Value::String(effort.clone()));
changes.push(format!("inject thinkingLevel={effort}{suffix}"));
}
};
if let Some(req) = json.get_mut("request").and_then(|v| v.as_object_mut()) {
let gc = req
.entry("generationConfig")
.or_insert_with(|| serde_json::json!({}));
if let Some(gc) = gc.as_object_mut() {
let tc = gc
.entry("thinkingConfig")
.or_insert_with(|| serde_json::json!({}));
if let Some(tc) = tc.as_object_mut() {
inject(tc, changes, "");
}
}
} else if let Some(o) = json.as_object_mut() {
let gc = o
.entry("generationConfig")
.or_insert_with(|| serde_json::json!({}));
if let Some(gc) = gc.as_object_mut() {
let tc = gc
.entry("thinkingConfig")
.or_insert_with(|| serde_json::json!({}));
if let Some(tc) = tc.as_object_mut() {
inject(tc, changes, " (top-level)");
}
}
}
}
/// Inject client-specified generation parameters (temperature, topP, etc.).
fn inject_generation_params(
json: &mut Value,
tool_ctx: Option<&ToolContext>,
changes: &mut Vec<String>,
) {
let gp = match tool_ctx.and_then(|ctx| ctx.generation_params.as_ref()) {
Some(gp) => gp,
None => return,
};
let gc = if let Some(req) = json.get_mut("request").and_then(|v| v.as_object_mut()) {
Some(
req.entry("generationConfig")
.or_insert_with(|| serde_json::json!({})),
)
} else {
json.as_object_mut().map(|o| {
o.entry("generationConfig")
.or_insert_with(|| serde_json::json!({}))
})
};
let gc = match gc.and_then(|v| v.as_object_mut()) {
Some(gc) => gc,
None => return,
};
let mut injected: Vec<String> = Vec::new();
if let Some(t) = gp.temperature {
gc.insert("temperature".into(), serde_json::json!(t));
injected.push(format!("temperature={t}"));
}
if let Some(p) = gp.top_p {
gc.insert("topP".into(), serde_json::json!(p));
injected.push(format!("topP={p}"));
}
if let Some(k) = gp.top_k {
gc.insert("topK".into(), serde_json::json!(k));
injected.push(format!("topK={k}"));
}
if let Some(m) = gp.max_output_tokens {
gc.insert("maxOutputTokens".into(), serde_json::json!(m));
injected.push(format!("maxOutputTokens={m}"));
}
if let Some(ref seqs) = gp.stop_sequences {
gc.insert("stopSequences".into(), serde_json::json!(seqs));
injected.push(format!("stopSequences({})", seqs.len()));
}
if let Some(fp) = gp.frequency_penalty {
gc.insert("frequencyPenalty".into(), serde_json::json!(fp));
injected.push(format!("frequencyPenalty={fp}"));
}
if let Some(pp) = gp.presence_penalty {
gc.insert("presencePenalty".into(), serde_json::json!(pp));
injected.push(format!("presencePenalty={pp}"));
}
if let Some(ref mime) = gp.response_mime_type {
gc.insert("responseMimeType".into(), serde_json::json!(mime));
injected.push(format!("responseMimeType={mime}"));
}
if let Some(ref schema) = gp.response_schema {
gc.insert("responseSchema".into(), schema.clone());
injected.push("responseSchema=<schema>".to_string());
}
if !injected.is_empty() {
changes.push(format!("inject generationConfig: {}", injected.join(", ")));
}
}
/// Inject a pending image as inlineData into the last user message.
fn inject_pending_image(
json: &mut Value,
tool_ctx: Option<&ToolContext>,
changes: &mut Vec<String>,
) {
let img = match tool_ctx.and_then(|ctx| ctx.pending_image.as_ref()) {
Some(img) => img,
None => return,
};
let contents = match json
.pointer_mut("/request/contents")
.and_then(|v| v.as_array_mut())
{
Some(c) => c,
None => return,
};
for msg in contents.iter_mut().rev() {
if msg["role"].as_str() != Some("user") {
continue;
}
if let Some(parts) = msg.get_mut("parts").and_then(|v| v.as_array_mut()) {
parts.push(serde_json::json!({
"inlineData": { "mimeType": img.mime_type, "data": img.base64_data }
}));
changes.push(format!(
"inject image ({}; {} bytes base64)",
img.mime_type,
img.base64_data.len()
));
return;
}
}
tracing::warn!("MITM: pending image but no user message found to inject into");
}
/// Extract the inner text of an XML-style section.
fn extract_xml_section(text: &str, tag: &str) -> Option<String> {
let open = format!("<{tag}>");
let close = format!("</{tag}>");
let start = text.find(&open)?;
let end = text.find(&close)?;
let inner_start = start + open.len();
if inner_start >= end {
return None;
}
Some(text[inner_start..end].to_string())
}
/// Strip an XML-style section and return the modified text.
fn strip_xml_section(text: &str, tag: &str) -> Option<String> {
let open = format!("<{tag}>");
let close = format!("</{tag}>");
let start = text.find(&open)?;
let end = text.find(&close)?;
let end_pos = end + close.len();
Some(format!("{}{}", &text[..start], &text[end_pos..]))
}
/// Strip everything between two markers (inclusive of markers).
fn strip_between(text: &str, start_marker: &str, end_marker: &str) -> Option<String> {
let start = text.find(start_marker)?;
let end = text.find(end_marker)?;
let end_pos = end + end_marker.len();
// Skip any trailing whitespace after end marker
let rest = text[end_pos..].trim_start();
Some(format!("{}{}", &text[..start], rest))
}
/// Collapse 3+ consecutive newlines into 2.
fn collapse_newlines(text: &str) -> String {
static RE: std::sync::LazyLock<Regex> =
std::sync::LazyLock::new(|| Regex::new(r"\n{3,}").unwrap());
RE.replace_all(text, "\n\n").to_string()
}
/// Dechunk an HTTP chunked-encoded body into raw bytes.
pub fn dechunk(data: &[u8]) -> Vec<u8> {
let mut result = Vec::with_capacity(data.len());
let mut pos = 0;
while pos < data.len() {
let line_end = match data[pos..].windows(2).position(|w| w == b"\r\n") {
Some(p) => pos + p,
None => break,
};
let size_str = std::str::from_utf8(&data[pos..line_end])
.unwrap_or("")
.split(';')
.next()
.unwrap_or("")
.trim();
let chunk_size = match usize::from_str_radix(size_str, 16) {
Ok(0) => break,
Ok(n) => n,
Err(_) => break,
};
let data_start = line_end + 2;
let data_end = (data_start + chunk_size).min(data.len());
result.extend_from_slice(&data[data_start..data_end]);
pos = data_end + 2;
}
result
}
/// Re-encode data as a single HTTP chunk + terminal chunk.
pub fn rechunk(data: &[u8]) -> Vec<u8> {
let hex_size = format!("{:x}", data.len());
let mut result = Vec::with_capacity(hex_size.len() + 2 + data.len() + 2 + 5);
result.extend_from_slice(hex_size.as_bytes());
result.extend_from_slice(b"\r\n");
result.extend_from_slice(data);
result.extend_from_slice(b"\r\n0\r\n\r\n");
result
}
// ── OpenAI → Gemini format conversion ────────────────────────────────────────
/// Convert OpenAI tool definitions to Gemini functionDeclarations format.
///
/// OpenAI: `[{"type":"function","function":{"name":"X","description":"Y","parameters":{...}}}]`
/// Gemini: `[{"functionDeclarations":[{"name":"X","description":"Y","parameters":{...}}]}]`
pub fn openai_tools_to_gemini(tools: &[Value]) -> Vec<Value> {
let declarations: Vec<Value> = tools
.iter()
.filter(|t| t["type"].as_str() == Some("function"))
.filter_map(|t| {
let func = t.get("function")?;
let mut decl = serde_json::json!({
"name": func["name"],
"description": func["description"],
});
if let Some(params) = func.get("parameters") {
let cleaned = clean_schema_for_gemini(uppercase_types(params.clone()));
decl["parameters"] = cleaned;
}
Some(decl)
})
.collect();
if declarations.is_empty() {
return vec![];
}
vec![serde_json::json!({"functionDeclarations": declarations})]
}
/// Recursively strip JSON Schema fields that Google's Gemini API doesn't accept.
/// Known unsupported: $schema, additionalProperties, $ref, $defs, default, examples
fn clean_schema_for_gemini(mut val: Value) -> Value {
const STRIP_KEYS: &[&str] = &[
"$schema",
"additionalProperties",
"$ref",
"$defs",
"default",
"examples",
"title",
];
match &mut val {
Value::Object(map) => {
for key in STRIP_KEYS {
map.remove(*key);
}
let keys: Vec<String> = map.keys().cloned().collect();
for key in keys {
if let Some(v) = map.remove(&key) {
map.insert(key, clean_schema_for_gemini(v));
}
}
}
Value::Array(arr) => {
for v in arr.iter_mut() {
*v = clean_schema_for_gemini(std::mem::take(v));
}
}
_ => {}
}
val
}
/// Convert OpenAI tool_choice to Gemini toolConfig format.
///
/// OpenAI: "auto" | "required" | "none" | {"type":"function","function":{"name":"X"}}
/// Gemini: {"functionCallingConfig":{"mode":"AUTO|ANY|NONE","allowedFunctionNames":[...]}}
pub fn openai_tool_choice_to_gemini(choice: &Value) -> Value {
match choice {
Value::String(s) => match s.as_str() {
"auto" => serde_json::json!({"functionCallingConfig": {"mode": "AUTO"}}),
"required" => serde_json::json!({"functionCallingConfig": {"mode": "ANY"}}),
"none" => serde_json::json!({"functionCallingConfig": {"mode": "NONE"}}),
_ => serde_json::json!({"functionCallingConfig": {"mode": "AUTO"}}),
},
Value::Object(obj) => {
if let Some(name) = obj.get("function").and_then(|f| f["name"].as_str()) {
serde_json::json!({
"functionCallingConfig": {
"mode": "ANY",
"allowedFunctionNames": [name]
}
})
} else {
serde_json::json!({"functionCallingConfig": {"mode": "AUTO"}})
}
}
_ => serde_json::json!({"functionCallingConfig": {"mode": "AUTO"}}),
}
}
/// Recursively convert JSON Schema type strings to uppercase (Gemini format).
/// "object" → "OBJECT", "string" → "STRING", etc.
fn uppercase_types(mut val: Value) -> Value {
match &mut val {
Value::Object(map) => {
if let Some(t) = map
.get("type")
.and_then(|v| v.as_str())
.map(|s| s.to_uppercase())
{
map.insert("type".to_string(), Value::String(t));
}
let keys: Vec<String> = map.keys().cloned().collect();
for key in keys {
if let Some(v) = map.remove(&key) {
map.insert(key, uppercase_types(v));
}
}
}
Value::Array(arr) => {
for v in arr.iter_mut() {
*v = uppercase_types(std::mem::take(v));
}
}
_ => {}
}
val
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dechunk_basic() {
let chunked = b"5\r\nhello\r\n6\r\n world\r\n0\r\n\r\n";
let result = dechunk(chunked);
assert_eq!(result, b"hello world");
}
#[test]
fn test_dechunk_single() {
let chunked = b"b\r\nhello world\r\n0\r\n\r\n";
let result = dechunk(chunked);
assert_eq!(result, b"hello world");
}
#[test]
fn test_rechunk() {
let data = b"hello world";
let chunked = rechunk(data);
let expected = b"b\r\nhello world\r\n0\r\n\r\n";
assert_eq!(chunked, expected);
}
#[test]
fn test_dechunk_rechunk_roundtrip() {
let original = b"5\r\nhello\r\n6\r\n world\r\n0\r\n\r\n";
let data = dechunk(original);
let rechunked = rechunk(&data);
let data2 = dechunk(&rechunked);
assert_eq!(data, data2);
}
#[test]
fn test_modify_strips_all_tools() {
let body = serde_json::json!({
"project": "test",
"requestId": "test/1",
"request": {
"contents": [{"role": "user", "parts": [{"text": "hello"}]}],
"tools": [
{"functionDeclarations": [{"name": "view_file", "description": "view", "parameters": {}}]},
{"functionDeclarations": [{"name": "browser_subagent", "description": "browse", "parameters": {}}]},
],
"generationConfig": {}
},
"model": "test"
});
let bytes = serde_json::to_vec(&body).unwrap();
let modified = modify_request(&bytes, None).unwrap();
let result: Value = serde_json::from_slice(&modified).unwrap();
// With no ToolContext, tools should be removed entirely
assert!(
result["request"]["tools"].is_null() || result.pointer("/request/tools").is_none(),
"tools should be removed when no custom tools provided"
);
}
#[test]
fn test_modify_keeps_only_identity() {
let sys_text = "<identity>\nYou are a helpful AI.\n</identity>\n\n<tool_calling>\nUse absolute paths.\n</tool_calling>\n<web_application_development>\nlots of web dev stuff\n</web_application_development>\n<communication_style>\nbe helpful\n</communication_style>";
let body = serde_json::json!({
"project": "test",
"requestId": "test/1",
"request": {
"contents": [{"role": "user", "parts": [{"text": "hello"}]}],
"systemInstruction": {"parts": [{"text": sys_text}]},
"tools": [],
"generationConfig": {}
},
"model": "test"
});
let bytes = serde_json::to_vec(&body).unwrap();
let modified = modify_request(&bytes, None).unwrap();
let result: Value = serde_json::from_slice(&modified).unwrap();
// Rewrite extracts identity content (without tags) into a 2-part system instruction
let sys_parts = result["request"]["systemInstruction"]["parts"]
.as_array()
.unwrap();
assert_eq!(sys_parts.len(), 2, "should have identity + ignore wrapper");
let part0 = sys_parts[0]["text"].as_str().unwrap();
let part1 = sys_parts[1]["text"].as_str().unwrap();
assert!(part0.contains("You are a helpful AI."));
assert!(part1.contains("[ignore]"));
assert!(!part0.contains("tool_calling"));
assert!(!part0.contains("web_application_development"));
assert!(!part0.contains("communication_style"));
}
#[test]
fn test_modify_strips_context_messages() {
let body = serde_json::json!({
"project": "test",
"requestId": "test/1",
"request": {
"contents": [
{"role": "user", "parts": [{"text": "<user_information>\nLinux\n</user_information>"}]},
{"role": "user", "parts": [{"text": "<user_rules>\nno rules\n</user_rules>"}]},
{"role": "user", "parts": [{"text": "<workflows>\nsome workflows\n</workflows>"}]},
{"role": "user", "parts": [{"text": "Step Id: 0\n\n<USER_REQUEST>\nSay hello\n</USER_REQUEST>\n<ADDITIONAL_METADATA>\ncursor stuff\n</ADDITIONAL_METADATA>"}]},
{"role": "model", "parts": [{"text": "Hello!"}]},
],
"tools": [],
"generationConfig": {}
},
"model": "test"
});
let bytes = serde_json::to_vec(&body).unwrap();
let modified = modify_request(&bytes, None).unwrap();
let result: Value = serde_json::from_slice(&modified).unwrap();
let contents = result["request"]["contents"].as_array().unwrap();
// Should have removed user_information, user_rules, workflows (3 messages)
// Kept: USER_REQUEST message (with ADDITIONAL_METADATA stripped) + model response
assert_eq!(
contents.len(),
2,
"should keep only user request + model response"
);
// Check USER_REQUEST message had metadata stripped
let user_msg = contents[0]["parts"][0]["text"].as_str().unwrap();
assert!(user_msg.contains("Say hello"), "should keep user request");
assert!(
!user_msg.contains("ADDITIONAL_METADATA"),
"should strip metadata"
);
assert!(
!user_msg.contains("cursor stuff"),
"should strip cursor info"
);
assert!(!user_msg.starts_with("Step Id:"), "should strip step id");
// Model response kept intact
assert_eq!(contents[1]["parts"][0]["text"].as_str().unwrap(), "Hello!");
}
#[test]
fn test_extract_xml_section() {
let text = "before <identity>\nI am AI\n</identity> after";
let result = extract_xml_section(text, "identity").unwrap();
assert_eq!(result, "\nI am AI\n");
}
#[test]
fn test_strip_xml_section() {
let text = "before <META>\nstuff\n</META> after";
let result = strip_xml_section(text, "META").unwrap();
assert_eq!(result, "before after");
}
#[test]
fn test_strip_between() {
let text =
"keep this # Conversation History\nlots of stuff\n</conversation_summaries>\nand this";
let result = strip_between(
text,
"# Conversation History\n",
"</conversation_summaries>",
)
.unwrap();
assert_eq!(result, "keep this and this");
}
#[test]
fn test_multi_round_history_rewrite() {
use crate::mitm::store::{CapturedFunctionCall, PendingToolResult, ToolRound};
// Simulate 2 rounds of tool use in LS history:
// user → model("Tool call completed") → user(text) → model("Tool call completed") → user(text)
let body = serde_json::json!({
"project": "test",
"requestId": "test/1",
"request": {
"contents": [
{"role": "user", "parts": [{"text": "Read foo and write to bar"}]},
{"role": "model", "parts": [{"text": "Tool call completed. Awaiting external tool result."}]},
{"role": "user", "parts": [{"text": "[Tool result: file contents here]"}]},
{"role": "model", "parts": [{"text": "Tool call completed. Awaiting external tool result."}]},
{"role": "user", "parts": [{"text": "[Tool result: write success]"}]},
],
"tools": [],
"generationConfig": {}
},
"model": "test"
});
let tool_ctx = ToolContext {
tools: Some(vec![serde_json::json!({
"functionDeclarations": [{
"name": "read_file",
"description": "Read a file",
"parameters": {"type": "OBJECT", "properties": {"path": {"type": "STRING"}}}
}, {
"name": "write_file",
"description": "Write a file",
"parameters": {"type": "OBJECT", "properties": {"path": {"type": "STRING"}, "content": {"type": "STRING"}}}
}]
})]),
tool_config: None,
pending_results: vec![],
last_calls: vec![],
generation_params: None,
pending_image: None,
pending_user_text: String::new(),
tool_rounds: vec![
ToolRound {
calls: vec![CapturedFunctionCall {
name: "read_file".to_string(),
args: serde_json::json!({"path": "/foo"}),
thought_signature: None,
captured_at: 0,
}],
results: vec![PendingToolResult {
name: "read_file".to_string(),
result: serde_json::json!({"content": "file contents here"}),
}],
},
ToolRound {
calls: vec![CapturedFunctionCall {
name: "write_file".to_string(),
args: serde_json::json!({"path": "/bar", "content": "data"}),
thought_signature: None,
captured_at: 0,
}],
results: vec![PendingToolResult {
name: "write_file".to_string(),
result: serde_json::json!({"result": "ok"}),
}],
},
],
};
let bytes = serde_json::to_vec(&body).unwrap();
let modified = modify_request(&bytes, Some(&tool_ctx)).unwrap();
let result: Value = serde_json::from_slice(&modified).unwrap();
let contents = result["request"]["contents"].as_array().unwrap();
// Expected layout after rewrite:
// [0] user: "Read foo..."
// [1] model: functionCall(read_file) (was "Tool call completed")
// [2] user: functionResponse(read_file) (injected)
// [3] user: "[Tool result: file contents]" (original LS turn)
// [4] model: functionCall(write_file) (was "Tool call completed")
// [5] user: functionResponse(write_file) (injected)
// [6] user: "[Tool result: write success]" (original LS turn)
assert_eq!(
contents.len(),
7,
"should have 7 turns (5 original + 2 injected)"
);
// Check round 1: model turn rewritten to functionCall
assert_eq!(
contents[1]["parts"][0]["functionCall"]["name"]
.as_str()
.unwrap(),
"read_file"
);
assert_eq!(
contents[1]["parts"][0]["functionCall"]["args"]["path"]
.as_str()
.unwrap(),
"/foo"
);
// Check round 1: functionResponse injected
assert_eq!(contents[2]["role"].as_str().unwrap(), "user");
assert_eq!(
contents[2]["parts"][0]["functionResponse"]["name"]
.as_str()
.unwrap(),
"read_file"
);
// Check round 2: model turn rewritten to functionCall
assert_eq!(
contents[4]["parts"][0]["functionCall"]["name"]
.as_str()
.unwrap(),
"write_file"
);
// Check round 2: functionResponse injected
assert_eq!(
contents[5]["parts"][0]["functionResponse"]["name"]
.as_str()
.unwrap(),
"write_file"
);
}
#[test]
fn test_single_round_legacy_fallback() {
use crate::mitm::store::{CapturedFunctionCall, PendingToolResult};
// Simulate single round using legacy last_calls/pending_results (no tool_rounds).
// This is the path used by responses.rs.
let body = serde_json::json!({
"project": "test",
"requestId": "test/1",
"request": {
"contents": [
{"role": "user", "parts": [{"text": "Search for X"}]},
{"role": "model", "parts": [{"text": "Tool call completed. Awaiting external tool result."}]},
{"role": "user", "parts": [{"text": "[Tool result: found X]"}]},
],
"tools": [],
"generationConfig": {}
},
"model": "test"
});
let tool_ctx = ToolContext {
tools: Some(vec![serde_json::json!({
"functionDeclarations": [{
"name": "search",
"description": "Search",
"parameters": {"type": "OBJECT", "properties": {"q": {"type": "STRING"}}}
}]
})]),
tool_config: None,
pending_results: vec![PendingToolResult {
name: "search".to_string(),
result: serde_json::json!({"results": ["x"]}),
}],
last_calls: vec![CapturedFunctionCall {
name: "search".to_string(),
args: serde_json::json!({"q": "X"}),
thought_signature: None,
captured_at: 0,
}],
generation_params: None,
pending_image: None,
pending_user_text: String::new(),
tool_rounds: vec![], // Empty — forces legacy fallback
};
let bytes = serde_json::to_vec(&body).unwrap();
let modified = modify_request(&bytes, Some(&tool_ctx)).unwrap();
let result: Value = serde_json::from_slice(&modified).unwrap();
let contents = result["request"]["contents"].as_array().unwrap();
// Should still work: model turn rewritten + functionResponse injected
assert_eq!(
contents.len(),
4,
"should have 4 turns (3 original + 1 injected)"
);
assert_eq!(
contents[1]["parts"][0]["functionCall"]["name"]
.as_str()
.unwrap(),
"search"
);
assert_eq!(
contents[2]["parts"][0]["functionResponse"]["name"]
.as_str()
.unwrap(),
"search"
);
}
#[test]
fn test_no_tool_rounds_no_rewrite() {
// No tool rounds, no legacy data — no rewriting should happen
let body = serde_json::json!({
"project": "test",
"requestId": "test/1",
"request": {
"contents": [
{"role": "user", "parts": [{"text": "Hello"}]},
{"role": "model", "parts": [{"text": "Hi there!"}]},
],
"tools": [],
"generationConfig": {}
},
"model": "test"
});
let tool_ctx = ToolContext {
tools: Some(vec![serde_json::json!({
"functionDeclarations": [{
"name": "noop",
"description": "Does nothing",
"parameters": {"type": "OBJECT", "properties": {}}
}]
})]),
tool_config: None,
pending_results: vec![],
last_calls: vec![],
generation_params: None,
pending_image: None,
pending_user_text: String::new(),
tool_rounds: vec![],
};
let bytes = serde_json::to_vec(&body).unwrap();
let modified = modify_request(&bytes, Some(&tool_ctx)).unwrap();
let result: Value = serde_json::from_slice(&modified).unwrap();
let contents = result["request"]["contents"].as_array().unwrap();
// No rewriting — same number of turns
assert_eq!(contents.len(), 2);
assert_eq!(
contents[1]["parts"][0]["text"].as_str().unwrap(),
"Hi there!"
);
}
#[test]
fn test_tool_rounds_append_when_no_model_turns() {
use crate::mitm::store::{CapturedFunctionCall, PendingToolResult, ToolRound};
// Simulate the real-world case: LS sends cascades with ONLY user messages.
// No model turns exist, so the rewrite approach finds nothing.
// The fallback should APPEND functionCall/functionResponse pairs.
let body = serde_json::json!({
"project": "test",
"requestId": "test/1",
"request": {
"contents": [
{"role": "user", "parts": [{"text": "hello"}]},
],
"tools": [],
"generationConfig": {}
},
"model": "test"
});
let tool_ctx = ToolContext {
tools: Some(vec![serde_json::json!({
"functionDeclarations": [{
"name": "web_search",
"description": "Search the web",
"parameters": {"type": "OBJECT", "properties": {"query": {"type": "STRING"}}}
}]
})]),
tool_config: None,
pending_results: vec![],
last_calls: vec![],
generation_params: None,
pending_image: None,
pending_user_text: String::new(),
tool_rounds: vec![ToolRound {
calls: vec![CapturedFunctionCall {
name: "web_search".to_string(),
args: serde_json::json!({"query": "rust news"}),
thought_signature: None,
captured_at: 0,
}],
results: vec![PendingToolResult {
name: "web_search".to_string(),
result: serde_json::json!({"results": "some results"}),
}],
}],
};
let bytes = serde_json::to_vec(&body).unwrap();
let modified = modify_request(&bytes, Some(&tool_ctx)).unwrap();
let result: Value = serde_json::from_slice(&modified).unwrap();
let contents = result["request"]["contents"].as_array().unwrap();
// Expected layout (tool rounds appended AFTER user message):
// [0] user: "hello" ← original
// [1] model: functionCall(web_search) ← appended after user
// [2] user: functionResponse(web_search) ← appended after functionCall
assert_eq!(contents.len(), 3, "should have 3 turns: user + fc + fr");
assert_eq!(contents[0]["role"].as_str().unwrap(), "user");
assert!(contents[0]["parts"][0]["text"]
.as_str()
.unwrap()
.contains("hello"));
assert_eq!(contents[1]["role"].as_str().unwrap(), "model");
assert_eq!(
contents[1]["parts"][0]["functionCall"]["name"]
.as_str()
.unwrap(),
"web_search"
);
assert_eq!(contents[2]["role"].as_str().unwrap(), "user");
assert_eq!(
contents[2]["parts"][0]["functionResponse"]["name"]
.as_str()
.unwrap(),
"web_search"
);
}
}
// ─── Response modification ──────────────────────────────────────────────────
/// Rewrite a parsed SSE JSON object: replace `functionCall` parts with text
/// placeholder and normalize `finishReason` to `STOP`.
///
/// Used by `ResponseRewriter` to hide tool calls from the LS.
fn rewrite_function_calls_in_response(json: &mut Value) -> bool {
let mut changed = false;
fn rewrite_candidates(candidates: &mut [Value]) -> bool {
let mut changed = false;
for candidate in candidates.iter_mut() {
if let Some(parts) = candidate
.pointer_mut("/content/parts")
.and_then(|v| v.as_array_mut())
{
for part in parts.iter_mut() {
if part.get("functionCall").is_some() {
*part = serde_json::json!({
"text": "Tool call completed. Awaiting external tool result."
});
changed = true;
}
}
}
if let Some(reason) = candidate.get("finishReason").and_then(|v| v.as_str()) {
if reason != "STOP" {
candidate["finishReason"] = Value::String("STOP".to_string());
changed = true;
}
}
}
changed
}
if let Some(candidates) = json.get_mut("candidates").and_then(|v| v.as_array_mut()) {
changed |= rewrite_candidates(candidates);
}
if let Some(candidates) = json
.pointer_mut("/response/candidates")
.and_then(|v| v.as_array_mut())
{
changed |= rewrite_candidates(candidates);
}
changed
}
// ─── ResponseRewriter ────────────────────────────────────────────────────────
/// Stateful line-buffered response rewriter.
///
/// `modify_response_chunk` is stateless per-TCP-chunk — if a `functionCall`
/// JSON event spans two reads, the quick `contains("functionCall")` check
/// fails and the raw bytes leak to the LS. This struct solves that by
/// accumulating raw response bytes and only forwarding complete
/// newline-terminated SSE lines, rewriting any that contain `functionCall`.
///
/// This mirrors exactly how `parse_streaming_chunk` / `StreamingAccumulator`
/// handles cross-chunk JSON reassembly.
#[derive(Debug, Default)]
pub struct ResponseRewriter {
/// Buffered data waiting for a complete `\n`-terminated line.
pending: String,
}
impl ResponseRewriter {
pub fn new() -> Self {
Self::default()
}
/// Feed raw response bytes, get back bytes safe to forward to the LS.
///
/// Complete lines are rewritten if they contain `functionCall`, then
/// returned. Partial lines stay buffered until the next `feed()` call.
pub fn feed(&mut self, chunk: &[u8]) -> Vec<u8> {
let text = String::from_utf8_lossy(chunk);
self.pending.push_str(&text);
let mut output = String::new();
// Extract all complete lines (terminated by \n)
while let Some(pos) = self.pending.find('\n') {
// Include the \n in the extracted line
let line = self.pending[..=pos].to_string();
self.pending = self.pending[pos + 1..].to_string();
// Check if this is a `data: {JSON}` SSE line containing functionCall
let trimmed = line.trim();
if trimmed.starts_with("data: {") && trimmed.contains("functionCall") {
// Extract JSON, rewrite, and rebuild the line
if let Some(data_start) = line.find("data: {") {
let json_start = data_start + 6; // skip "data: "
let json_str = line[json_start..].trim_end();
if let Ok(mut json) = serde_json::from_str::<serde_json::Value>(json_str) {
if rewrite_function_calls_in_response(&mut json) {
if let Ok(new_json) = serde_json::to_string(&json) {
let rewritten =
format!("{}data: {}\n", &line[..data_start], new_json);
info!("MITM: rewrote functionCall in response → text placeholder for LS (buffered)");
output.push_str(&rewritten);
continue;
}
}
}
}
// Couldn't parse/rewrite — forward as-is
output.push_str(&line);
} else {
// Not a functionCall line — forward as-is
output.push_str(&line);
}
}
output.into_bytes()
}
/// Flush any remaining buffered data (call at end of response).
/// Rewrites if possible, otherwise forwards raw.
pub fn flush(&mut self) -> Vec<u8> {
if self.pending.is_empty() {
return vec![];
}
let remaining = std::mem::take(&mut self.pending);
// Try to rewrite if it contains functionCall
if remaining.contains("functionCall") {
if let Some(data_start) = remaining.find("data: {") {
let json_start = data_start + 6;
let json_str = remaining[json_start..].trim_end();
if let Ok(mut json) = serde_json::from_str::<serde_json::Value>(json_str) {
if rewrite_function_calls_in_response(&mut json) {
if let Ok(new_json) = serde_json::to_string(&json) {
let rewritten =
format!("{}data: {}", &remaining[..data_start], new_json);
info!("MITM: rewrote functionCall in flush → text placeholder for LS");
return rewritten.into_bytes();
}
}
}
}
}
remaining.into_bytes()
}
}