fix: multi-round tool history rewrite and finishReason handling
- Add ToolRound struct to pair function calls with results per-round - Replace single-match history rewrite (broke after first round) with multi-round loop that rewrites ALL placeholder model turns - Fix tool result name fallback: use positional index instead of always picking the first call - Set is_complete for any finishReason (FUNCTION_CALL, MAX_TOKENS, etc.) not just STOP — prevents response_complete flag from never being set - Legacy fallback: responses.rs path (single-round via last_calls + pending_results) still works when tool_rounds is empty - Add tests: multi-round rewrite, single-round legacy, no-op, and FUNCTION_CALL/MAX_TOKENS finishReason handling
This commit is contained in:
@@ -16,7 +16,7 @@ use super::polling::{
|
|||||||
use super::types::*;
|
use super::types::*;
|
||||||
use super::util::{err_response, now_unix, upstream_err_response};
|
use super::util::{err_response, now_unix, upstream_err_response};
|
||||||
use super::AppState;
|
use super::AppState;
|
||||||
use crate::mitm::store::{CapturedFunctionCall, PendingToolResult};
|
use crate::mitm::store::{CapturedFunctionCall, PendingToolResult, ToolRound};
|
||||||
|
|
||||||
/// Extract a conversation/session ID from a flexible JSON value.
|
/// Extract a conversation/session ID from a flexible JSON value.
|
||||||
/// Accepts a plain string or an object with an "id" field.
|
/// Accepts a plain string or an object with an "id" field.
|
||||||
@@ -229,18 +229,25 @@ pub(crate) async fn handle_completions(
|
|||||||
// When OpenCode sends back tool results, the messages array contains:
|
// When OpenCode sends back tool results, the messages array contains:
|
||||||
// 1. assistant message with tool_calls (the model's previous function calls)
|
// 1. assistant message with tool_calls (the model's previous function calls)
|
||||||
// 2. tool messages with results (the executed tool outputs)
|
// 2. tool messages with results (the executed tool outputs)
|
||||||
// We need to store these so modify_request can rewrite the LS's
|
// We build ToolRounds: each round pairs one assistant's tool_calls with
|
||||||
// conversation history with proper functionCall/functionResponse parts
|
// the subsequent tool result messages. This enables correct per-turn
|
||||||
// instead of the placeholder "Tool call completed" text.
|
// history rewriting for multi-step tool use.
|
||||||
{
|
{
|
||||||
let mut last_calls: Vec<CapturedFunctionCall> = Vec::new();
|
let mut rounds: Vec<ToolRound> = Vec::new();
|
||||||
let mut pending_results: Vec<PendingToolResult> = Vec::new();
|
let mut current_round: Option<ToolRound> = None;
|
||||||
|
|
||||||
for msg in &body.messages {
|
for msg in &body.messages {
|
||||||
match msg.role.as_str() {
|
match msg.role.as_str() {
|
||||||
"assistant" => {
|
"assistant" => {
|
||||||
// Extract function calls from assistant's tool_calls
|
// Finalize any open round
|
||||||
|
if let Some(round) = current_round.take() {
|
||||||
|
if !round.calls.is_empty() {
|
||||||
|
rounds.push(round);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Start new round if this assistant has tool_calls
|
||||||
if let Some(ref tool_calls) = msg.tool_calls {
|
if let Some(ref tool_calls) = msg.tool_calls {
|
||||||
|
let mut calls = Vec::new();
|
||||||
for tc in tool_calls {
|
for tc in tool_calls {
|
||||||
if let Some(func) = tc.get("function") {
|
if let Some(func) = tc.get("function") {
|
||||||
let name = func["name"].as_str().unwrap_or("unknown").to_string();
|
let name = func["name"].as_str().unwrap_or("unknown").to_string();
|
||||||
@@ -254,7 +261,7 @@ pub(crate) async fn handle_completions(
|
|||||||
state.mitm_store.register_call_id(call_id, name.clone()).await;
|
state.mitm_store.register_call_id(call_id, name.clone()).await;
|
||||||
}
|
}
|
||||||
|
|
||||||
last_calls.push(CapturedFunctionCall {
|
calls.push(CapturedFunctionCall {
|
||||||
name,
|
name,
|
||||||
args,
|
args,
|
||||||
captured_at: std::time::SystemTime::now()
|
captured_at: std::time::SystemTime::now()
|
||||||
@@ -264,21 +271,31 @@ pub(crate) async fn handle_completions(
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if !calls.is_empty() {
|
||||||
|
current_round = Some(ToolRound {
|
||||||
|
calls,
|
||||||
|
results: Vec::new(),
|
||||||
|
});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
"tool" => {
|
"tool" => {
|
||||||
// Extract tool results
|
|
||||||
let text = extract_message_text(&msg.content);
|
let text = extract_message_text(&msg.content);
|
||||||
if let Some(ref call_id) = msg.tool_call_id {
|
if let Some(ref call_id) = msg.tool_call_id {
|
||||||
// Look up function name from call_id
|
// Look up function name from call_id, fall back to
|
||||||
|
// positional index within the current round's calls
|
||||||
|
let result_index = current_round
|
||||||
|
.as_ref()
|
||||||
|
.map(|r| r.results.len())
|
||||||
|
.unwrap_or(0);
|
||||||
let name = state
|
let name = state
|
||||||
.mitm_store
|
.mitm_store
|
||||||
.lookup_call_id(call_id)
|
.lookup_call_id(call_id)
|
||||||
.await
|
.await
|
||||||
.unwrap_or_else(|| {
|
.unwrap_or_else(|| {
|
||||||
// Fallback: try to find the name from last_calls by position
|
current_round
|
||||||
last_calls
|
.as_ref()
|
||||||
.first()
|
.and_then(|r| r.calls.get(result_index))
|
||||||
.map(|fc| fc.name.clone())
|
.map(|fc| fc.name.clone())
|
||||||
.unwrap_or_else(|| "unknown_function".to_string())
|
.unwrap_or_else(|| "unknown_function".to_string())
|
||||||
});
|
});
|
||||||
@@ -286,35 +303,43 @@ pub(crate) async fn handle_completions(
|
|||||||
let result_value = serde_json::from_str::<serde_json::Value>(&text)
|
let result_value = serde_json::from_str::<serde_json::Value>(&text)
|
||||||
.unwrap_or_else(|_| serde_json::json!({"result": text}));
|
.unwrap_or_else(|_| serde_json::json!({"result": text}));
|
||||||
|
|
||||||
pending_results.push(PendingToolResult {
|
if let Some(ref mut round) = current_round {
|
||||||
|
round.results.push(PendingToolResult {
|
||||||
name,
|
name,
|
||||||
result: result_value,
|
result: result_value,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
_ => {}
|
}
|
||||||
|
_ => {
|
||||||
|
// Any other role (user, system) finalizes the current round
|
||||||
|
if let Some(round) = current_round.take() {
|
||||||
|
if !round.calls.is_empty() {
|
||||||
|
rounds.push(round);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Finalize last round
|
||||||
|
if let Some(round) = current_round.take() {
|
||||||
|
if !round.calls.is_empty() {
|
||||||
|
rounds.push(round);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !last_calls.is_empty() {
|
if !rounds.is_empty() {
|
||||||
info!(
|
info!(
|
||||||
count = last_calls.len(),
|
round_count = rounds.len(),
|
||||||
tools = ?last_calls.iter().map(|c| &c.name).collect::<Vec<_>>(),
|
calls = ?rounds.iter().map(|r| r.calls.iter().map(|c| &c.name).collect::<Vec<_>>()).collect::<Vec<_>>(),
|
||||||
"Completions: stored last function calls for MITM history rewrite"
|
"Completions: stored {} tool round(s) for MITM history rewrite",
|
||||||
|
rounds.len(),
|
||||||
);
|
);
|
||||||
state.mitm_store.set_last_function_calls(last_calls).await;
|
// Also set last_function_calls from the latest round for proxy.rs recording compat
|
||||||
|
if let Some(last_round) = rounds.last() {
|
||||||
|
state.mitm_store.set_last_function_calls(last_round.calls.clone()).await;
|
||||||
}
|
}
|
||||||
|
state.mitm_store.set_tool_rounds(rounds).await;
|
||||||
if !pending_results.is_empty() {
|
|
||||||
info!(
|
|
||||||
count = pending_results.len(),
|
|
||||||
tools = ?pending_results.iter().map(|r| &r.name).collect::<Vec<_>>(),
|
|
||||||
"Completions: stored tool results for MITM injection"
|
|
||||||
);
|
|
||||||
for result in pending_results {
|
|
||||||
state.mitm_store.add_tool_result(result).await;
|
|
||||||
}
|
|
||||||
// Clear awaiting flag — we have the results now
|
|
||||||
state.mitm_store.clear_awaiting_tool_result();
|
state.mitm_store.clear_awaiting_tool_result();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -153,12 +153,10 @@ impl StreamingAccumulator {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Check for completion
|
// Check for completion — any finishReason means response is done
|
||||||
if let Some(reason) = candidate["finishReason"].as_str() {
|
if let Some(reason) = candidate["finishReason"].as_str() {
|
||||||
self.stop_reason = Some(reason.to_string());
|
self.stop_reason = Some(reason.to_string());
|
||||||
if reason == "STOP" {
|
|
||||||
self.is_complete = true;
|
self.is_complete = true;
|
||||||
}
|
|
||||||
// Log non-STOP finish reasons
|
// Log non-STOP finish reasons
|
||||||
if reason != "STOP" {
|
if reason != "STOP" {
|
||||||
info!(finish_reason = reason, "MITM: non-STOP finish reason");
|
info!(finish_reason = reason, "MITM: non-STOP finish reason");
|
||||||
@@ -589,4 +587,30 @@ data: {"response": {"candidates": [{"content": {"role": "model","parts": [{"text
|
|||||||
);
|
);
|
||||||
assert_eq!(acc.stop_reason, Some("STOP".to_string()));
|
assert_eq!(acc.stop_reason, Some("STOP".to_string()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_function_call_finish_reason_sets_complete() {
|
||||||
|
let mut acc = StreamingAccumulator::new();
|
||||||
|
|
||||||
|
let event = "data: {\"response\": {\"candidates\": [{\"content\": {\"role\": \"model\", \"parts\": [{\"functionCall\": {\"name\": \"read_file\", \"args\": {\"path\": \"/foo\"}}}]}, \"finishReason\": \"FUNCTION_CALL\"}], \"usageMetadata\": {\"promptTokenCount\": 50, \"candidatesTokenCount\": 5, \"totalTokenCount\": 55}, \"modelVersion\": \"gemini-3-flash\"}}\n";
|
||||||
|
parse_streaming_chunk(event, &mut acc);
|
||||||
|
|
||||||
|
assert!(acc.is_complete, "FUNCTION_CALL finishReason should set is_complete");
|
||||||
|
assert_eq!(acc.stop_reason, Some("FUNCTION_CALL".to_string()));
|
||||||
|
assert_eq!(acc.function_calls.len(), 1);
|
||||||
|
assert_eq!(acc.function_calls[0].name, "read_file");
|
||||||
|
assert_eq!(acc.output_tokens, 5);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_max_tokens_finish_reason_sets_complete() {
|
||||||
|
let mut acc = StreamingAccumulator::new();
|
||||||
|
|
||||||
|
let event = "data: {\"response\": {\"candidates\": [{\"content\": {\"role\": \"model\", \"parts\": [{\"text\": \"truncated...\"}]}, \"finishReason\": \"MAX_TOKENS\"}], \"usageMetadata\": {\"promptTokenCount\": 50, \"candidatesTokenCount\": 100, \"totalTokenCount\": 150}}}\n";
|
||||||
|
parse_streaming_chunk(event, &mut acc);
|
||||||
|
|
||||||
|
assert!(acc.is_complete, "MAX_TOKENS finishReason should set is_complete");
|
||||||
|
assert_eq!(acc.stop_reason, Some("MAX_TOKENS".to_string()));
|
||||||
|
assert_eq!(acc.response_text, "truncated...");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ use regex::Regex;
|
|||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use tracing::info;
|
use tracing::info;
|
||||||
|
|
||||||
use super::store::{CapturedFunctionCall, PendingImage, PendingToolResult};
|
use super::store::{CapturedFunctionCall, PendingImage, PendingToolResult, ToolRound};
|
||||||
|
|
||||||
/// Strip ALL tool definitions.
|
/// Strip ALL tool definitions.
|
||||||
/// Must be true: with tools present, the LS enters full agentic mode
|
/// Must be true: with tools present, the LS enters full agentic mode
|
||||||
@@ -30,6 +30,9 @@ pub struct ToolContext {
|
|||||||
pub generation_params: Option<super::store::GenerationParams>,
|
pub generation_params: Option<super::store::GenerationParams>,
|
||||||
/// Pending image to inject as inlineData in the user message.
|
/// Pending image to inject as inlineData in the user message.
|
||||||
pub pending_image: Option<PendingImage>,
|
pub pending_image: Option<PendingImage>,
|
||||||
|
/// Multi-round tool call history. Each entry is a (calls, results) pair
|
||||||
|
/// from one round of tool use. Preferred over last_calls/pending_results.
|
||||||
|
pub tool_rounds: Vec<ToolRound>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Modify a streamGenerateContent request body in-place.
|
/// Modify a streamGenerateContent request body in-place.
|
||||||
@@ -355,22 +358,53 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option<Vec
|
|||||||
|
|
||||||
// ── 3b. Rewrite conversation history for tool results ────────────
|
// ── 3b. Rewrite conversation history for tool results ────────────
|
||||||
if let Some(ref ctx) = tool_ctx {
|
if let Some(ref ctx) = tool_ctx {
|
||||||
if !ctx.pending_results.is_empty() && !ctx.last_calls.is_empty() {
|
// Prefer multi-round tool_rounds (set by completions handler) over
|
||||||
|
// legacy last_calls/pending_results (set by responses handler).
|
||||||
|
let rounds = if !ctx.tool_rounds.is_empty() {
|
||||||
|
ctx.tool_rounds.clone()
|
||||||
|
} else if !ctx.pending_results.is_empty() && !ctx.last_calls.is_empty() {
|
||||||
|
// Legacy single-round: wrap in one ToolRound
|
||||||
|
vec![ToolRound {
|
||||||
|
calls: ctx.last_calls.clone(),
|
||||||
|
results: ctx.pending_results.clone(),
|
||||||
|
}]
|
||||||
|
} else {
|
||||||
|
vec![]
|
||||||
|
};
|
||||||
|
|
||||||
|
if !rounds.is_empty() {
|
||||||
if let Some(contents) = json
|
if let Some(contents) = json
|
||||||
.pointer_mut("/request/contents")
|
.pointer_mut("/request/contents")
|
||||||
.and_then(|v| v.as_array_mut())
|
.and_then(|v| v.as_array_mut())
|
||||||
{
|
{
|
||||||
// Find the model turn with our fake "Tool call completed" text and replace it
|
// Phase 1: find ALL model turns with placeholder text, pair with rounds
|
||||||
// with the actual functionCall parts
|
let mut rewrites: Vec<(usize, usize)> = Vec::new(); // (content_index, round_index)
|
||||||
for msg in contents.iter_mut() {
|
let mut round_idx = 0;
|
||||||
|
for (i, msg) in contents.iter().enumerate() {
|
||||||
|
if round_idx >= rounds.len() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
if msg["role"].as_str() == Some("model") {
|
if msg["role"].as_str() == Some("model") {
|
||||||
if let Some(text) = msg["parts"][0]["text"].as_str() {
|
if let Some(text) = msg["parts"][0]["text"].as_str() {
|
||||||
if text.contains("Tool call completed")
|
if text.contains("Tool call completed")
|
||||||
|| text.contains("Awaiting external tool result")
|
|| text.contains("Awaiting external tool result")
|
||||||
{
|
{
|
||||||
// Replace with functionCall parts
|
rewrites.push((i, round_idx));
|
||||||
let fc_parts: Vec<Value> = ctx
|
round_idx += 1;
|
||||||
.last_calls
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Phase 2: apply rewrites (reverse order for stable indices during insertion)
|
||||||
|
let mut insert_offset = 0;
|
||||||
|
for (content_idx, round_idx) in &rewrites {
|
||||||
|
let actual_idx = *content_idx + insert_offset;
|
||||||
|
let round = &rounds[*round_idx];
|
||||||
|
|
||||||
|
// Replace model turn with functionCall parts
|
||||||
|
let fc_parts: Vec<Value> = round
|
||||||
|
.calls
|
||||||
.iter()
|
.iter()
|
||||||
.map(|fc| {
|
.map(|fc| {
|
||||||
serde_json::json!({
|
serde_json::json!({
|
||||||
@@ -381,17 +415,12 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option<Vec
|
|||||||
})
|
})
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
msg["parts"] = Value::Array(fc_parts);
|
contents[actual_idx]["parts"] = Value::Array(fc_parts);
|
||||||
changes.push("rewrite model turn with functionCall".to_string());
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add functionResponse as a user turn before the last user message
|
// Inject functionResponse user turn right after
|
||||||
let fn_response_parts: Vec<Value> = ctx
|
if !round.results.is_empty() {
|
||||||
.pending_results
|
let fr_parts: Vec<Value> = round
|
||||||
|
.results
|
||||||
.iter()
|
.iter()
|
||||||
.map(|r| {
|
.map(|r| {
|
||||||
serde_json::json!({
|
serde_json::json!({
|
||||||
@@ -402,27 +431,26 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option<Vec
|
|||||||
})
|
})
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
let fn_response_turn = serde_json::json!({
|
contents.insert(
|
||||||
|
actual_idx + 1,
|
||||||
|
serde_json::json!({
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"parts": fn_response_parts,
|
"parts": fr_parts,
|
||||||
});
|
}),
|
||||||
|
);
|
||||||
// Insert before the last user message
|
insert_offset += 1;
|
||||||
let last_user_idx = contents
|
|
||||||
.iter()
|
|
||||||
.rposition(|msg| msg["role"].as_str() == Some("user"));
|
|
||||||
if let Some(idx) = last_user_idx {
|
|
||||||
contents.insert(idx, fn_response_turn);
|
|
||||||
} else {
|
|
||||||
contents.push(fn_response_turn);
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !rewrites.is_empty() {
|
||||||
changes.push(format!(
|
changes.push(format!(
|
||||||
"inject {} functionResponse(s)",
|
"rewrite {} tool round(s) in history",
|
||||||
ctx.pending_results.len()
|
rewrites.len()
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ── 4. Inject includeThoughts to capture thinking text ───────────────
|
// ── 4. Inject includeThoughts to capture thinking text ───────────────
|
||||||
// Without this flag, Google only reports thinking token counts
|
// Without this flag, Google only reports thinking token counts
|
||||||
@@ -992,6 +1020,222 @@ mod tests {
|
|||||||
.unwrap();
|
.unwrap();
|
||||||
assert_eq!(result, "keep this and this");
|
assert_eq!(result, "keep this and this");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_multi_round_history_rewrite() {
|
||||||
|
use crate::mitm::store::{CapturedFunctionCall, PendingToolResult, ToolRound};
|
||||||
|
|
||||||
|
// Simulate 2 rounds of tool use in LS history:
|
||||||
|
// user → model("Tool call completed") → user(text) → model("Tool call completed") → user(text)
|
||||||
|
let body = serde_json::json!({
|
||||||
|
"project": "test",
|
||||||
|
"requestId": "test/1",
|
||||||
|
"request": {
|
||||||
|
"contents": [
|
||||||
|
{"role": "user", "parts": [{"text": "Read foo and write to bar"}]},
|
||||||
|
{"role": "model", "parts": [{"text": "Tool call completed. Awaiting external tool result."}]},
|
||||||
|
{"role": "user", "parts": [{"text": "[Tool result: file contents here]"}]},
|
||||||
|
{"role": "model", "parts": [{"text": "Tool call completed. Awaiting external tool result."}]},
|
||||||
|
{"role": "user", "parts": [{"text": "[Tool result: write success]"}]},
|
||||||
|
],
|
||||||
|
"tools": [],
|
||||||
|
"generationConfig": {}
|
||||||
|
},
|
||||||
|
"model": "test"
|
||||||
|
});
|
||||||
|
|
||||||
|
let tool_ctx = ToolContext {
|
||||||
|
tools: Some(vec![serde_json::json!({
|
||||||
|
"functionDeclarations": [{
|
||||||
|
"name": "read_file",
|
||||||
|
"description": "Read a file",
|
||||||
|
"parameters": {"type": "OBJECT", "properties": {"path": {"type": "STRING"}}}
|
||||||
|
}, {
|
||||||
|
"name": "write_file",
|
||||||
|
"description": "Write a file",
|
||||||
|
"parameters": {"type": "OBJECT", "properties": {"path": {"type": "STRING"}, "content": {"type": "STRING"}}}
|
||||||
|
}]
|
||||||
|
})]),
|
||||||
|
tool_config: None,
|
||||||
|
pending_results: vec![],
|
||||||
|
last_calls: vec![],
|
||||||
|
generation_params: None,
|
||||||
|
pending_image: None,
|
||||||
|
tool_rounds: vec![
|
||||||
|
ToolRound {
|
||||||
|
calls: vec![CapturedFunctionCall {
|
||||||
|
name: "read_file".to_string(),
|
||||||
|
args: serde_json::json!({"path": "/foo"}),
|
||||||
|
captured_at: 0,
|
||||||
|
}],
|
||||||
|
results: vec![PendingToolResult {
|
||||||
|
name: "read_file".to_string(),
|
||||||
|
result: serde_json::json!({"content": "file contents here"}),
|
||||||
|
}],
|
||||||
|
},
|
||||||
|
ToolRound {
|
||||||
|
calls: vec![CapturedFunctionCall {
|
||||||
|
name: "write_file".to_string(),
|
||||||
|
args: serde_json::json!({"path": "/bar", "content": "data"}),
|
||||||
|
captured_at: 0,
|
||||||
|
}],
|
||||||
|
results: vec![PendingToolResult {
|
||||||
|
name: "write_file".to_string(),
|
||||||
|
result: serde_json::json!({"result": "ok"}),
|
||||||
|
}],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
};
|
||||||
|
|
||||||
|
let bytes = serde_json::to_vec(&body).unwrap();
|
||||||
|
let modified = modify_request(&bytes, Some(&tool_ctx)).unwrap();
|
||||||
|
let result: Value = serde_json::from_slice(&modified).unwrap();
|
||||||
|
let contents = result["request"]["contents"].as_array().unwrap();
|
||||||
|
|
||||||
|
// Expected layout after rewrite:
|
||||||
|
// [0] user: "Read foo..."
|
||||||
|
// [1] model: functionCall(read_file) (was "Tool call completed")
|
||||||
|
// [2] user: functionResponse(read_file) (injected)
|
||||||
|
// [3] user: "[Tool result: file contents]" (original LS turn)
|
||||||
|
// [4] model: functionCall(write_file) (was "Tool call completed")
|
||||||
|
// [5] user: functionResponse(write_file) (injected)
|
||||||
|
// [6] user: "[Tool result: write success]" (original LS turn)
|
||||||
|
assert_eq!(contents.len(), 7, "should have 7 turns (5 original + 2 injected)");
|
||||||
|
|
||||||
|
// Check round 1: model turn rewritten to functionCall
|
||||||
|
assert_eq!(
|
||||||
|
contents[1]["parts"][0]["functionCall"]["name"].as_str().unwrap(),
|
||||||
|
"read_file"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
contents[1]["parts"][0]["functionCall"]["args"]["path"].as_str().unwrap(),
|
||||||
|
"/foo"
|
||||||
|
);
|
||||||
|
// Check round 1: functionResponse injected
|
||||||
|
assert_eq!(
|
||||||
|
contents[2]["role"].as_str().unwrap(),
|
||||||
|
"user"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
contents[2]["parts"][0]["functionResponse"]["name"].as_str().unwrap(),
|
||||||
|
"read_file"
|
||||||
|
);
|
||||||
|
|
||||||
|
// Check round 2: model turn rewritten to functionCall
|
||||||
|
assert_eq!(
|
||||||
|
contents[4]["parts"][0]["functionCall"]["name"].as_str().unwrap(),
|
||||||
|
"write_file"
|
||||||
|
);
|
||||||
|
// Check round 2: functionResponse injected
|
||||||
|
assert_eq!(
|
||||||
|
contents[5]["parts"][0]["functionResponse"]["name"].as_str().unwrap(),
|
||||||
|
"write_file"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_single_round_legacy_fallback() {
|
||||||
|
use crate::mitm::store::{CapturedFunctionCall, PendingToolResult};
|
||||||
|
|
||||||
|
// Simulate single round using legacy last_calls/pending_results (no tool_rounds).
|
||||||
|
// This is the path used by responses.rs.
|
||||||
|
let body = serde_json::json!({
|
||||||
|
"project": "test",
|
||||||
|
"requestId": "test/1",
|
||||||
|
"request": {
|
||||||
|
"contents": [
|
||||||
|
{"role": "user", "parts": [{"text": "Search for X"}]},
|
||||||
|
{"role": "model", "parts": [{"text": "Tool call completed. Awaiting external tool result."}]},
|
||||||
|
{"role": "user", "parts": [{"text": "[Tool result: found X]"}]},
|
||||||
|
],
|
||||||
|
"tools": [],
|
||||||
|
"generationConfig": {}
|
||||||
|
},
|
||||||
|
"model": "test"
|
||||||
|
});
|
||||||
|
|
||||||
|
let tool_ctx = ToolContext {
|
||||||
|
tools: Some(vec![serde_json::json!({
|
||||||
|
"functionDeclarations": [{
|
||||||
|
"name": "search",
|
||||||
|
"description": "Search",
|
||||||
|
"parameters": {"type": "OBJECT", "properties": {"q": {"type": "STRING"}}}
|
||||||
|
}]
|
||||||
|
})]),
|
||||||
|
tool_config: None,
|
||||||
|
pending_results: vec![PendingToolResult {
|
||||||
|
name: "search".to_string(),
|
||||||
|
result: serde_json::json!({"results": ["x"]}),
|
||||||
|
}],
|
||||||
|
last_calls: vec![CapturedFunctionCall {
|
||||||
|
name: "search".to_string(),
|
||||||
|
args: serde_json::json!({"q": "X"}),
|
||||||
|
captured_at: 0,
|
||||||
|
}],
|
||||||
|
generation_params: None,
|
||||||
|
pending_image: None,
|
||||||
|
tool_rounds: vec![], // Empty — forces legacy fallback
|
||||||
|
};
|
||||||
|
|
||||||
|
let bytes = serde_json::to_vec(&body).unwrap();
|
||||||
|
let modified = modify_request(&bytes, Some(&tool_ctx)).unwrap();
|
||||||
|
let result: Value = serde_json::from_slice(&modified).unwrap();
|
||||||
|
let contents = result["request"]["contents"].as_array().unwrap();
|
||||||
|
|
||||||
|
// Should still work: model turn rewritten + functionResponse injected
|
||||||
|
assert_eq!(contents.len(), 4, "should have 4 turns (3 original + 1 injected)");
|
||||||
|
assert_eq!(
|
||||||
|
contents[1]["parts"][0]["functionCall"]["name"].as_str().unwrap(),
|
||||||
|
"search"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
contents[2]["parts"][0]["functionResponse"]["name"].as_str().unwrap(),
|
||||||
|
"search"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_no_tool_rounds_no_rewrite() {
|
||||||
|
// No tool rounds, no legacy data — no rewriting should happen
|
||||||
|
let body = serde_json::json!({
|
||||||
|
"project": "test",
|
||||||
|
"requestId": "test/1",
|
||||||
|
"request": {
|
||||||
|
"contents": [
|
||||||
|
{"role": "user", "parts": [{"text": "Hello"}]},
|
||||||
|
{"role": "model", "parts": [{"text": "Hi there!"}]},
|
||||||
|
],
|
||||||
|
"tools": [],
|
||||||
|
"generationConfig": {}
|
||||||
|
},
|
||||||
|
"model": "test"
|
||||||
|
});
|
||||||
|
|
||||||
|
let tool_ctx = ToolContext {
|
||||||
|
tools: Some(vec![serde_json::json!({
|
||||||
|
"functionDeclarations": [{
|
||||||
|
"name": "noop",
|
||||||
|
"description": "Does nothing",
|
||||||
|
"parameters": {"type": "OBJECT", "properties": {}}
|
||||||
|
}]
|
||||||
|
})]),
|
||||||
|
tool_config: None,
|
||||||
|
pending_results: vec![],
|
||||||
|
last_calls: vec![],
|
||||||
|
generation_params: None,
|
||||||
|
pending_image: None,
|
||||||
|
tool_rounds: vec![],
|
||||||
|
};
|
||||||
|
|
||||||
|
let bytes = serde_json::to_vec(&body).unwrap();
|
||||||
|
let modified = modify_request(&bytes, Some(&tool_ctx)).unwrap();
|
||||||
|
let result: Value = serde_json::from_slice(&modified).unwrap();
|
||||||
|
let contents = result["request"]["contents"].as_array().unwrap();
|
||||||
|
|
||||||
|
// No rewriting — same number of turns
|
||||||
|
assert_eq!(contents.len(), 2);
|
||||||
|
assert_eq!(contents[1]["parts"][0]["text"].as_str().unwrap(), "Hi there!");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ─── Response modification ──────────────────────────────────────────────────
|
// ─── Response modification ──────────────────────────────────────────────────
|
||||||
|
|||||||
@@ -602,9 +602,11 @@ async fn handle_http_over_tls(
|
|||||||
let last_calls = store.get_last_function_calls().await;
|
let last_calls = store.get_last_function_calls().await;
|
||||||
let generation_params = store.get_generation_params().await;
|
let generation_params = store.get_generation_params().await;
|
||||||
let pending_image = store.take_pending_image().await;
|
let pending_image = store.take_pending_image().await;
|
||||||
|
let tool_rounds = store.take_tool_rounds().await;
|
||||||
|
|
||||||
let tool_ctx = if tools.is_some()
|
let tool_ctx = if tools.is_some()
|
||||||
|| !pending_results.is_empty()
|
|| !pending_results.is_empty()
|
||||||
|
|| !tool_rounds.is_empty()
|
||||||
|| generation_params.is_some()
|
|| generation_params.is_some()
|
||||||
|| pending_image.is_some()
|
|| pending_image.is_some()
|
||||||
{
|
{
|
||||||
@@ -615,6 +617,7 @@ async fn handle_http_over_tls(
|
|||||||
last_calls,
|
last_calls,
|
||||||
generation_params,
|
generation_params,
|
||||||
pending_image,
|
pending_image,
|
||||||
|
tool_rounds,
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
|
|||||||
@@ -60,6 +60,18 @@ pub struct PendingToolResult {
|
|||||||
pub result: serde_json::Value,
|
pub result: serde_json::Value,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// A single round of tool calling: the model's function calls paired with
|
||||||
|
/// the client's execution results.
|
||||||
|
///
|
||||||
|
/// In multi-step tool use, each round has its own calls and results.
|
||||||
|
/// This preserves per-turn data so history rewriting can map each
|
||||||
|
/// "Tool call completed" model turn to the correct functionCall/functionResponse.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct ToolRound {
|
||||||
|
pub calls: Vec<CapturedFunctionCall>,
|
||||||
|
pub results: Vec<PendingToolResult>,
|
||||||
|
}
|
||||||
|
|
||||||
/// An upstream error captured from Google's API response.
|
/// An upstream error captured from Google's API response.
|
||||||
/// Stored by the MITM proxy so API handlers can return it to the client
|
/// Stored by the MITM proxy so API handlers can return it to the client
|
||||||
/// instead of hanging forever waiting for a response that won't come.
|
/// instead of hanging forever waiting for a response that won't come.
|
||||||
@@ -152,6 +164,9 @@ pub struct MitmStore {
|
|||||||
call_id_to_name: Arc<RwLock<HashMap<String, String>>>,
|
call_id_to_name: Arc<RwLock<HashMap<String, String>>>,
|
||||||
/// Last captured function calls (for conversation history rewriting).
|
/// Last captured function calls (for conversation history rewriting).
|
||||||
last_function_calls: Arc<RwLock<Vec<CapturedFunctionCall>>>,
|
last_function_calls: Arc<RwLock<Vec<CapturedFunctionCall>>>,
|
||||||
|
/// Multi-round tool call history for correct per-turn history rewriting.
|
||||||
|
/// Set by completions/responses handler, consumed by modify_request.
|
||||||
|
tool_rounds: Arc<RwLock<Vec<ToolRound>>>,
|
||||||
|
|
||||||
// ── Cascade correlation ──────────────────────────────────────────────
|
// ── Cascade correlation ──────────────────────────────────────────────
|
||||||
/// Active cascade ID set by the API layer before sending a message.
|
/// Active cascade ID set by the API layer before sending a message.
|
||||||
@@ -223,6 +238,7 @@ impl MitmStore {
|
|||||||
pending_tool_results: Arc::new(RwLock::new(Vec::new())),
|
pending_tool_results: Arc::new(RwLock::new(Vec::new())),
|
||||||
call_id_to_name: Arc::new(RwLock::new(HashMap::new())),
|
call_id_to_name: Arc::new(RwLock::new(HashMap::new())),
|
||||||
last_function_calls: Arc::new(RwLock::new(Vec::new())),
|
last_function_calls: Arc::new(RwLock::new(Vec::new())),
|
||||||
|
tool_rounds: Arc::new(RwLock::new(Vec::new())),
|
||||||
active_cascade_id: Arc::new(RwLock::new(None)),
|
active_cascade_id: Arc::new(RwLock::new(None)),
|
||||||
captured_response_text: Arc::new(RwLock::new(None)),
|
captured_response_text: Arc::new(RwLock::new(None)),
|
||||||
captured_thinking_text: Arc::new(RwLock::new(None)),
|
captured_thinking_text: Arc::new(RwLock::new(None)),
|
||||||
@@ -511,6 +527,16 @@ impl MitmStore {
|
|||||||
self.last_function_calls.read().await.clone()
|
self.last_function_calls.read().await.clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Store multi-round tool call history for correct per-turn history rewriting.
|
||||||
|
pub async fn set_tool_rounds(&self, rounds: Vec<ToolRound>) {
|
||||||
|
*self.tool_rounds.write().await = rounds;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Take (consume) multi-round tool call history.
|
||||||
|
pub async fn take_tool_rounds(&self) -> Vec<ToolRound> {
|
||||||
|
std::mem::take(&mut *self.tool_rounds.write().await)
|
||||||
|
}
|
||||||
|
|
||||||
// ── Direct response capture (bypass LS) ──────────────────────────────
|
// ── Direct response capture (bypass LS) ──────────────────────────────
|
||||||
|
|
||||||
/// Set (replace) the captured response text.
|
/// Set (replace) the captured response text.
|
||||||
|
|||||||
Reference in New Issue
Block a user