- Add GenerationParams struct to MitmStore for temperature, top_p, top_k, max_output_tokens, stop_sequences, frequency/presence_penalty - MITM modify_request injects params into request.generationConfig - All 3 endpoints (Completions, Responses, Gemini) store client params - Add usageMetadata to Gemini sync responses (promptTokenCount, candidatesTokenCount, totalTokenCount, thoughtsTokenCount) - Add generation param fields to GeminiRequest (temperature, topP, etc.) - Completions stream_options.include_usage emits final usage chunk - Completions reasoning_tokens in completion_tokens_details - Update endpoint gap analysis doc (all high-priority gaps resolved)
988 lines
38 KiB
Rust
988 lines
38 KiB
Rust
//! Request body modification for intercepted LLM API calls.
|
|
//!
|
|
//! Aggressively strips everything except identity and actual conversation
|
|
//! from the Gemini API request. No integrity checks exist on the request
|
|
//! body — Google validates OAuth, project, model, and JSON structure only.
|
|
|
|
use regex::Regex;
|
|
use serde_json::Value;
|
|
use tracing::info;
|
|
|
|
use super::store::{CapturedFunctionCall, PendingToolResult};
|
|
|
|
/// Strip ALL tool definitions.
|
|
/// Must be true: with tools present, the LS enters full agentic mode
|
|
/// (multi-turn tool calls, file searches, etc.) burning quota.
|
|
const STRIP_ALL_TOOLS: bool = true;
|
|
|
|
/// Context for tool injection during request modification.
|
|
/// Built from MitmStore data before calling modify_request.
|
|
pub struct ToolContext {
|
|
/// Gemini-format tool declarations (functionDeclarations).
|
|
pub tools: Option<Vec<Value>>,
|
|
/// Gemini-format toolConfig.
|
|
pub tool_config: Option<Value>,
|
|
/// Pending tool results to inject as functionResponse.
|
|
pub pending_results: Vec<PendingToolResult>,
|
|
/// Last captured function calls for history rewriting.
|
|
pub last_calls: Vec<CapturedFunctionCall>,
|
|
/// Client-specified generation parameters (temperature, top_p, etc.).
|
|
pub generation_params: Option<super::store::GenerationParams>,
|
|
}
|
|
|
|
/// Modify a streamGenerateContent request body in-place.
|
|
/// Returns the modified JSON bytes, or None if modification wasn't possible.
|
|
pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option<Vec<u8>> {
|
|
let mut json: Value = serde_json::from_slice(body).ok()?;
|
|
|
|
let original_size = body.len();
|
|
let mut changes: Vec<String> = Vec::new();
|
|
|
|
// ── 1. System instruction: keep ONLY <identity>, nuke everything else ──
|
|
if let Some(sys) = json
|
|
.pointer_mut("/request/systemInstruction/parts/0/text")
|
|
.and_then(|v| v.as_str())
|
|
.map(|s| s.to_string())
|
|
{
|
|
let original_len = sys.len();
|
|
|
|
// Extract <identity>...</identity> block
|
|
let identity = extract_xml_section(&sys, "identity");
|
|
|
|
if let Some(identity_text) = identity {
|
|
let mut new_sys = format!("<identity>\n{}\n</identity>", identity_text.trim());
|
|
|
|
// When no tools are available, explicitly tell the model not to attempt
|
|
// function calls. Without this, the model's training causes it to try
|
|
// calling tools from its identity context, resulting in MALFORMED_FUNCTION_CALL.
|
|
let has_tools = tool_ctx.as_ref().map_or(false, |ctx| ctx.tools.is_some());
|
|
if !has_tools {
|
|
new_sys.push_str("\n\nIMPORTANT: You have NO tools available. Do not attempt to call any functions or tools. Respond with text only.");
|
|
}
|
|
|
|
let stripped = original_len - new_sys.len();
|
|
if stripped > 0 {
|
|
changes.push(format!(
|
|
"system instruction: keep <identity> only ({original_len} → {} chars, -{stripped})",
|
|
new_sys.len()
|
|
));
|
|
json["request"]["systemInstruction"]["parts"][0]["text"] =
|
|
Value::String(new_sys);
|
|
}
|
|
} else {
|
|
// No identity tag found — clear the whole thing
|
|
changes.push(format!("system instruction: cleared ({original_len} chars)"));
|
|
json["request"]["systemInstruction"]["parts"][0]["text"] =
|
|
Value::String(String::new());
|
|
}
|
|
}
|
|
|
|
// ── 2. Content messages: keep only actual conversation turns ───────────
|
|
if let Some(contents) = json
|
|
.pointer_mut("/request/contents")
|
|
.and_then(|v| v.as_array_mut())
|
|
{
|
|
let before = contents.len();
|
|
|
|
// Remove messages that are pure Antigravity context injection
|
|
contents.retain(|msg| {
|
|
if let Some(text) = msg["parts"][0]["text"].as_str() {
|
|
// Strip user_information (OS, workspace paths)
|
|
if text.starts_with("<user_information>") {
|
|
return false;
|
|
}
|
|
// Strip user_rules / MEMORY blocks
|
|
if text.starts_with("<user_rules>") {
|
|
return false;
|
|
}
|
|
// Strip workflows
|
|
if text.starts_with("<workflows>") {
|
|
return false;
|
|
}
|
|
// Strip MCP servers block
|
|
if text.starts_with("<mcp_servers>") {
|
|
return false;
|
|
}
|
|
}
|
|
true
|
|
});
|
|
|
|
// For remaining messages, strip embedded metadata
|
|
for msg in contents.iter_mut() {
|
|
if let Some(text) = msg["parts"][0]["text"].as_str().map(|s| s.to_string()) {
|
|
let mut modified = text.clone();
|
|
|
|
// Strip conversation summaries block
|
|
if let Some(cleaned) = strip_between(&modified, "# Conversation History\n", "</conversation_summaries>") {
|
|
modified = cleaned;
|
|
}
|
|
|
|
// Strip <ADDITIONAL_METADATA> blocks (cursor pos, open files, etc.)
|
|
if let Some(cleaned) = strip_xml_section(&modified, "ADDITIONAL_METADATA") {
|
|
modified = cleaned;
|
|
}
|
|
|
|
// Strip <EPHEMERAL_MESSAGE> blocks
|
|
if let Some(cleaned) = strip_xml_section(&modified, "EPHEMERAL_MESSAGE") {
|
|
modified = cleaned;
|
|
}
|
|
|
|
// Strip "Step Id: N\n" prefixes
|
|
if modified.starts_with("Step Id:") {
|
|
if let Some(newline_pos) = modified.find('\n') {
|
|
modified = modified[newline_pos + 1..].to_string();
|
|
}
|
|
}
|
|
|
|
// Strip knowledge item blocks
|
|
if let Some(cleaned) = strip_between(&modified, "Here are the ", "</knowledge_item>") {
|
|
// Only strip if it's about knowledge items
|
|
if cleaned.len() < modified.len() && modified.contains("knowledge item") {
|
|
modified = cleaned;
|
|
}
|
|
}
|
|
|
|
// Clean up excessive whitespace from stripping
|
|
let modified = collapse_newlines(&modified);
|
|
|
|
if modified.len() < text.len() {
|
|
msg["parts"][0]["text"] = Value::String(modified);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Remove now-empty messages
|
|
contents.retain(|msg| {
|
|
if let Some(text) = msg["parts"][0]["text"].as_str() {
|
|
!text.trim().is_empty()
|
|
} else {
|
|
true
|
|
}
|
|
});
|
|
|
|
let removed = before - contents.len();
|
|
if removed > 0 {
|
|
changes.push(format!("remove {removed}/{before} content messages"));
|
|
}
|
|
}
|
|
|
|
// ── 3. Strip LS tools, inject client tools ─────────────────────────────
|
|
let mut has_custom_tools = false;
|
|
if STRIP_ALL_TOOLS {
|
|
if let Some(tools) = json
|
|
.pointer_mut("/request/tools")
|
|
.and_then(|v| v.as_array_mut())
|
|
{
|
|
let count = tools.len();
|
|
if count > 0 {
|
|
tools.clear();
|
|
changes.push(format!("strip all {count} LS tools"));
|
|
}
|
|
|
|
// Inject client-provided tools from ToolContext
|
|
if let Some(ref ctx) = tool_ctx {
|
|
if let Some(ref custom_tools) = ctx.tools {
|
|
for tool in custom_tools {
|
|
tools.push(tool.clone());
|
|
}
|
|
has_custom_tools = true;
|
|
changes.push(format!("inject {} custom tool group(s)", custom_tools.len()));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// ── 3a. When no tools remain, clean up all tool-related config ────────
|
|
// The LS sets toolConfig.functionCallingConfig.mode = "VALIDATED" which
|
|
// forces Google to attempt function calls even with an empty tools array,
|
|
// causing MALFORMED_FUNCTION_CALL in an infinite retry loop.
|
|
if STRIP_ALL_TOOLS && !has_custom_tools {
|
|
if let Some(req) = json.get_mut("request").and_then(|v| v.as_object_mut()) {
|
|
// Remove the empty tools array entirely
|
|
if req.get("tools").and_then(|v| v.as_array()).map_or(false, |a| a.is_empty()) {
|
|
req.remove("tools");
|
|
changes.push("remove empty tools array".to_string());
|
|
}
|
|
// Remove toolConfig (VALIDATED mode with no tools = MALFORMED_FUNCTION_CALL)
|
|
if req.remove("toolConfig").is_some() {
|
|
changes.push("remove toolConfig (no tools)".to_string());
|
|
}
|
|
}
|
|
}
|
|
|
|
// ── 3b. ALWAYS strip old functionCall/functionResponse from history ───
|
|
// Even when custom tools are injected, the LS history contains function
|
|
// call parts for LS-internal tools we stripped. Google rejects these as
|
|
// MALFORMED_FUNCTION_CALL because the referenced tools don't exist.
|
|
if STRIP_ALL_TOOLS {
|
|
if let Some(contents) = json
|
|
.pointer_mut("/request/contents")
|
|
.and_then(|v| v.as_array_mut())
|
|
{
|
|
let mut stripped_fc = 0usize;
|
|
for msg in contents.iter_mut() {
|
|
if let Some(parts) = msg.get_mut("parts").and_then(|v| v.as_array_mut()) {
|
|
let before = parts.len();
|
|
parts.retain(|part| {
|
|
!part.get("functionCall").is_some()
|
|
&& !part.get("functionResponse").is_some()
|
|
});
|
|
stripped_fc += before - parts.len();
|
|
}
|
|
}
|
|
// Remove messages that became empty after stripping function parts
|
|
contents.retain(|msg| {
|
|
msg.get("parts")
|
|
.and_then(|v| v.as_array())
|
|
.map_or(true, |parts| !parts.is_empty())
|
|
});
|
|
if stripped_fc > 0 {
|
|
changes.push(format!("strip {stripped_fc} functionCall/Response parts from history"));
|
|
}
|
|
}
|
|
}
|
|
|
|
// Inject toolConfig if provided
|
|
if let Some(ref ctx) = tool_ctx {
|
|
if let Some(ref config) = ctx.tool_config {
|
|
if let Some(req) = json.get_mut("request").and_then(|v| v.as_object_mut()) {
|
|
req.insert("toolConfig".to_string(), config.clone());
|
|
changes.push("inject toolConfig".to_string());
|
|
}
|
|
}
|
|
}
|
|
|
|
// ── 3b. Rewrite conversation history for tool results ────────────
|
|
if let Some(ref ctx) = tool_ctx {
|
|
if !ctx.pending_results.is_empty() && !ctx.last_calls.is_empty() {
|
|
if let Some(contents) = json
|
|
.pointer_mut("/request/contents")
|
|
.and_then(|v| v.as_array_mut())
|
|
{
|
|
// Find the model turn with our fake "Tool call completed" text and replace it
|
|
// with the actual functionCall parts
|
|
for msg in contents.iter_mut() {
|
|
if msg["role"].as_str() == Some("model") {
|
|
if let Some(text) = msg["parts"][0]["text"].as_str() {
|
|
if text.contains("Tool call completed") || text.contains("Awaiting external tool result") {
|
|
// Replace with functionCall parts
|
|
let fc_parts: Vec<Value> = ctx.last_calls.iter().map(|fc| {
|
|
serde_json::json!({
|
|
"functionCall": {
|
|
"name": fc.name,
|
|
"args": fc.args,
|
|
}
|
|
})
|
|
}).collect();
|
|
msg["parts"] = Value::Array(fc_parts);
|
|
changes.push("rewrite model turn with functionCall".to_string());
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Add functionResponse as a user turn before the last user message
|
|
let fn_response_parts: Vec<Value> = ctx.pending_results.iter().map(|r| {
|
|
serde_json::json!({
|
|
"functionResponse": {
|
|
"name": r.name,
|
|
"response": r.result,
|
|
}
|
|
})
|
|
}).collect();
|
|
let fn_response_turn = serde_json::json!({
|
|
"role": "user",
|
|
"parts": fn_response_parts,
|
|
});
|
|
|
|
// Insert before the last user message
|
|
let last_user_idx = contents.iter().rposition(|msg| {
|
|
msg["role"].as_str() == Some("user")
|
|
});
|
|
if let Some(idx) = last_user_idx {
|
|
contents.insert(idx, fn_response_turn);
|
|
} else {
|
|
contents.push(fn_response_turn);
|
|
}
|
|
changes.push(format!("inject {} functionResponse(s)", ctx.pending_results.len()));
|
|
}
|
|
}
|
|
}
|
|
|
|
// ── 4. Inject includeThoughts to capture thinking text ───────────────
|
|
// Without this flag, Google only reports thinking token counts
|
|
// but doesn't send the thinking text in SSE parts.
|
|
{
|
|
// Ensure request.generationConfig.thinkingConfig.includeThoughts = true
|
|
let request = json.get_mut("request").and_then(|v| v.as_object_mut());
|
|
if let Some(req) = request {
|
|
let gen_config = req
|
|
.entry("generationConfig")
|
|
.or_insert_with(|| serde_json::json!({}));
|
|
if let Some(gc) = gen_config.as_object_mut() {
|
|
let thinking_config = gc
|
|
.entry("thinkingConfig")
|
|
.or_insert_with(|| serde_json::json!({}));
|
|
if let Some(tc) = thinking_config.as_object_mut() {
|
|
if !tc.contains_key("includeThoughts") {
|
|
tc.insert("includeThoughts".to_string(), Value::Bool(true));
|
|
changes.push("inject includeThoughts".to_string());
|
|
}
|
|
}
|
|
}
|
|
} else {
|
|
// Not wrapped in request — try top-level (public API format)
|
|
let gen_config = json.as_object_mut().and_then(|o| {
|
|
Some(o.entry("generationConfig")
|
|
.or_insert_with(|| serde_json::json!({})))
|
|
});
|
|
if let Some(gc) = gen_config.and_then(|v| v.as_object_mut()) {
|
|
let thinking_config = gc
|
|
.entry("thinkingConfig")
|
|
.or_insert_with(|| serde_json::json!({}));
|
|
if let Some(tc) = thinking_config.as_object_mut() {
|
|
if !tc.contains_key("includeThoughts") {
|
|
tc.insert("includeThoughts".to_string(), Value::Bool(true));
|
|
changes.push("inject includeThoughts (top-level)".to_string());
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// ── 5. Inject client-specified generation parameters ──────────────────
|
|
// These override the LS defaults (which are typically absent or conservative).
|
|
// Google generationConfig fields: temperature, topP, topK, maxOutputTokens,
|
|
// stopSequences, frequencyPenalty, presencePenalty.
|
|
if let Some(ref ctx) = tool_ctx {
|
|
if let Some(ref gp) = ctx.generation_params {
|
|
// Find or create generationConfig (same path as above)
|
|
let gc = if let Some(req) = json.get_mut("request").and_then(|v| v.as_object_mut()) {
|
|
Some(req.entry("generationConfig")
|
|
.or_insert_with(|| serde_json::json!({})))
|
|
} else {
|
|
json.as_object_mut().map(|o| {
|
|
o.entry("generationConfig")
|
|
.or_insert_with(|| serde_json::json!({}))
|
|
})
|
|
};
|
|
|
|
if let Some(gc) = gc.and_then(|v| v.as_object_mut()) {
|
|
let mut injected: Vec<String> = Vec::new();
|
|
|
|
if let Some(t) = gp.temperature {
|
|
gc.insert("temperature".to_string(), serde_json::json!(t));
|
|
injected.push(format!("temperature={t}"));
|
|
}
|
|
if let Some(p) = gp.top_p {
|
|
gc.insert("topP".to_string(), serde_json::json!(p));
|
|
injected.push(format!("topP={p}"));
|
|
}
|
|
if let Some(k) = gp.top_k {
|
|
gc.insert("topK".to_string(), serde_json::json!(k));
|
|
injected.push(format!("topK={k}"));
|
|
}
|
|
if let Some(m) = gp.max_output_tokens {
|
|
gc.insert("maxOutputTokens".to_string(), serde_json::json!(m));
|
|
injected.push(format!("maxOutputTokens={m}"));
|
|
}
|
|
if let Some(ref seqs) = gp.stop_sequences {
|
|
gc.insert("stopSequences".to_string(), serde_json::json!(seqs));
|
|
injected.push(format!("stopSequences({})", seqs.len()));
|
|
}
|
|
if let Some(fp) = gp.frequency_penalty {
|
|
gc.insert("frequencyPenalty".to_string(), serde_json::json!(fp));
|
|
injected.push(format!("frequencyPenalty={fp}"));
|
|
}
|
|
if let Some(pp) = gp.presence_penalty {
|
|
gc.insert("presencePenalty".to_string(), serde_json::json!(pp));
|
|
injected.push(format!("presencePenalty={pp}"));
|
|
}
|
|
|
|
if !injected.is_empty() {
|
|
changes.push(format!("inject generationConfig: {}", injected.join(", ")));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if changes.is_empty() {
|
|
return None; // Nothing modified
|
|
}
|
|
|
|
let modified_bytes = serde_json::to_vec(&json).ok()?;
|
|
let saved = original_size as i64 - modified_bytes.len() as i64;
|
|
let pct = if original_size > 0 {
|
|
(saved as f64 / original_size as f64 * 100.0) as i32
|
|
} else {
|
|
0
|
|
};
|
|
|
|
info!(
|
|
original = original_size,
|
|
modified = modified_bytes.len(),
|
|
saved_bytes = saved,
|
|
saved_pct = pct,
|
|
"MITM: request modified [{}]",
|
|
changes.join(", ")
|
|
);
|
|
|
|
Some(modified_bytes)
|
|
}
|
|
|
|
/// Extract the inner text of an XML-style section.
|
|
fn extract_xml_section(text: &str, tag: &str) -> Option<String> {
|
|
let open = format!("<{tag}>");
|
|
let close = format!("</{tag}>");
|
|
let start = text.find(&open)?;
|
|
let end = text.find(&close)?;
|
|
let inner_start = start + open.len();
|
|
if inner_start >= end {
|
|
return None;
|
|
}
|
|
Some(text[inner_start..end].to_string())
|
|
}
|
|
|
|
/// Strip an XML-style section and return the modified text.
|
|
fn strip_xml_section(text: &str, tag: &str) -> Option<String> {
|
|
let open = format!("<{tag}>");
|
|
let close = format!("</{tag}>");
|
|
let start = text.find(&open)?;
|
|
let end = text.find(&close)?;
|
|
let end_pos = end + close.len();
|
|
Some(format!("{}{}", &text[..start], &text[end_pos..]))
|
|
}
|
|
|
|
/// Strip everything between two markers (inclusive of markers).
|
|
fn strip_between(text: &str, start_marker: &str, end_marker: &str) -> Option<String> {
|
|
let start = text.find(start_marker)?;
|
|
let end = text.find(end_marker)?;
|
|
let end_pos = end + end_marker.len();
|
|
// Skip any trailing whitespace after end marker
|
|
let rest = text[end_pos..].trim_start();
|
|
Some(format!("{}{}", &text[..start], rest))
|
|
}
|
|
|
|
/// Collapse 3+ consecutive newlines into 2.
|
|
fn collapse_newlines(text: &str) -> String {
|
|
let re = Regex::new(r"\n{3,}").unwrap();
|
|
re.replace_all(text, "\n\n").to_string()
|
|
}
|
|
|
|
/// Dechunk an HTTP chunked-encoded body into raw bytes.
|
|
pub fn dechunk(data: &[u8]) -> Vec<u8> {
|
|
let mut result = Vec::with_capacity(data.len());
|
|
let mut pos = 0;
|
|
|
|
while pos < data.len() {
|
|
let line_end = match data[pos..].windows(2).position(|w| w == b"\r\n") {
|
|
Some(p) => pos + p,
|
|
None => break,
|
|
};
|
|
|
|
let size_str = std::str::from_utf8(&data[pos..line_end])
|
|
.unwrap_or("")
|
|
.split(';')
|
|
.next()
|
|
.unwrap_or("")
|
|
.trim();
|
|
|
|
let chunk_size = match usize::from_str_radix(size_str, 16) {
|
|
Ok(0) => break,
|
|
Ok(n) => n,
|
|
Err(_) => break,
|
|
};
|
|
|
|
let data_start = line_end + 2;
|
|
let data_end = (data_start + chunk_size).min(data.len());
|
|
result.extend_from_slice(&data[data_start..data_end]);
|
|
|
|
pos = data_end + 2;
|
|
}
|
|
|
|
result
|
|
}
|
|
|
|
/// Re-encode data as a single HTTP chunk + terminal chunk.
|
|
pub fn rechunk(data: &[u8]) -> Vec<u8> {
|
|
let hex_size = format!("{:x}", data.len());
|
|
let mut result = Vec::with_capacity(hex_size.len() + 2 + data.len() + 2 + 5);
|
|
result.extend_from_slice(hex_size.as_bytes());
|
|
result.extend_from_slice(b"\r\n");
|
|
result.extend_from_slice(data);
|
|
result.extend_from_slice(b"\r\n0\r\n\r\n");
|
|
result
|
|
}
|
|
|
|
// ── OpenAI → Gemini format conversion ────────────────────────────────────────
|
|
|
|
/// Convert OpenAI tool definitions to Gemini functionDeclarations format.
|
|
///
|
|
/// OpenAI: `[{"type":"function","function":{"name":"X","description":"Y","parameters":{...}}}]`
|
|
/// Gemini: `[{"functionDeclarations":[{"name":"X","description":"Y","parameters":{...}}]}]`
|
|
pub fn openai_tools_to_gemini(tools: &[Value]) -> Vec<Value> {
|
|
let declarations: Vec<Value> = tools
|
|
.iter()
|
|
.filter(|t| t["type"].as_str() == Some("function"))
|
|
.filter_map(|t| {
|
|
let func = t.get("function")?;
|
|
let mut decl = serde_json::json!({
|
|
"name": func["name"],
|
|
"description": func["description"],
|
|
});
|
|
if let Some(params) = func.get("parameters") {
|
|
let cleaned = clean_schema_for_gemini(uppercase_types(params.clone()));
|
|
decl["parameters"] = cleaned;
|
|
}
|
|
Some(decl)
|
|
})
|
|
.collect();
|
|
|
|
if declarations.is_empty() {
|
|
return vec![];
|
|
}
|
|
|
|
vec![serde_json::json!({"functionDeclarations": declarations})]
|
|
}
|
|
|
|
/// Recursively strip JSON Schema fields that Google's Gemini API doesn't accept.
|
|
/// Known unsupported: $schema, additionalProperties, $ref, $defs, default, examples
|
|
fn clean_schema_for_gemini(mut val: Value) -> Value {
|
|
const STRIP_KEYS: &[&str] = &[
|
|
"$schema",
|
|
"additionalProperties",
|
|
"$ref",
|
|
"$defs",
|
|
"default",
|
|
"examples",
|
|
"title",
|
|
];
|
|
|
|
match &mut val {
|
|
Value::Object(map) => {
|
|
for key in STRIP_KEYS {
|
|
map.remove(*key);
|
|
}
|
|
let keys: Vec<String> = map.keys().cloned().collect();
|
|
for key in keys {
|
|
if let Some(v) = map.remove(&key) {
|
|
map.insert(key, clean_schema_for_gemini(v));
|
|
}
|
|
}
|
|
}
|
|
Value::Array(arr) => {
|
|
for v in arr.iter_mut() {
|
|
*v = clean_schema_for_gemini(std::mem::take(v));
|
|
}
|
|
}
|
|
_ => {}
|
|
}
|
|
val
|
|
}
|
|
|
|
/// Convert OpenAI tool_choice to Gemini toolConfig format.
|
|
///
|
|
/// OpenAI: "auto" | "required" | "none" | {"type":"function","function":{"name":"X"}}
|
|
/// Gemini: {"functionCallingConfig":{"mode":"AUTO|ANY|NONE","allowedFunctionNames":[...]}}
|
|
pub fn openai_tool_choice_to_gemini(choice: &Value) -> Value {
|
|
match choice {
|
|
Value::String(s) => match s.as_str() {
|
|
"auto" => serde_json::json!({"functionCallingConfig": {"mode": "AUTO"}}),
|
|
"required" => serde_json::json!({"functionCallingConfig": {"mode": "ANY"}}),
|
|
"none" => serde_json::json!({"functionCallingConfig": {"mode": "NONE"}}),
|
|
_ => serde_json::json!({"functionCallingConfig": {"mode": "AUTO"}}),
|
|
},
|
|
Value::Object(obj) => {
|
|
if let Some(name) = obj.get("function").and_then(|f| f["name"].as_str()) {
|
|
serde_json::json!({
|
|
"functionCallingConfig": {
|
|
"mode": "ANY",
|
|
"allowedFunctionNames": [name]
|
|
}
|
|
})
|
|
} else {
|
|
serde_json::json!({"functionCallingConfig": {"mode": "AUTO"}})
|
|
}
|
|
}
|
|
_ => serde_json::json!({"functionCallingConfig": {"mode": "AUTO"}}),
|
|
}
|
|
}
|
|
|
|
/// Recursively convert JSON Schema type strings to uppercase (Gemini format).
|
|
/// "object" → "OBJECT", "string" → "STRING", etc.
|
|
fn uppercase_types(mut val: Value) -> Value {
|
|
match &mut val {
|
|
Value::Object(map) => {
|
|
if let Some(t) = map
|
|
.get("type")
|
|
.and_then(|v| v.as_str())
|
|
.map(|s| s.to_uppercase())
|
|
{
|
|
map.insert("type".to_string(), Value::String(t));
|
|
}
|
|
let keys: Vec<String> = map.keys().cloned().collect();
|
|
for key in keys {
|
|
if let Some(v) = map.remove(&key) {
|
|
map.insert(key, uppercase_types(v));
|
|
}
|
|
}
|
|
}
|
|
Value::Array(arr) => {
|
|
for v in arr.iter_mut() {
|
|
*v = uppercase_types(std::mem::take(v));
|
|
}
|
|
}
|
|
_ => {}
|
|
}
|
|
val
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_dechunk_basic() {
|
|
let chunked = b"5\r\nhello\r\n6\r\n world\r\n0\r\n\r\n";
|
|
let result = dechunk(chunked);
|
|
assert_eq!(result, b"hello world");
|
|
}
|
|
|
|
#[test]
|
|
fn test_dechunk_single() {
|
|
let chunked = b"b\r\nhello world\r\n0\r\n\r\n";
|
|
let result = dechunk(chunked);
|
|
assert_eq!(result, b"hello world");
|
|
}
|
|
|
|
#[test]
|
|
fn test_rechunk() {
|
|
let data = b"hello world";
|
|
let chunked = rechunk(data);
|
|
let expected = b"b\r\nhello world\r\n0\r\n\r\n";
|
|
assert_eq!(chunked, expected);
|
|
}
|
|
|
|
#[test]
|
|
fn test_dechunk_rechunk_roundtrip() {
|
|
let original = b"5\r\nhello\r\n6\r\n world\r\n0\r\n\r\n";
|
|
let data = dechunk(original);
|
|
let rechunked = rechunk(&data);
|
|
let data2 = dechunk(&rechunked);
|
|
assert_eq!(data, data2);
|
|
}
|
|
|
|
#[test]
|
|
fn test_modify_strips_all_tools() {
|
|
let body = serde_json::json!({
|
|
"project": "test",
|
|
"requestId": "test/1",
|
|
"request": {
|
|
"contents": [{"role": "user", "parts": [{"text": "hello"}]}],
|
|
"tools": [
|
|
{"functionDeclarations": [{"name": "view_file", "description": "view", "parameters": {}}]},
|
|
{"functionDeclarations": [{"name": "browser_subagent", "description": "browse", "parameters": {}}]},
|
|
],
|
|
"generationConfig": {}
|
|
},
|
|
"model": "test"
|
|
});
|
|
|
|
let bytes = serde_json::to_vec(&body).unwrap();
|
|
let modified = modify_request(&bytes, None).unwrap();
|
|
let result: Value = serde_json::from_slice(&modified).unwrap();
|
|
|
|
// With no ToolContext, tools should be removed entirely
|
|
assert!(result["request"]["tools"].is_null() || result.pointer("/request/tools").is_none(),
|
|
"tools should be removed when no custom tools provided");
|
|
}
|
|
|
|
#[test]
|
|
fn test_modify_keeps_only_identity() {
|
|
let sys_text = "<identity>\nYou are a helpful AI.\n</identity>\n\n<tool_calling>\nUse absolute paths.\n</tool_calling>\n<web_application_development>\nlots of web dev stuff\n</web_application_development>\n<communication_style>\nbe helpful\n</communication_style>";
|
|
let body = serde_json::json!({
|
|
"project": "test",
|
|
"requestId": "test/1",
|
|
"request": {
|
|
"contents": [{"role": "user", "parts": [{"text": "hello"}]}],
|
|
"systemInstruction": {"parts": [{"text": sys_text}]},
|
|
"tools": [],
|
|
"generationConfig": {}
|
|
},
|
|
"model": "test"
|
|
});
|
|
|
|
let bytes = serde_json::to_vec(&body).unwrap();
|
|
let modified = modify_request(&bytes, None).unwrap();
|
|
let result: Value = serde_json::from_slice(&modified).unwrap();
|
|
|
|
let new_sys = result["request"]["systemInstruction"]["parts"][0]["text"]
|
|
.as_str()
|
|
.unwrap();
|
|
|
|
assert!(new_sys.contains("<identity>"));
|
|
assert!(new_sys.contains("You are a helpful AI."));
|
|
assert!(!new_sys.contains("tool_calling"));
|
|
assert!(!new_sys.contains("web_application_development"));
|
|
assert!(!new_sys.contains("communication_style"));
|
|
}
|
|
|
|
#[test]
|
|
fn test_modify_strips_context_messages() {
|
|
let body = serde_json::json!({
|
|
"project": "test",
|
|
"requestId": "test/1",
|
|
"request": {
|
|
"contents": [
|
|
{"role": "user", "parts": [{"text": "<user_information>\nLinux\n</user_information>"}]},
|
|
{"role": "user", "parts": [{"text": "<user_rules>\nno rules\n</user_rules>"}]},
|
|
{"role": "user", "parts": [{"text": "<workflows>\nsome workflows\n</workflows>"}]},
|
|
{"role": "user", "parts": [{"text": "Step Id: 0\n\n<USER_REQUEST>\nSay hello\n</USER_REQUEST>\n<ADDITIONAL_METADATA>\ncursor stuff\n</ADDITIONAL_METADATA>"}]},
|
|
{"role": "model", "parts": [{"text": "Hello!"}]},
|
|
],
|
|
"tools": [],
|
|
"generationConfig": {}
|
|
},
|
|
"model": "test"
|
|
});
|
|
|
|
let bytes = serde_json::to_vec(&body).unwrap();
|
|
let modified = modify_request(&bytes, None).unwrap();
|
|
let result: Value = serde_json::from_slice(&modified).unwrap();
|
|
|
|
let contents = result["request"]["contents"].as_array().unwrap();
|
|
// Should have removed user_information, user_rules, workflows (3 messages)
|
|
// Kept: USER_REQUEST message (with ADDITIONAL_METADATA stripped) + model response
|
|
assert_eq!(contents.len(), 2, "should keep only user request + model response");
|
|
|
|
// Check USER_REQUEST message had metadata stripped
|
|
let user_msg = contents[0]["parts"][0]["text"].as_str().unwrap();
|
|
assert!(user_msg.contains("Say hello"), "should keep user request");
|
|
assert!(!user_msg.contains("ADDITIONAL_METADATA"), "should strip metadata");
|
|
assert!(!user_msg.contains("cursor stuff"), "should strip cursor info");
|
|
assert!(!user_msg.starts_with("Step Id:"), "should strip step id");
|
|
|
|
// Model response kept intact
|
|
assert_eq!(contents[1]["parts"][0]["text"].as_str().unwrap(), "Hello!");
|
|
}
|
|
|
|
#[test]
|
|
fn test_extract_xml_section() {
|
|
let text = "before <identity>\nI am AI\n</identity> after";
|
|
let result = extract_xml_section(text, "identity").unwrap();
|
|
assert_eq!(result, "\nI am AI\n");
|
|
}
|
|
|
|
#[test]
|
|
fn test_strip_xml_section() {
|
|
let text = "before <META>\nstuff\n</META> after";
|
|
let result = strip_xml_section(text, "META").unwrap();
|
|
assert_eq!(result, "before after");
|
|
}
|
|
|
|
#[test]
|
|
fn test_strip_between() {
|
|
let text = "keep this # Conversation History\nlots of stuff\n</conversation_summaries>\nand this";
|
|
let result = strip_between(text, "# Conversation History\n", "</conversation_summaries>").unwrap();
|
|
assert_eq!(result, "keep this and this");
|
|
}
|
|
}
|
|
|
|
// ─── Response modification ──────────────────────────────────────────────────
|
|
|
|
/// Rewrite an SSE response chunk to replace `functionCall` parts with text,
|
|
#[allow(dead_code)]
|
|
/// so the LS doesn't see tool calls for tools it doesn't manage.
|
|
///
|
|
/// The MITM intercept layer has already captured the function call data
|
|
/// (via `parse_streaming_chunk`) before this function runs, so we're not
|
|
/// losing any information — just hiding it from the LS.
|
|
///
|
|
/// Handles HTTP chunked transfer encoding framing (size\r\n...data...\r\n).
|
|
///
|
|
/// Returns `Some(modified_bytes)` if the chunk was rewritten, `None` if no
|
|
/// change was needed.
|
|
pub fn modify_response_chunk(chunk: &[u8]) -> Option<Vec<u8>> {
|
|
let text = std::str::from_utf8(chunk).ok()?;
|
|
|
|
// Quick check — no point parsing if no functionCall present
|
|
if !text.contains("functionCall") {
|
|
return None;
|
|
}
|
|
|
|
// Strategy: find each `data: {json}` SSE event in the raw text (which may
|
|
// be wrapped in chunked encoding). Parse the JSON, rewrite functionCall
|
|
// parts, and rebuild the chunked frame with updated sizes.
|
|
|
|
// First, dechunk: extract SSE data lines from chunked encoding
|
|
// Chunked format: <hex-size>\r\n<data>\r\n
|
|
// We'll work on the whole text, finding "data: " prefixed JSON objects
|
|
let mut result = text.to_string();
|
|
let mut changed = false;
|
|
|
|
// Find all `data: {...}` patterns (SSE events with JSON)
|
|
// Use a simple approach: find "data: {" and match to the end of JSON
|
|
let mut search_from = 0;
|
|
while let Some(data_pos) = result[search_from..].find("data: {") {
|
|
let abs_pos = search_from + data_pos;
|
|
let json_start = abs_pos + 6; // skip "data: "
|
|
|
|
// Find the end of this JSON object by finding the matching closing brace
|
|
if let Some(json_end) = find_json_end(&result[json_start..]) {
|
|
let json_str = &result[json_start..json_start + json_end];
|
|
|
|
if json_str.contains("functionCall") {
|
|
if let Ok(mut json) = serde_json::from_str::<Value>(json_str) {
|
|
if rewrite_function_calls_in_response(&mut json) {
|
|
if let Ok(new_json) = serde_json::to_string(&json) {
|
|
// Replace the JSON in the result string
|
|
result.replace_range(json_start..json_start + json_end, &new_json);
|
|
changed = true;
|
|
info!("MITM: rewrote functionCall in response → text placeholder for LS");
|
|
search_from = json_start + new_json.len();
|
|
continue;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
search_from = json_start + json_end;
|
|
} else {
|
|
search_from = json_start;
|
|
}
|
|
}
|
|
|
|
if !changed {
|
|
return None;
|
|
}
|
|
|
|
// Rechunk: if the original was chunked, we need to recalculate chunk sizes
|
|
// The format is: <hex-size>\r\n<payload>\r\n
|
|
// We'll rebuild the chunked encoding from scratch
|
|
if text.contains("\r\n") && text.chars().next().map_or(false, |c| c.is_ascii_hexdigit()) {
|
|
// This looks like chunked encoding — rebuild it
|
|
// Extract the payload (everything between first \r\n and last \r\n)
|
|
let rechunked = rechunk_response(&result);
|
|
Some(rechunked.into_bytes())
|
|
} else {
|
|
Some(result.into_bytes())
|
|
}
|
|
}
|
|
|
|
#[allow(dead_code)]
|
|
/// Find the end of a JSON object starting at the given string.
|
|
/// Returns the index past the closing brace.
|
|
fn find_json_end(s: &str) -> Option<usize> {
|
|
let mut depth = 0i32;
|
|
let mut in_string = false;
|
|
let mut escape = false;
|
|
|
|
for (i, c) in s.char_indices() {
|
|
if escape {
|
|
escape = false;
|
|
continue;
|
|
}
|
|
if c == '\\' && in_string {
|
|
escape = true;
|
|
continue;
|
|
}
|
|
if c == '"' {
|
|
in_string = !in_string;
|
|
continue;
|
|
}
|
|
if in_string {
|
|
continue;
|
|
}
|
|
if c == '{' {
|
|
depth += 1;
|
|
} else if c == '}' {
|
|
depth -= 1;
|
|
if depth == 0 {
|
|
return Some(i + 1);
|
|
}
|
|
}
|
|
}
|
|
None
|
|
}
|
|
|
|
#[allow(dead_code)]
|
|
/// Rebuild chunked encoding from a modified response body.
|
|
/// Takes the full text (which contains old chunk sizes) and rebuilds
|
|
/// with correct sizes.
|
|
fn rechunk_response(text: &str) -> String {
|
|
// Extract the actual SSE data lines (skip chunk size lines)
|
|
let mut payload = String::new();
|
|
for line in text.split('\n') {
|
|
let trimmed = line.trim_end_matches('\r');
|
|
// Skip lines that are purely hex chunk sizes
|
|
if trimmed.is_empty() {
|
|
continue;
|
|
}
|
|
if trimmed.chars().all(|c| c.is_ascii_hexdigit()) && !trimmed.is_empty() {
|
|
continue;
|
|
}
|
|
// Skip "0" (chunked terminator)
|
|
if trimmed == "0" {
|
|
continue;
|
|
}
|
|
payload.push_str(line);
|
|
if !line.ends_with('\n') {
|
|
payload.push('\n');
|
|
}
|
|
}
|
|
|
|
// Wrap in a single chunk
|
|
let payload_bytes = payload.as_bytes();
|
|
format!("{:x}\r\n{}\r\n", payload_bytes.len(), payload)
|
|
}
|
|
|
|
/// Rewrite a parsed SSE JSON object: replace `functionCall` parts with
|
|
/// text placeholder and change `finishReason` from `MALFORMED_FUNCTION_CALL`
|
|
/// or any non-STOP reason to `STOP`.
|
|
///
|
|
/// Handles both Gemini public API format (`{"candidates":[...]}`) and
|
|
/// internal LS format (`{"response":{"candidates":[...]}}`).
|
|
#[allow(dead_code)]
|
|
fn rewrite_function_calls_in_response(json: &mut Value) -> bool {
|
|
let mut changed = false;
|
|
|
|
// Helper to rewrite candidates array in-place
|
|
fn rewrite_candidates(candidates: &mut Vec<Value>) -> bool {
|
|
let mut changed = false;
|
|
for candidate in candidates.iter_mut() {
|
|
if let Some(parts) = candidate
|
|
.pointer_mut("/content/parts")
|
|
.and_then(|v| v.as_array_mut())
|
|
{
|
|
for part in parts.iter_mut() {
|
|
if part.get("functionCall").is_some() {
|
|
*part = serde_json::json!({
|
|
"text": "Tool call completed. Awaiting external tool result."
|
|
});
|
|
changed = true;
|
|
}
|
|
}
|
|
}
|
|
if let Some(reason) = candidate.get("finishReason").and_then(|v| v.as_str()) {
|
|
if reason != "STOP" {
|
|
candidate["finishReason"] = Value::String("STOP".to_string());
|
|
changed = true;
|
|
}
|
|
}
|
|
}
|
|
changed
|
|
}
|
|
|
|
// Try direct "candidates" first
|
|
if let Some(candidates) = json.get_mut("candidates").and_then(|v| v.as_array_mut()) {
|
|
changed |= rewrite_candidates(candidates);
|
|
}
|
|
|
|
// Try nested "response.candidates"
|
|
if let Some(candidates) = json.pointer_mut("/response/candidates").and_then(|v| v.as_array_mut()) {
|
|
changed |= rewrite_candidates(candidates);
|
|
}
|
|
|
|
changed
|
|
}
|