1570 lines
57 KiB
Rust
1570 lines
57 KiB
Rust
//! Request body modification for intercepted LLM API calls.
|
|
//!
|
|
//! Aggressively strips everything except identity and actual conversation
|
|
//! from the Gemini API request. No integrity checks exist on the request
|
|
//! body — Google validates OAuth, project, model, and JSON structure only.
|
|
|
|
use regex::Regex;
|
|
use serde_json::Value;
|
|
use tracing::info;
|
|
|
|
use super::store::{CapturedFunctionCall, PendingImage, PendingToolResult, ToolRound};
|
|
|
|
/// Strip ALL tool definitions.
|
|
/// Must be true: with tools present, the LS enters full agentic mode
|
|
/// (multi-turn tool calls, file searches, etc.) burning quota.
|
|
const STRIP_ALL_TOOLS: bool = true;
|
|
|
|
/// Context for tool injection during request modification.
|
|
/// Built from MitmStore data before calling modify_request.
|
|
pub struct ToolContext {
|
|
/// Real user text to replace the "." dot prompt sent to LS.
|
|
pub pending_user_text: String,
|
|
/// Gemini-format tool declarations (functionDeclarations).
|
|
pub tools: Option<Vec<Value>>,
|
|
/// Gemini-format toolConfig.
|
|
pub tool_config: Option<Value>,
|
|
/// Pending tool results to inject as functionResponse.
|
|
pub pending_results: Vec<PendingToolResult>,
|
|
/// Last captured function calls for history rewriting.
|
|
pub last_calls: Vec<CapturedFunctionCall>,
|
|
/// Client-specified generation parameters (temperature, top_p, etc.).
|
|
pub generation_params: Option<super::store::GenerationParams>,
|
|
/// Pending image to inject as inlineData in the user message.
|
|
pub pending_image: Option<PendingImage>,
|
|
/// Multi-round tool call history. Each entry is a (calls, results) pair
|
|
/// from one round of tool use. Preferred over last_calls/pending_results.
|
|
pub tool_rounds: Vec<ToolRound>,
|
|
}
|
|
|
|
/// Build a functionCall part JSON, including `thoughtSignature` as a sibling.
|
|
/// Google's part structure: `{functionCall: {name, args}, thoughtSignature: "..."}`
|
|
/// NOT nested inside functionCall.
|
|
fn build_function_call_part(fc: &super::store::CapturedFunctionCall) -> Value {
|
|
let mut part = serde_json::json!({
|
|
"functionCall": {
|
|
"name": fc.name,
|
|
"args": fc.args,
|
|
}
|
|
});
|
|
if let Some(ref sig) = fc.thought_signature {
|
|
part["thoughtSignature"] = Value::String(sig.clone());
|
|
}
|
|
part
|
|
}
|
|
|
|
/// Modify a streamGenerateContent request body in-place.
|
|
/// Returns the modified JSON bytes, or None if modification wasn't possible.
|
|
pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option<Vec<u8>> {
|
|
let mut json: Value = serde_json::from_slice(body).ok()?;
|
|
let original_size = body.len();
|
|
let mut changes: Vec<String> = Vec::new();
|
|
|
|
// Each phase mutates `json` in place and appends to `changes`.
|
|
rewrite_system_instruction(&mut json, &mut changes);
|
|
strip_context_messages(&mut json, &mut changes);
|
|
replace_dummy_prompt(&mut json, tool_ctx, &mut changes);
|
|
manage_tools_and_history(&mut json, tool_ctx, &mut changes);
|
|
inject_thinking_config(&mut json, tool_ctx, &mut changes);
|
|
inject_generation_params(&mut json, tool_ctx, &mut changes);
|
|
inject_pending_image(&mut json, tool_ctx, &mut changes);
|
|
|
|
if changes.is_empty() {
|
|
return None;
|
|
}
|
|
|
|
let modified_bytes = serde_json::to_vec(&json).ok()?;
|
|
let saved = original_size as i64 - modified_bytes.len() as i64;
|
|
let pct = if original_size > 0 {
|
|
(saved as f64 / original_size as f64 * 100.0) as i32
|
|
} else {
|
|
0
|
|
};
|
|
info!(
|
|
original = original_size,
|
|
modified = modified_bytes.len(),
|
|
saved_bytes = saved,
|
|
saved_pct = pct,
|
|
"MITM: request modified [{}]",
|
|
changes.join(", ")
|
|
);
|
|
Some(modified_bytes)
|
|
}
|
|
|
|
// ─── modify_request sub-functions ────────────────────────────────────────────
|
|
|
|
/// Rewrite systemInstruction to CLIProxyAPI-style multi-part format.
|
|
///
|
|
/// Extracts `<identity>` block, builds:
|
|
/// part[0] = identity text
|
|
/// part[1] = "Please ignore following [ignore]<identity>[/ignore]"
|
|
/// part[2..] = remaining original parts
|
|
fn rewrite_system_instruction(json: &mut Value, changes: &mut Vec<String>) {
|
|
let sys = match json
|
|
.pointer_mut("/request/systemInstruction/parts/0/text")
|
|
.and_then(|v| v.as_str())
|
|
.map(|s| s.to_string())
|
|
{
|
|
Some(s) => s,
|
|
None => return,
|
|
};
|
|
let original_len = sys.len();
|
|
|
|
if let Some(identity_text) = extract_xml_section(&sys, "identity") {
|
|
let identity_clean = identity_text.trim().to_string();
|
|
let part0 = identity_clean.clone();
|
|
let part1 = format!(
|
|
"Please ignore following [ignore]{}[/ignore]",
|
|
identity_clean
|
|
);
|
|
|
|
let mut extra_parts: Vec<Value> = json
|
|
.pointer("/request/systemInstruction/parts")
|
|
.and_then(|v| v.as_array())
|
|
.map(|parts| parts.iter().skip(1).cloned().collect())
|
|
.unwrap_or_default();
|
|
|
|
let mut new_parts = vec![
|
|
serde_json::json!({"text": part0}),
|
|
serde_json::json!({"text": part1}),
|
|
];
|
|
new_parts.append(&mut extra_parts);
|
|
json["request"]["systemInstruction"]["parts"] = Value::Array(new_parts);
|
|
|
|
let new_len = part0.len() + part1.len();
|
|
if original_len > new_len {
|
|
changes.push(format!(
|
|
"system instruction: CLIProxyAPI-style rewrite ({original_len} → {new_len} chars)"
|
|
));
|
|
}
|
|
} else {
|
|
changes.push(format!(
|
|
"system instruction: cleared ({original_len} chars)"
|
|
));
|
|
json["request"]["systemInstruction"]["parts"][0]["text"] = Value::String(String::new());
|
|
}
|
|
}
|
|
|
|
/// Strip Antigravity-injected context messages and inline metadata.
|
|
///
|
|
/// Removes entire messages that are pure context (user_information, user_rules,
|
|
/// workflows, mcp_servers) and strips embedded metadata from remaining messages
|
|
/// (conversation summaries, ADDITIONAL_METADATA, EPHEMERAL_MESSAGE, cid markers,
|
|
/// Step Id prefixes, knowledge items). Also collapses excessive newlines.
|
|
fn strip_context_messages(json: &mut Value, changes: &mut Vec<String>) {
|
|
let contents = match json
|
|
.pointer_mut("/request/contents")
|
|
.and_then(|v| v.as_array_mut())
|
|
{
|
|
Some(c) => c,
|
|
None => return,
|
|
};
|
|
let before = contents.len();
|
|
|
|
// Phase 1: Remove whole messages that are pure context injection
|
|
contents.retain(|msg| {
|
|
// Always keep messages with image/binary data
|
|
if let Some(parts) = msg["parts"].as_array() {
|
|
if parts.iter().any(|p| p.get("inlineData").is_some()) {
|
|
return true;
|
|
}
|
|
}
|
|
if let Some(text) = msg["parts"][0]["text"].as_str() {
|
|
if text.starts_with("<user_information>")
|
|
|| text.starts_with("<user_rules>")
|
|
|| text.starts_with("<workflows>")
|
|
|| text.starts_with("<mcp_servers>")
|
|
{
|
|
return false;
|
|
}
|
|
}
|
|
true
|
|
});
|
|
|
|
// Phase 2: Strip embedded metadata from remaining messages
|
|
for msg in contents.iter_mut() {
|
|
let text = match msg["parts"][0]["text"].as_str().map(|s| s.to_string()) {
|
|
Some(t) => t,
|
|
None => continue,
|
|
};
|
|
let mut m = text.clone();
|
|
|
|
// Conversation summaries
|
|
if let Some(c) = strip_between(&m, "# Conversation History\n", "</conversation_summaries>")
|
|
{
|
|
m = c;
|
|
}
|
|
// <ADDITIONAL_METADATA> and <EPHEMERAL_MESSAGE>
|
|
if let Some(c) = strip_xml_section(&m, "ADDITIONAL_METADATA") {
|
|
m = c;
|
|
}
|
|
if let Some(c) = strip_xml_section(&m, "EPHEMERAL_MESSAGE") {
|
|
m = c;
|
|
}
|
|
|
|
// <cid:UUID> markers
|
|
while let Some(start) = m.find("<cid:") {
|
|
if let Some(end) = m[start..].find('>') {
|
|
m = format!("{}{}", &m[..start], &m[start + end + 1..]);
|
|
} else {
|
|
break;
|
|
}
|
|
}
|
|
|
|
// "Step Id: N\n" prefixes
|
|
if m.starts_with("Step Id:") {
|
|
if let Some(nl) = m.find('\n') {
|
|
m = m[nl + 1..].to_string();
|
|
}
|
|
}
|
|
|
|
// Knowledge item blocks
|
|
if let Some(c) = strip_between(&m, "Here are the ", "</knowledge_item>") {
|
|
if c.len() < m.len() && m.contains("knowledge item") {
|
|
m = c;
|
|
}
|
|
}
|
|
|
|
let m = collapse_newlines(&m);
|
|
if m.len() < text.len() {
|
|
msg["parts"][0]["text"] = Value::String(m);
|
|
}
|
|
}
|
|
|
|
// Phase 3: Remove now-empty messages (preserve image parts)
|
|
contents.retain(|msg| {
|
|
if let Some(parts) = msg["parts"].as_array() {
|
|
if parts.iter().any(|p| p.get("inlineData").is_some()) {
|
|
return true;
|
|
}
|
|
}
|
|
msg["parts"][0]["text"]
|
|
.as_str()
|
|
.is_none_or(|t| !t.trim().is_empty())
|
|
});
|
|
|
|
let removed = before - contents.len();
|
|
if removed > 0 {
|
|
changes.push(format!("remove {removed}/{before} content messages"));
|
|
}
|
|
}
|
|
|
|
/// Replace dummy ".<cid:UUID>" prompt with real user text from the ToolContext.
|
|
///
|
|
/// The LS receives "." as the user prompt. Antigravity wraps it in
|
|
/// `<USER_REQUEST>...</USER_REQUEST>` tags. This function swaps the dot for the
|
|
/// actual user text before sending to Google.
|
|
fn replace_dummy_prompt(
|
|
json: &mut Value,
|
|
tool_ctx: Option<&ToolContext>,
|
|
changes: &mut Vec<String>,
|
|
) {
|
|
let ctx = match tool_ctx {
|
|
Some(c) if !c.pending_user_text.is_empty() => c,
|
|
_ => return,
|
|
};
|
|
let contents = match json
|
|
.pointer_mut("/request/contents")
|
|
.and_then(|v| v.as_array_mut())
|
|
{
|
|
Some(c) => c,
|
|
None => return,
|
|
};
|
|
|
|
for msg in contents.iter_mut() {
|
|
let is_user = msg
|
|
.get("role")
|
|
.and_then(|r| r.as_str())
|
|
.is_none_or(|r| r == "user");
|
|
if !is_user {
|
|
continue;
|
|
}
|
|
|
|
let text_val = match msg.pointer_mut("/parts/0/text") {
|
|
Some(v) => v,
|
|
None => continue,
|
|
};
|
|
let old = text_val.as_str().unwrap_or("");
|
|
|
|
let is_dot_in_wrapper = old.contains("<USER_REQUEST>")
|
|
&& extract_xml_section(old, "USER_REQUEST").is_some_and(|inner| {
|
|
let t = inner.trim();
|
|
t == "." || t.starts_with(".<cid:")
|
|
});
|
|
let is_bare_dot =
|
|
old.trim() == "." || (old.trim().starts_with(".<cid:") && old.trim().ends_with(">"));
|
|
|
|
if is_dot_in_wrapper {
|
|
*text_val = Value::String(format!(
|
|
"\n<USER_REQUEST>\n{}\n</USER_REQUEST>\n",
|
|
ctx.pending_user_text
|
|
));
|
|
changes.push(format!(
|
|
"replace dummy prompt in USER_REQUEST wrapper ({} chars)",
|
|
ctx.pending_user_text.len()
|
|
));
|
|
return;
|
|
} else if is_bare_dot {
|
|
*text_val = Value::String(ctx.pending_user_text.clone());
|
|
changes.push(format!(
|
|
"replace bare dummy prompt ({} chars)",
|
|
ctx.pending_user_text.len()
|
|
));
|
|
return;
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Strip LS tools, inject client tools, clean up functionCall history, and
|
|
/// rewrite conversation history with tool call/response pairs.
|
|
fn manage_tools_and_history(
|
|
json: &mut Value,
|
|
tool_ctx: Option<&ToolContext>,
|
|
changes: &mut Vec<String>,
|
|
) {
|
|
let mut has_custom_tools = false;
|
|
|
|
// ── Strip LS tools, inject client tools ──────────────────────────────
|
|
if STRIP_ALL_TOOLS {
|
|
if let Some(tools) = json
|
|
.pointer_mut("/request/tools")
|
|
.and_then(|v| v.as_array_mut())
|
|
{
|
|
let count = tools.len();
|
|
if count > 0 {
|
|
tools.clear();
|
|
changes.push(format!("strip all {count} LS tools"));
|
|
}
|
|
|
|
if let Some(ctx) = tool_ctx {
|
|
if let Some(ref custom_tools) = ctx.tools {
|
|
for tool in custom_tools {
|
|
tools.push(tool.clone());
|
|
}
|
|
has_custom_tools = true;
|
|
changes.push(format!(
|
|
"inject {} custom tool group(s)",
|
|
custom_tools.len()
|
|
));
|
|
|
|
// Override VALIDATED → AUTO for custom tools
|
|
if let Some(req) = json.get_mut("request").and_then(|v| v.as_object_mut()) {
|
|
let has_validated = req
|
|
.get("toolConfig")
|
|
.and_then(|tc| tc.pointer("/functionCallingConfig/mode"))
|
|
.and_then(|m| m.as_str())
|
|
== Some("VALIDATED");
|
|
if has_validated {
|
|
req.insert(
|
|
"toolConfig".to_string(),
|
|
serde_json::json!({"functionCallingConfig": {"mode": "AUTO"}}),
|
|
);
|
|
changes.push("override toolConfig VALIDATED → AUTO".to_string());
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// ── Clean up when no tools remain ────────────────────────────────────
|
|
if STRIP_ALL_TOOLS && !has_custom_tools {
|
|
if let Some(req) = json.get_mut("request").and_then(|v| v.as_object_mut()) {
|
|
if req
|
|
.get("tools")
|
|
.and_then(|v| v.as_array())
|
|
.is_some_and(|a| a.is_empty())
|
|
{
|
|
req.remove("tools");
|
|
changes.push("remove empty tools array".to_string());
|
|
}
|
|
if req.remove("toolConfig").is_some() {
|
|
changes.push("remove toolConfig (no tools)".to_string());
|
|
}
|
|
}
|
|
}
|
|
|
|
// ── Strip old functionCall/functionResponse from history ──────────────
|
|
if STRIP_ALL_TOOLS {
|
|
let custom_tool_names: std::collections::HashSet<String> = tool_ctx
|
|
.as_ref()
|
|
.and_then(|ctx| ctx.tools.as_ref())
|
|
.map(|tools| {
|
|
tools
|
|
.iter()
|
|
.filter_map(|t| t["functionDeclarations"].as_array())
|
|
.flatten()
|
|
.filter_map(|decl| decl["name"].as_str().map(|s| s.to_string()))
|
|
.collect()
|
|
})
|
|
.unwrap_or_default();
|
|
|
|
if let Some(contents) = json
|
|
.pointer_mut("/request/contents")
|
|
.and_then(|v| v.as_array_mut())
|
|
{
|
|
let mut stripped_fc = 0usize;
|
|
for msg in contents.iter_mut() {
|
|
if let Some(parts) = msg.get_mut("parts").and_then(|v| v.as_array_mut()) {
|
|
let before = parts.len();
|
|
parts.retain(|part| {
|
|
if let Some(fc) = part.get("functionCall") {
|
|
return fc
|
|
.get("name")
|
|
.and_then(|v| v.as_str())
|
|
.is_some_and(|n| custom_tool_names.contains(n));
|
|
}
|
|
if let Some(fr) = part.get("functionResponse") {
|
|
return fr
|
|
.get("name")
|
|
.and_then(|v| v.as_str())
|
|
.is_some_and(|n| custom_tool_names.contains(n));
|
|
}
|
|
true
|
|
});
|
|
stripped_fc += before - parts.len();
|
|
}
|
|
}
|
|
contents.retain(|msg| {
|
|
msg.get("parts")
|
|
.and_then(|v| v.as_array())
|
|
.is_none_or(|p| !p.is_empty())
|
|
});
|
|
if stripped_fc > 0 {
|
|
changes.push(format!(
|
|
"strip {stripped_fc} functionCall/Response parts from history"
|
|
));
|
|
}
|
|
}
|
|
}
|
|
|
|
// ── Inject toolConfig if provided ────────────────────────────────────
|
|
if let Some(ctx) = tool_ctx {
|
|
if let Some(ref config) = ctx.tool_config {
|
|
if let Some(req) = json.get_mut("request").and_then(|v| v.as_object_mut()) {
|
|
req.insert("toolConfig".to_string(), config.clone());
|
|
changes.push("inject toolConfig".to_string());
|
|
}
|
|
}
|
|
}
|
|
|
|
// ── Rewrite conversation history for tool results ────────────────────
|
|
rewrite_tool_rounds(json, tool_ctx, changes);
|
|
}
|
|
|
|
/// Rewrite conversation history: replace placeholder model turns with real
|
|
/// functionCall parts and inject functionResponse user turns.
|
|
fn rewrite_tool_rounds(
|
|
json: &mut Value,
|
|
tool_ctx: Option<&ToolContext>,
|
|
changes: &mut Vec<String>,
|
|
) {
|
|
let ctx = match tool_ctx {
|
|
Some(c) => c,
|
|
None => return,
|
|
};
|
|
|
|
let rounds = if !ctx.tool_rounds.is_empty() {
|
|
ctx.tool_rounds.clone()
|
|
} else if !ctx.pending_results.is_empty() && !ctx.last_calls.is_empty() {
|
|
vec![ToolRound {
|
|
calls: ctx.last_calls.clone(),
|
|
results: ctx.pending_results.clone(),
|
|
}]
|
|
} else {
|
|
return;
|
|
};
|
|
|
|
let contents = match json
|
|
.pointer_mut("/request/contents")
|
|
.and_then(|v| v.as_array_mut())
|
|
{
|
|
Some(c) => c,
|
|
None => return,
|
|
};
|
|
|
|
// Phase 1: find model turns with placeholder text
|
|
let mut rewrites: Vec<(usize, usize)> = Vec::new();
|
|
let mut round_idx = 0;
|
|
for (i, msg) in contents.iter().enumerate() {
|
|
if round_idx >= rounds.len() {
|
|
break;
|
|
}
|
|
if msg["role"].as_str() == Some("model") {
|
|
if let Some(text) = msg["parts"][0]["text"].as_str() {
|
|
if text.contains("Tool call completed")
|
|
|| text.contains("Awaiting external tool result")
|
|
{
|
|
rewrites.push((i, round_idx));
|
|
round_idx += 1;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Phase 2: apply rewrites
|
|
let mut insert_offset = 0;
|
|
for (content_idx, round_idx) in &rewrites {
|
|
let actual_idx = *content_idx + insert_offset;
|
|
let round = &rounds[*round_idx];
|
|
|
|
let fc_parts: Vec<Value> = round.calls.iter().map(build_function_call_part).collect();
|
|
contents[actual_idx]["parts"] = Value::Array(fc_parts);
|
|
|
|
if !round.results.is_empty() {
|
|
let fr_parts: Vec<Value> = round.results.iter()
|
|
.map(|r| serde_json::json!({"functionResponse": {"name": r.name, "response": r.result}}))
|
|
.collect();
|
|
contents.insert(
|
|
actual_idx + 1,
|
|
serde_json::json!({"role": "user", "parts": fr_parts}),
|
|
);
|
|
insert_offset += 1;
|
|
}
|
|
}
|
|
|
|
if !rewrites.is_empty() {
|
|
changes.push(format!(
|
|
"rewrite {} tool round(s) in history",
|
|
rewrites.len()
|
|
));
|
|
} else {
|
|
// Append as new messages (no existing model turns to rewrite)
|
|
let insert_pos = contents.len();
|
|
let mut offset = 0;
|
|
for round in &rounds {
|
|
let fc_parts: Vec<Value> = round.calls.iter().map(build_function_call_part).collect();
|
|
contents.insert(
|
|
insert_pos + offset,
|
|
serde_json::json!({"role": "model", "parts": fc_parts}),
|
|
);
|
|
offset += 1;
|
|
|
|
if !round.results.is_empty() {
|
|
let fr_parts: Vec<Value> = round.results.iter()
|
|
.map(|r| serde_json::json!({"functionResponse": {"name": r.name, "response": r.result}}))
|
|
.collect();
|
|
contents.insert(
|
|
insert_pos + offset,
|
|
serde_json::json!({"role": "user", "parts": fr_parts}),
|
|
);
|
|
offset += 1;
|
|
}
|
|
}
|
|
changes.push(format!(
|
|
"append {} tool round(s) as functionCall/Response pairs (no model turns found)",
|
|
rounds.len()
|
|
));
|
|
}
|
|
}
|
|
|
|
/// Inject `includeThoughts` and `thinkingLevel` into generationConfig.
|
|
fn inject_thinking_config(
|
|
json: &mut Value,
|
|
tool_ctx: Option<&ToolContext>,
|
|
changes: &mut Vec<String>,
|
|
) {
|
|
let reasoning_effort = tool_ctx
|
|
.and_then(|ctx| ctx.generation_params.as_ref())
|
|
.and_then(|gp| gp.reasoning_effort.clone());
|
|
|
|
// Helper: inject into a thinkingConfig object
|
|
let inject =
|
|
|tc: &mut serde_json::Map<String, Value>, changes: &mut Vec<String>, suffix: &str| {
|
|
if !tc.contains_key("includeThoughts") {
|
|
tc.insert("includeThoughts".to_string(), Value::Bool(true));
|
|
changes.push(format!("inject includeThoughts{suffix}"));
|
|
}
|
|
if let Some(ref effort) = reasoning_effort {
|
|
tc.insert("thinkingLevel".to_string(), Value::String(effort.clone()));
|
|
changes.push(format!("inject thinkingLevel={effort}{suffix}"));
|
|
}
|
|
};
|
|
|
|
if let Some(req) = json.get_mut("request").and_then(|v| v.as_object_mut()) {
|
|
let gc = req
|
|
.entry("generationConfig")
|
|
.or_insert_with(|| serde_json::json!({}));
|
|
if let Some(gc) = gc.as_object_mut() {
|
|
let tc = gc
|
|
.entry("thinkingConfig")
|
|
.or_insert_with(|| serde_json::json!({}));
|
|
if let Some(tc) = tc.as_object_mut() {
|
|
inject(tc, changes, "");
|
|
}
|
|
}
|
|
} else if let Some(o) = json.as_object_mut() {
|
|
let gc = o
|
|
.entry("generationConfig")
|
|
.or_insert_with(|| serde_json::json!({}));
|
|
if let Some(gc) = gc.as_object_mut() {
|
|
let tc = gc
|
|
.entry("thinkingConfig")
|
|
.or_insert_with(|| serde_json::json!({}));
|
|
if let Some(tc) = tc.as_object_mut() {
|
|
inject(tc, changes, " (top-level)");
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Inject client-specified generation parameters (temperature, topP, etc.).
|
|
fn inject_generation_params(
|
|
json: &mut Value,
|
|
tool_ctx: Option<&ToolContext>,
|
|
changes: &mut Vec<String>,
|
|
) {
|
|
let gp = match tool_ctx.and_then(|ctx| ctx.generation_params.as_ref()) {
|
|
Some(gp) => gp,
|
|
None => return,
|
|
};
|
|
|
|
let gc = if let Some(req) = json.get_mut("request").and_then(|v| v.as_object_mut()) {
|
|
Some(
|
|
req.entry("generationConfig")
|
|
.or_insert_with(|| serde_json::json!({})),
|
|
)
|
|
} else {
|
|
json.as_object_mut().map(|o| {
|
|
o.entry("generationConfig")
|
|
.or_insert_with(|| serde_json::json!({}))
|
|
})
|
|
};
|
|
|
|
let gc = match gc.and_then(|v| v.as_object_mut()) {
|
|
Some(gc) => gc,
|
|
None => return,
|
|
};
|
|
|
|
let mut injected: Vec<String> = Vec::new();
|
|
if let Some(t) = gp.temperature {
|
|
gc.insert("temperature".into(), serde_json::json!(t));
|
|
injected.push(format!("temperature={t}"));
|
|
}
|
|
if let Some(p) = gp.top_p {
|
|
gc.insert("topP".into(), serde_json::json!(p));
|
|
injected.push(format!("topP={p}"));
|
|
}
|
|
if let Some(k) = gp.top_k {
|
|
gc.insert("topK".into(), serde_json::json!(k));
|
|
injected.push(format!("topK={k}"));
|
|
}
|
|
if let Some(m) = gp.max_output_tokens {
|
|
gc.insert("maxOutputTokens".into(), serde_json::json!(m));
|
|
injected.push(format!("maxOutputTokens={m}"));
|
|
}
|
|
if let Some(ref seqs) = gp.stop_sequences {
|
|
gc.insert("stopSequences".into(), serde_json::json!(seqs));
|
|
injected.push(format!("stopSequences({})", seqs.len()));
|
|
}
|
|
if let Some(fp) = gp.frequency_penalty {
|
|
gc.insert("frequencyPenalty".into(), serde_json::json!(fp));
|
|
injected.push(format!("frequencyPenalty={fp}"));
|
|
}
|
|
if let Some(pp) = gp.presence_penalty {
|
|
gc.insert("presencePenalty".into(), serde_json::json!(pp));
|
|
injected.push(format!("presencePenalty={pp}"));
|
|
}
|
|
if let Some(ref mime) = gp.response_mime_type {
|
|
gc.insert("responseMimeType".into(), serde_json::json!(mime));
|
|
injected.push(format!("responseMimeType={mime}"));
|
|
}
|
|
if let Some(ref schema) = gp.response_schema {
|
|
gc.insert("responseSchema".into(), schema.clone());
|
|
injected.push("responseSchema=<schema>".to_string());
|
|
}
|
|
|
|
if !injected.is_empty() {
|
|
changes.push(format!("inject generationConfig: {}", injected.join(", ")));
|
|
}
|
|
}
|
|
|
|
/// Inject a pending image as inlineData into the last user message.
|
|
fn inject_pending_image(
|
|
json: &mut Value,
|
|
tool_ctx: Option<&ToolContext>,
|
|
changes: &mut Vec<String>,
|
|
) {
|
|
let img = match tool_ctx.and_then(|ctx| ctx.pending_image.as_ref()) {
|
|
Some(img) => img,
|
|
None => return,
|
|
};
|
|
let contents = match json
|
|
.pointer_mut("/request/contents")
|
|
.and_then(|v| v.as_array_mut())
|
|
{
|
|
Some(c) => c,
|
|
None => return,
|
|
};
|
|
|
|
for msg in contents.iter_mut().rev() {
|
|
if msg["role"].as_str() != Some("user") {
|
|
continue;
|
|
}
|
|
if let Some(parts) = msg.get_mut("parts").and_then(|v| v.as_array_mut()) {
|
|
parts.push(serde_json::json!({
|
|
"inlineData": { "mimeType": img.mime_type, "data": img.base64_data }
|
|
}));
|
|
changes.push(format!(
|
|
"inject image ({}; {} bytes base64)",
|
|
img.mime_type,
|
|
img.base64_data.len()
|
|
));
|
|
return;
|
|
}
|
|
}
|
|
tracing::warn!("MITM: pending image but no user message found to inject into");
|
|
}
|
|
|
|
/// Extract the inner text of an XML-style section.
|
|
fn extract_xml_section(text: &str, tag: &str) -> Option<String> {
|
|
let open = format!("<{tag}>");
|
|
let close = format!("</{tag}>");
|
|
let start = text.find(&open)?;
|
|
let end = text.find(&close)?;
|
|
let inner_start = start + open.len();
|
|
if inner_start >= end {
|
|
return None;
|
|
}
|
|
Some(text[inner_start..end].to_string())
|
|
}
|
|
|
|
/// Strip an XML-style section and return the modified text.
|
|
fn strip_xml_section(text: &str, tag: &str) -> Option<String> {
|
|
let open = format!("<{tag}>");
|
|
let close = format!("</{tag}>");
|
|
let start = text.find(&open)?;
|
|
let end = text.find(&close)?;
|
|
let end_pos = end + close.len();
|
|
Some(format!("{}{}", &text[..start], &text[end_pos..]))
|
|
}
|
|
|
|
/// Strip everything between two markers (inclusive of markers).
|
|
fn strip_between(text: &str, start_marker: &str, end_marker: &str) -> Option<String> {
|
|
let start = text.find(start_marker)?;
|
|
let end = text.find(end_marker)?;
|
|
let end_pos = end + end_marker.len();
|
|
// Skip any trailing whitespace after end marker
|
|
let rest = text[end_pos..].trim_start();
|
|
Some(format!("{}{}", &text[..start], rest))
|
|
}
|
|
|
|
/// Collapse 3+ consecutive newlines into 2.
|
|
fn collapse_newlines(text: &str) -> String {
|
|
static RE: std::sync::LazyLock<Regex> =
|
|
std::sync::LazyLock::new(|| Regex::new(r"\n{3,}").unwrap());
|
|
RE.replace_all(text, "\n\n").to_string()
|
|
}
|
|
|
|
/// Dechunk an HTTP chunked-encoded body into raw bytes.
|
|
pub fn dechunk(data: &[u8]) -> Vec<u8> {
|
|
let mut result = Vec::with_capacity(data.len());
|
|
let mut pos = 0;
|
|
|
|
while pos < data.len() {
|
|
let line_end = match data[pos..].windows(2).position(|w| w == b"\r\n") {
|
|
Some(p) => pos + p,
|
|
None => break,
|
|
};
|
|
|
|
let size_str = std::str::from_utf8(&data[pos..line_end])
|
|
.unwrap_or("")
|
|
.split(';')
|
|
.next()
|
|
.unwrap_or("")
|
|
.trim();
|
|
|
|
let chunk_size = match usize::from_str_radix(size_str, 16) {
|
|
Ok(0) => break,
|
|
Ok(n) => n,
|
|
Err(_) => break,
|
|
};
|
|
|
|
let data_start = line_end + 2;
|
|
let data_end = (data_start + chunk_size).min(data.len());
|
|
result.extend_from_slice(&data[data_start..data_end]);
|
|
|
|
pos = data_end + 2;
|
|
}
|
|
|
|
result
|
|
}
|
|
|
|
/// Re-encode data as a single HTTP chunk + terminal chunk.
|
|
pub fn rechunk(data: &[u8]) -> Vec<u8> {
|
|
let hex_size = format!("{:x}", data.len());
|
|
let mut result = Vec::with_capacity(hex_size.len() + 2 + data.len() + 2 + 5);
|
|
result.extend_from_slice(hex_size.as_bytes());
|
|
result.extend_from_slice(b"\r\n");
|
|
result.extend_from_slice(data);
|
|
result.extend_from_slice(b"\r\n0\r\n\r\n");
|
|
result
|
|
}
|
|
|
|
// ── OpenAI → Gemini format conversion ────────────────────────────────────────
|
|
|
|
/// Convert OpenAI tool definitions to Gemini functionDeclarations format.
|
|
///
|
|
/// OpenAI: `[{"type":"function","function":{"name":"X","description":"Y","parameters":{...}}}]`
|
|
/// Gemini: `[{"functionDeclarations":[{"name":"X","description":"Y","parameters":{...}}]}]`
|
|
pub fn openai_tools_to_gemini(tools: &[Value]) -> Vec<Value> {
|
|
let declarations: Vec<Value> = tools
|
|
.iter()
|
|
.filter(|t| t["type"].as_str() == Some("function"))
|
|
.filter_map(|t| {
|
|
let func = t.get("function")?;
|
|
let mut decl = serde_json::json!({
|
|
"name": func["name"],
|
|
"description": func["description"],
|
|
});
|
|
if let Some(params) = func.get("parameters") {
|
|
let cleaned = clean_schema_for_gemini(uppercase_types(params.clone()));
|
|
decl["parameters"] = cleaned;
|
|
}
|
|
Some(decl)
|
|
})
|
|
.collect();
|
|
|
|
if declarations.is_empty() {
|
|
return vec![];
|
|
}
|
|
|
|
vec![serde_json::json!({"functionDeclarations": declarations})]
|
|
}
|
|
|
|
/// Recursively strip JSON Schema fields that Google's Gemini API doesn't accept.
|
|
/// Known unsupported: $schema, additionalProperties, $ref, $defs, default, examples
|
|
fn clean_schema_for_gemini(mut val: Value) -> Value {
|
|
const STRIP_KEYS: &[&str] = &[
|
|
"$schema",
|
|
"additionalProperties",
|
|
"$ref",
|
|
"$defs",
|
|
"default",
|
|
"examples",
|
|
"title",
|
|
];
|
|
|
|
match &mut val {
|
|
Value::Object(map) => {
|
|
for key in STRIP_KEYS {
|
|
map.remove(*key);
|
|
}
|
|
let keys: Vec<String> = map.keys().cloned().collect();
|
|
for key in keys {
|
|
if let Some(v) = map.remove(&key) {
|
|
map.insert(key, clean_schema_for_gemini(v));
|
|
}
|
|
}
|
|
}
|
|
Value::Array(arr) => {
|
|
for v in arr.iter_mut() {
|
|
*v = clean_schema_for_gemini(std::mem::take(v));
|
|
}
|
|
}
|
|
_ => {}
|
|
}
|
|
val
|
|
}
|
|
|
|
/// Convert OpenAI tool_choice to Gemini toolConfig format.
|
|
///
|
|
/// OpenAI: "auto" | "required" | "none" | {"type":"function","function":{"name":"X"}}
|
|
/// Gemini: {"functionCallingConfig":{"mode":"AUTO|ANY|NONE","allowedFunctionNames":[...]}}
|
|
pub fn openai_tool_choice_to_gemini(choice: &Value) -> Value {
|
|
match choice {
|
|
Value::String(s) => match s.as_str() {
|
|
"auto" => serde_json::json!({"functionCallingConfig": {"mode": "AUTO"}}),
|
|
"required" => serde_json::json!({"functionCallingConfig": {"mode": "ANY"}}),
|
|
"none" => serde_json::json!({"functionCallingConfig": {"mode": "NONE"}}),
|
|
_ => serde_json::json!({"functionCallingConfig": {"mode": "AUTO"}}),
|
|
},
|
|
Value::Object(obj) => {
|
|
if let Some(name) = obj.get("function").and_then(|f| f["name"].as_str()) {
|
|
serde_json::json!({
|
|
"functionCallingConfig": {
|
|
"mode": "ANY",
|
|
"allowedFunctionNames": [name]
|
|
}
|
|
})
|
|
} else {
|
|
serde_json::json!({"functionCallingConfig": {"mode": "AUTO"}})
|
|
}
|
|
}
|
|
_ => serde_json::json!({"functionCallingConfig": {"mode": "AUTO"}}),
|
|
}
|
|
}
|
|
|
|
/// Recursively convert JSON Schema type strings to uppercase (Gemini format).
|
|
/// "object" → "OBJECT", "string" → "STRING", etc.
|
|
fn uppercase_types(mut val: Value) -> Value {
|
|
match &mut val {
|
|
Value::Object(map) => {
|
|
if let Some(t) = map
|
|
.get("type")
|
|
.and_then(|v| v.as_str())
|
|
.map(|s| s.to_uppercase())
|
|
{
|
|
map.insert("type".to_string(), Value::String(t));
|
|
}
|
|
let keys: Vec<String> = map.keys().cloned().collect();
|
|
for key in keys {
|
|
if let Some(v) = map.remove(&key) {
|
|
map.insert(key, uppercase_types(v));
|
|
}
|
|
}
|
|
}
|
|
Value::Array(arr) => {
|
|
for v in arr.iter_mut() {
|
|
*v = uppercase_types(std::mem::take(v));
|
|
}
|
|
}
|
|
_ => {}
|
|
}
|
|
val
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_dechunk_basic() {
|
|
let chunked = b"5\r\nhello\r\n6\r\n world\r\n0\r\n\r\n";
|
|
let result = dechunk(chunked);
|
|
assert_eq!(result, b"hello world");
|
|
}
|
|
|
|
#[test]
|
|
fn test_dechunk_single() {
|
|
let chunked = b"b\r\nhello world\r\n0\r\n\r\n";
|
|
let result = dechunk(chunked);
|
|
assert_eq!(result, b"hello world");
|
|
}
|
|
|
|
#[test]
|
|
fn test_rechunk() {
|
|
let data = b"hello world";
|
|
let chunked = rechunk(data);
|
|
let expected = b"b\r\nhello world\r\n0\r\n\r\n";
|
|
assert_eq!(chunked, expected);
|
|
}
|
|
|
|
#[test]
|
|
fn test_dechunk_rechunk_roundtrip() {
|
|
let original = b"5\r\nhello\r\n6\r\n world\r\n0\r\n\r\n";
|
|
let data = dechunk(original);
|
|
let rechunked = rechunk(&data);
|
|
let data2 = dechunk(&rechunked);
|
|
assert_eq!(data, data2);
|
|
}
|
|
|
|
#[test]
|
|
fn test_modify_strips_all_tools() {
|
|
let body = serde_json::json!({
|
|
"project": "test",
|
|
"requestId": "test/1",
|
|
"request": {
|
|
"contents": [{"role": "user", "parts": [{"text": "hello"}]}],
|
|
"tools": [
|
|
{"functionDeclarations": [{"name": "view_file", "description": "view", "parameters": {}}]},
|
|
{"functionDeclarations": [{"name": "browser_subagent", "description": "browse", "parameters": {}}]},
|
|
],
|
|
"generationConfig": {}
|
|
},
|
|
"model": "test"
|
|
});
|
|
|
|
let bytes = serde_json::to_vec(&body).unwrap();
|
|
let modified = modify_request(&bytes, None).unwrap();
|
|
let result: Value = serde_json::from_slice(&modified).unwrap();
|
|
|
|
// With no ToolContext, tools should be removed entirely
|
|
assert!(
|
|
result["request"]["tools"].is_null() || result.pointer("/request/tools").is_none(),
|
|
"tools should be removed when no custom tools provided"
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn test_modify_keeps_only_identity() {
|
|
let sys_text = "<identity>\nYou are a helpful AI.\n</identity>\n\n<tool_calling>\nUse absolute paths.\n</tool_calling>\n<web_application_development>\nlots of web dev stuff\n</web_application_development>\n<communication_style>\nbe helpful\n</communication_style>";
|
|
let body = serde_json::json!({
|
|
"project": "test",
|
|
"requestId": "test/1",
|
|
"request": {
|
|
"contents": [{"role": "user", "parts": [{"text": "hello"}]}],
|
|
"systemInstruction": {"parts": [{"text": sys_text}]},
|
|
"tools": [],
|
|
"generationConfig": {}
|
|
},
|
|
"model": "test"
|
|
});
|
|
|
|
let bytes = serde_json::to_vec(&body).unwrap();
|
|
let modified = modify_request(&bytes, None).unwrap();
|
|
let result: Value = serde_json::from_slice(&modified).unwrap();
|
|
|
|
// Rewrite extracts identity content (without tags) into a 2-part system instruction
|
|
let sys_parts = result["request"]["systemInstruction"]["parts"]
|
|
.as_array()
|
|
.unwrap();
|
|
assert_eq!(sys_parts.len(), 2, "should have identity + ignore wrapper");
|
|
let part0 = sys_parts[0]["text"].as_str().unwrap();
|
|
let part1 = sys_parts[1]["text"].as_str().unwrap();
|
|
assert!(part0.contains("You are a helpful AI."));
|
|
assert!(part1.contains("[ignore]"));
|
|
assert!(!part0.contains("tool_calling"));
|
|
assert!(!part0.contains("web_application_development"));
|
|
assert!(!part0.contains("communication_style"));
|
|
}
|
|
|
|
#[test]
|
|
fn test_modify_strips_context_messages() {
|
|
let body = serde_json::json!({
|
|
"project": "test",
|
|
"requestId": "test/1",
|
|
"request": {
|
|
"contents": [
|
|
{"role": "user", "parts": [{"text": "<user_information>\nLinux\n</user_information>"}]},
|
|
{"role": "user", "parts": [{"text": "<user_rules>\nno rules\n</user_rules>"}]},
|
|
{"role": "user", "parts": [{"text": "<workflows>\nsome workflows\n</workflows>"}]},
|
|
{"role": "user", "parts": [{"text": "Step Id: 0\n\n<USER_REQUEST>\nSay hello\n</USER_REQUEST>\n<ADDITIONAL_METADATA>\ncursor stuff\n</ADDITIONAL_METADATA>"}]},
|
|
{"role": "model", "parts": [{"text": "Hello!"}]},
|
|
],
|
|
"tools": [],
|
|
"generationConfig": {}
|
|
},
|
|
"model": "test"
|
|
});
|
|
|
|
let bytes = serde_json::to_vec(&body).unwrap();
|
|
let modified = modify_request(&bytes, None).unwrap();
|
|
let result: Value = serde_json::from_slice(&modified).unwrap();
|
|
|
|
let contents = result["request"]["contents"].as_array().unwrap();
|
|
// Should have removed user_information, user_rules, workflows (3 messages)
|
|
// Kept: USER_REQUEST message (with ADDITIONAL_METADATA stripped) + model response
|
|
assert_eq!(
|
|
contents.len(),
|
|
2,
|
|
"should keep only user request + model response"
|
|
);
|
|
|
|
// Check USER_REQUEST message had metadata stripped
|
|
let user_msg = contents[0]["parts"][0]["text"].as_str().unwrap();
|
|
assert!(user_msg.contains("Say hello"), "should keep user request");
|
|
assert!(
|
|
!user_msg.contains("ADDITIONAL_METADATA"),
|
|
"should strip metadata"
|
|
);
|
|
assert!(
|
|
!user_msg.contains("cursor stuff"),
|
|
"should strip cursor info"
|
|
);
|
|
assert!(!user_msg.starts_with("Step Id:"), "should strip step id");
|
|
|
|
// Model response kept intact
|
|
assert_eq!(contents[1]["parts"][0]["text"].as_str().unwrap(), "Hello!");
|
|
}
|
|
|
|
#[test]
|
|
fn test_extract_xml_section() {
|
|
let text = "before <identity>\nI am AI\n</identity> after";
|
|
let result = extract_xml_section(text, "identity").unwrap();
|
|
assert_eq!(result, "\nI am AI\n");
|
|
}
|
|
|
|
#[test]
|
|
fn test_strip_xml_section() {
|
|
let text = "before <META>\nstuff\n</META> after";
|
|
let result = strip_xml_section(text, "META").unwrap();
|
|
assert_eq!(result, "before after");
|
|
}
|
|
|
|
#[test]
|
|
fn test_strip_between() {
|
|
let text =
|
|
"keep this # Conversation History\nlots of stuff\n</conversation_summaries>\nand this";
|
|
let result = strip_between(
|
|
text,
|
|
"# Conversation History\n",
|
|
"</conversation_summaries>",
|
|
)
|
|
.unwrap();
|
|
assert_eq!(result, "keep this and this");
|
|
}
|
|
|
|
#[test]
|
|
fn test_multi_round_history_rewrite() {
|
|
use crate::mitm::store::{CapturedFunctionCall, PendingToolResult, ToolRound};
|
|
|
|
// Simulate 2 rounds of tool use in LS history:
|
|
// user → model("Tool call completed") → user(text) → model("Tool call completed") → user(text)
|
|
let body = serde_json::json!({
|
|
"project": "test",
|
|
"requestId": "test/1",
|
|
"request": {
|
|
"contents": [
|
|
{"role": "user", "parts": [{"text": "Read foo and write to bar"}]},
|
|
{"role": "model", "parts": [{"text": "Tool call completed. Awaiting external tool result."}]},
|
|
{"role": "user", "parts": [{"text": "[Tool result: file contents here]"}]},
|
|
{"role": "model", "parts": [{"text": "Tool call completed. Awaiting external tool result."}]},
|
|
{"role": "user", "parts": [{"text": "[Tool result: write success]"}]},
|
|
],
|
|
"tools": [],
|
|
"generationConfig": {}
|
|
},
|
|
"model": "test"
|
|
});
|
|
|
|
let tool_ctx = ToolContext {
|
|
tools: Some(vec![serde_json::json!({
|
|
"functionDeclarations": [{
|
|
"name": "read_file",
|
|
"description": "Read a file",
|
|
"parameters": {"type": "OBJECT", "properties": {"path": {"type": "STRING"}}}
|
|
}, {
|
|
"name": "write_file",
|
|
"description": "Write a file",
|
|
"parameters": {"type": "OBJECT", "properties": {"path": {"type": "STRING"}, "content": {"type": "STRING"}}}
|
|
}]
|
|
})]),
|
|
tool_config: None,
|
|
pending_results: vec![],
|
|
last_calls: vec![],
|
|
generation_params: None,
|
|
pending_image: None,
|
|
pending_user_text: String::new(),
|
|
tool_rounds: vec![
|
|
ToolRound {
|
|
calls: vec![CapturedFunctionCall {
|
|
name: "read_file".to_string(),
|
|
args: serde_json::json!({"path": "/foo"}),
|
|
thought_signature: None,
|
|
captured_at: 0,
|
|
}],
|
|
results: vec![PendingToolResult {
|
|
name: "read_file".to_string(),
|
|
result: serde_json::json!({"content": "file contents here"}),
|
|
}],
|
|
},
|
|
ToolRound {
|
|
calls: vec![CapturedFunctionCall {
|
|
name: "write_file".to_string(),
|
|
args: serde_json::json!({"path": "/bar", "content": "data"}),
|
|
thought_signature: None,
|
|
captured_at: 0,
|
|
}],
|
|
results: vec![PendingToolResult {
|
|
name: "write_file".to_string(),
|
|
result: serde_json::json!({"result": "ok"}),
|
|
}],
|
|
},
|
|
],
|
|
};
|
|
|
|
let bytes = serde_json::to_vec(&body).unwrap();
|
|
let modified = modify_request(&bytes, Some(&tool_ctx)).unwrap();
|
|
let result: Value = serde_json::from_slice(&modified).unwrap();
|
|
let contents = result["request"]["contents"].as_array().unwrap();
|
|
|
|
// Expected layout after rewrite:
|
|
// [0] user: "Read foo..."
|
|
// [1] model: functionCall(read_file) (was "Tool call completed")
|
|
// [2] user: functionResponse(read_file) (injected)
|
|
// [3] user: "[Tool result: file contents]" (original LS turn)
|
|
// [4] model: functionCall(write_file) (was "Tool call completed")
|
|
// [5] user: functionResponse(write_file) (injected)
|
|
// [6] user: "[Tool result: write success]" (original LS turn)
|
|
assert_eq!(
|
|
contents.len(),
|
|
7,
|
|
"should have 7 turns (5 original + 2 injected)"
|
|
);
|
|
|
|
// Check round 1: model turn rewritten to functionCall
|
|
assert_eq!(
|
|
contents[1]["parts"][0]["functionCall"]["name"]
|
|
.as_str()
|
|
.unwrap(),
|
|
"read_file"
|
|
);
|
|
assert_eq!(
|
|
contents[1]["parts"][0]["functionCall"]["args"]["path"]
|
|
.as_str()
|
|
.unwrap(),
|
|
"/foo"
|
|
);
|
|
// Check round 1: functionResponse injected
|
|
assert_eq!(contents[2]["role"].as_str().unwrap(), "user");
|
|
assert_eq!(
|
|
contents[2]["parts"][0]["functionResponse"]["name"]
|
|
.as_str()
|
|
.unwrap(),
|
|
"read_file"
|
|
);
|
|
|
|
// Check round 2: model turn rewritten to functionCall
|
|
assert_eq!(
|
|
contents[4]["parts"][0]["functionCall"]["name"]
|
|
.as_str()
|
|
.unwrap(),
|
|
"write_file"
|
|
);
|
|
// Check round 2: functionResponse injected
|
|
assert_eq!(
|
|
contents[5]["parts"][0]["functionResponse"]["name"]
|
|
.as_str()
|
|
.unwrap(),
|
|
"write_file"
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn test_single_round_legacy_fallback() {
|
|
use crate::mitm::store::{CapturedFunctionCall, PendingToolResult};
|
|
|
|
// Simulate single round using legacy last_calls/pending_results (no tool_rounds).
|
|
// This is the path used by responses.rs.
|
|
let body = serde_json::json!({
|
|
"project": "test",
|
|
"requestId": "test/1",
|
|
"request": {
|
|
"contents": [
|
|
{"role": "user", "parts": [{"text": "Search for X"}]},
|
|
{"role": "model", "parts": [{"text": "Tool call completed. Awaiting external tool result."}]},
|
|
{"role": "user", "parts": [{"text": "[Tool result: found X]"}]},
|
|
],
|
|
"tools": [],
|
|
"generationConfig": {}
|
|
},
|
|
"model": "test"
|
|
});
|
|
|
|
let tool_ctx = ToolContext {
|
|
tools: Some(vec![serde_json::json!({
|
|
"functionDeclarations": [{
|
|
"name": "search",
|
|
"description": "Search",
|
|
"parameters": {"type": "OBJECT", "properties": {"q": {"type": "STRING"}}}
|
|
}]
|
|
})]),
|
|
tool_config: None,
|
|
pending_results: vec![PendingToolResult {
|
|
name: "search".to_string(),
|
|
result: serde_json::json!({"results": ["x"]}),
|
|
}],
|
|
last_calls: vec![CapturedFunctionCall {
|
|
name: "search".to_string(),
|
|
args: serde_json::json!({"q": "X"}),
|
|
thought_signature: None,
|
|
captured_at: 0,
|
|
}],
|
|
generation_params: None,
|
|
pending_image: None,
|
|
pending_user_text: String::new(),
|
|
tool_rounds: vec![], // Empty — forces legacy fallback
|
|
};
|
|
|
|
let bytes = serde_json::to_vec(&body).unwrap();
|
|
let modified = modify_request(&bytes, Some(&tool_ctx)).unwrap();
|
|
let result: Value = serde_json::from_slice(&modified).unwrap();
|
|
let contents = result["request"]["contents"].as_array().unwrap();
|
|
|
|
// Should still work: model turn rewritten + functionResponse injected
|
|
assert_eq!(
|
|
contents.len(),
|
|
4,
|
|
"should have 4 turns (3 original + 1 injected)"
|
|
);
|
|
assert_eq!(
|
|
contents[1]["parts"][0]["functionCall"]["name"]
|
|
.as_str()
|
|
.unwrap(),
|
|
"search"
|
|
);
|
|
assert_eq!(
|
|
contents[2]["parts"][0]["functionResponse"]["name"]
|
|
.as_str()
|
|
.unwrap(),
|
|
"search"
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn test_no_tool_rounds_no_rewrite() {
|
|
// No tool rounds, no legacy data — no rewriting should happen
|
|
let body = serde_json::json!({
|
|
"project": "test",
|
|
"requestId": "test/1",
|
|
"request": {
|
|
"contents": [
|
|
{"role": "user", "parts": [{"text": "Hello"}]},
|
|
{"role": "model", "parts": [{"text": "Hi there!"}]},
|
|
],
|
|
"tools": [],
|
|
"generationConfig": {}
|
|
},
|
|
"model": "test"
|
|
});
|
|
|
|
let tool_ctx = ToolContext {
|
|
tools: Some(vec![serde_json::json!({
|
|
"functionDeclarations": [{
|
|
"name": "noop",
|
|
"description": "Does nothing",
|
|
"parameters": {"type": "OBJECT", "properties": {}}
|
|
}]
|
|
})]),
|
|
tool_config: None,
|
|
pending_results: vec![],
|
|
last_calls: vec![],
|
|
generation_params: None,
|
|
pending_image: None,
|
|
pending_user_text: String::new(),
|
|
tool_rounds: vec![],
|
|
};
|
|
|
|
let bytes = serde_json::to_vec(&body).unwrap();
|
|
let modified = modify_request(&bytes, Some(&tool_ctx)).unwrap();
|
|
let result: Value = serde_json::from_slice(&modified).unwrap();
|
|
let contents = result["request"]["contents"].as_array().unwrap();
|
|
|
|
// No rewriting — same number of turns
|
|
assert_eq!(contents.len(), 2);
|
|
assert_eq!(
|
|
contents[1]["parts"][0]["text"].as_str().unwrap(),
|
|
"Hi there!"
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn test_tool_rounds_append_when_no_model_turns() {
|
|
use crate::mitm::store::{CapturedFunctionCall, PendingToolResult, ToolRound};
|
|
|
|
// Simulate the real-world case: LS sends cascades with ONLY user messages.
|
|
// No model turns exist, so the rewrite approach finds nothing.
|
|
// The fallback should APPEND functionCall/functionResponse pairs.
|
|
let body = serde_json::json!({
|
|
"project": "test",
|
|
"requestId": "test/1",
|
|
"request": {
|
|
"contents": [
|
|
{"role": "user", "parts": [{"text": "hello"}]},
|
|
],
|
|
"tools": [],
|
|
"generationConfig": {}
|
|
},
|
|
"model": "test"
|
|
});
|
|
|
|
let tool_ctx = ToolContext {
|
|
tools: Some(vec![serde_json::json!({
|
|
"functionDeclarations": [{
|
|
"name": "web_search",
|
|
"description": "Search the web",
|
|
"parameters": {"type": "OBJECT", "properties": {"query": {"type": "STRING"}}}
|
|
}]
|
|
})]),
|
|
tool_config: None,
|
|
pending_results: vec![],
|
|
last_calls: vec![],
|
|
generation_params: None,
|
|
pending_image: None,
|
|
pending_user_text: String::new(),
|
|
tool_rounds: vec![ToolRound {
|
|
calls: vec![CapturedFunctionCall {
|
|
name: "web_search".to_string(),
|
|
args: serde_json::json!({"query": "rust news"}),
|
|
thought_signature: None,
|
|
captured_at: 0,
|
|
}],
|
|
results: vec![PendingToolResult {
|
|
name: "web_search".to_string(),
|
|
result: serde_json::json!({"results": "some results"}),
|
|
}],
|
|
}],
|
|
};
|
|
|
|
let bytes = serde_json::to_vec(&body).unwrap();
|
|
let modified = modify_request(&bytes, Some(&tool_ctx)).unwrap();
|
|
let result: Value = serde_json::from_slice(&modified).unwrap();
|
|
let contents = result["request"]["contents"].as_array().unwrap();
|
|
|
|
// Expected layout (tool rounds appended AFTER user message):
|
|
// [0] user: "hello" ← original
|
|
// [1] model: functionCall(web_search) ← appended after user
|
|
// [2] user: functionResponse(web_search) ← appended after functionCall
|
|
assert_eq!(contents.len(), 3, "should have 3 turns: user + fc + fr");
|
|
|
|
assert_eq!(contents[0]["role"].as_str().unwrap(), "user");
|
|
assert!(contents[0]["parts"][0]["text"]
|
|
.as_str()
|
|
.unwrap()
|
|
.contains("hello"));
|
|
|
|
assert_eq!(contents[1]["role"].as_str().unwrap(), "model");
|
|
assert_eq!(
|
|
contents[1]["parts"][0]["functionCall"]["name"]
|
|
.as_str()
|
|
.unwrap(),
|
|
"web_search"
|
|
);
|
|
|
|
assert_eq!(contents[2]["role"].as_str().unwrap(), "user");
|
|
assert_eq!(
|
|
contents[2]["parts"][0]["functionResponse"]["name"]
|
|
.as_str()
|
|
.unwrap(),
|
|
"web_search"
|
|
);
|
|
}
|
|
}
|
|
|
|
// ─── Response modification ──────────────────────────────────────────────────
|
|
|
|
/// Rewrite a parsed SSE JSON object: replace `functionCall` parts with text
|
|
/// placeholder and normalize `finishReason` to `STOP`.
|
|
///
|
|
/// Used by `ResponseRewriter` to hide tool calls from the LS.
|
|
fn rewrite_function_calls_in_response(json: &mut Value) -> bool {
|
|
let mut changed = false;
|
|
|
|
fn rewrite_candidates(candidates: &mut [Value]) -> bool {
|
|
let mut changed = false;
|
|
for candidate in candidates.iter_mut() {
|
|
if let Some(parts) = candidate
|
|
.pointer_mut("/content/parts")
|
|
.and_then(|v| v.as_array_mut())
|
|
{
|
|
for part in parts.iter_mut() {
|
|
if part.get("functionCall").is_some() {
|
|
*part = serde_json::json!({
|
|
"text": "Tool call completed. Awaiting external tool result."
|
|
});
|
|
changed = true;
|
|
}
|
|
}
|
|
}
|
|
if let Some(reason) = candidate.get("finishReason").and_then(|v| v.as_str()) {
|
|
if reason != "STOP" {
|
|
candidate["finishReason"] = Value::String("STOP".to_string());
|
|
changed = true;
|
|
}
|
|
}
|
|
}
|
|
changed
|
|
}
|
|
|
|
if let Some(candidates) = json.get_mut("candidates").and_then(|v| v.as_array_mut()) {
|
|
changed |= rewrite_candidates(candidates);
|
|
}
|
|
if let Some(candidates) = json
|
|
.pointer_mut("/response/candidates")
|
|
.and_then(|v| v.as_array_mut())
|
|
{
|
|
changed |= rewrite_candidates(candidates);
|
|
}
|
|
|
|
changed
|
|
}
|
|
|
|
// ─── ResponseRewriter ────────────────────────────────────────────────────────
|
|
|
|
/// Stateful line-buffered response rewriter.
|
|
///
|
|
/// `modify_response_chunk` is stateless per-TCP-chunk — if a `functionCall`
|
|
/// JSON event spans two reads, the quick `contains("functionCall")` check
|
|
/// fails and the raw bytes leak to the LS. This struct solves that by
|
|
/// accumulating raw response bytes and only forwarding complete
|
|
/// newline-terminated SSE lines, rewriting any that contain `functionCall`.
|
|
///
|
|
/// This mirrors exactly how `parse_streaming_chunk` / `StreamingAccumulator`
|
|
/// handles cross-chunk JSON reassembly.
|
|
#[derive(Debug, Default)]
|
|
pub struct ResponseRewriter {
|
|
/// Buffered data waiting for a complete `\n`-terminated line.
|
|
pending: String,
|
|
}
|
|
|
|
impl ResponseRewriter {
|
|
pub fn new() -> Self {
|
|
Self::default()
|
|
}
|
|
|
|
/// Feed raw response bytes, get back bytes safe to forward to the LS.
|
|
///
|
|
/// Complete lines are rewritten if they contain `functionCall`, then
|
|
/// returned. Partial lines stay buffered until the next `feed()` call.
|
|
pub fn feed(&mut self, chunk: &[u8]) -> Vec<u8> {
|
|
let text = String::from_utf8_lossy(chunk);
|
|
self.pending.push_str(&text);
|
|
|
|
let mut output = String::new();
|
|
|
|
// Extract all complete lines (terminated by \n)
|
|
while let Some(pos) = self.pending.find('\n') {
|
|
// Include the \n in the extracted line
|
|
let line = self.pending[..=pos].to_string();
|
|
self.pending = self.pending[pos + 1..].to_string();
|
|
|
|
// Check if this is a `data: {JSON}` SSE line containing functionCall
|
|
let trimmed = line.trim();
|
|
if trimmed.starts_with("data: {") && trimmed.contains("functionCall") {
|
|
// Extract JSON, rewrite, and rebuild the line
|
|
if let Some(data_start) = line.find("data: {") {
|
|
let json_start = data_start + 6; // skip "data: "
|
|
let json_str = line[json_start..].trim_end();
|
|
if let Ok(mut json) = serde_json::from_str::<serde_json::Value>(json_str) {
|
|
if rewrite_function_calls_in_response(&mut json) {
|
|
if let Ok(new_json) = serde_json::to_string(&json) {
|
|
let rewritten =
|
|
format!("{}data: {}\n", &line[..data_start], new_json);
|
|
info!("MITM: rewrote functionCall in response → text placeholder for LS (buffered)");
|
|
output.push_str(&rewritten);
|
|
continue;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
// Couldn't parse/rewrite — forward as-is
|
|
output.push_str(&line);
|
|
} else {
|
|
// Not a functionCall line — forward as-is
|
|
output.push_str(&line);
|
|
}
|
|
}
|
|
|
|
output.into_bytes()
|
|
}
|
|
|
|
/// Flush any remaining buffered data (call at end of response).
|
|
/// Rewrites if possible, otherwise forwards raw.
|
|
pub fn flush(&mut self) -> Vec<u8> {
|
|
if self.pending.is_empty() {
|
|
return vec![];
|
|
}
|
|
let remaining = std::mem::take(&mut self.pending);
|
|
|
|
// Try to rewrite if it contains functionCall
|
|
if remaining.contains("functionCall") {
|
|
if let Some(data_start) = remaining.find("data: {") {
|
|
let json_start = data_start + 6;
|
|
let json_str = remaining[json_start..].trim_end();
|
|
if let Ok(mut json) = serde_json::from_str::<serde_json::Value>(json_str) {
|
|
if rewrite_function_calls_in_response(&mut json) {
|
|
if let Ok(new_json) = serde_json::to_string(&json) {
|
|
let rewritten =
|
|
format!("{}data: {}", &remaining[..data_start], new_json);
|
|
info!("MITM: rewrote functionCall in flush → text placeholder for LS");
|
|
return rewritten.into_bytes();
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
remaining.into_bytes()
|
|
}
|
|
}
|