From 39381a4dfe7ea46246f723d95683b56fc8a1ad9d Mon Sep 17 00:00:00 2001 From: Nikketryhard Date: Mon, 16 Feb 2026 19:05:37 -0600 Subject: [PATCH] fix: multi-round tool history rewrite and finishReason handling MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add ToolRound struct to pair function calls with results per-round - Replace single-match history rewrite (broke after first round) with multi-round loop that rewrites ALL placeholder model turns - Fix tool result name fallback: use positional index instead of always picking the first call - Set is_complete for any finishReason (FUNCTION_CALL, MAX_TOKENS, etc.) not just STOP — prevents response_complete flag from never being set - Legacy fallback: responses.rs path (single-round via last_calls + pending_results) still works when tool_rounds is empty - Add tests: multi-round rewrite, single-round legacy, no-op, and FUNCTION_CALL/MAX_TOKENS finishReason handling --- src/api/completions.rs | 93 +++++++---- src/mitm/intercept.rs | 32 +++- src/mitm/modify.rs | 344 +++++++++++++++++++++++++++++++++++------ src/mitm/proxy.rs | 3 + src/mitm/store.rs | 26 ++++ 5 files changed, 410 insertions(+), 88 deletions(-) diff --git a/src/api/completions.rs b/src/api/completions.rs index 27f4b8d..cad1a45 100644 --- a/src/api/completions.rs +++ b/src/api/completions.rs @@ -16,7 +16,7 @@ use super::polling::{ use super::types::*; use super::util::{err_response, now_unix, upstream_err_response}; use super::AppState; -use crate::mitm::store::{CapturedFunctionCall, PendingToolResult}; +use crate::mitm::store::{CapturedFunctionCall, PendingToolResult, ToolRound}; /// Extract a conversation/session ID from a flexible JSON value. /// Accepts a plain string or an object with an "id" field. @@ -229,18 +229,25 @@ pub(crate) async fn handle_completions( // When OpenCode sends back tool results, the messages array contains: // 1. assistant message with tool_calls (the model's previous function calls) // 2. tool messages with results (the executed tool outputs) - // We need to store these so modify_request can rewrite the LS's - // conversation history with proper functionCall/functionResponse parts - // instead of the placeholder "Tool call completed" text. + // We build ToolRounds: each round pairs one assistant's tool_calls with + // the subsequent tool result messages. This enables correct per-turn + // history rewriting for multi-step tool use. { - let mut last_calls: Vec = Vec::new(); - let mut pending_results: Vec = Vec::new(); + let mut rounds: Vec = Vec::new(); + let mut current_round: Option = None; for msg in &body.messages { match msg.role.as_str() { "assistant" => { - // Extract function calls from assistant's tool_calls + // Finalize any open round + if let Some(round) = current_round.take() { + if !round.calls.is_empty() { + rounds.push(round); + } + } + // Start new round if this assistant has tool_calls if let Some(ref tool_calls) = msg.tool_calls { + let mut calls = Vec::new(); for tc in tool_calls { if let Some(func) = tc.get("function") { let name = func["name"].as_str().unwrap_or("unknown").to_string(); @@ -254,7 +261,7 @@ pub(crate) async fn handle_completions( state.mitm_store.register_call_id(call_id, name.clone()).await; } - last_calls.push(CapturedFunctionCall { + calls.push(CapturedFunctionCall { name, args, captured_at: std::time::SystemTime::now() @@ -264,21 +271,31 @@ pub(crate) async fn handle_completions( }); } } + if !calls.is_empty() { + current_round = Some(ToolRound { + calls, + results: Vec::new(), + }); + } } } "tool" => { - // Extract tool results let text = extract_message_text(&msg.content); if let Some(ref call_id) = msg.tool_call_id { - // Look up function name from call_id + // Look up function name from call_id, fall back to + // positional index within the current round's calls + let result_index = current_round + .as_ref() + .map(|r| r.results.len()) + .unwrap_or(0); let name = state .mitm_store .lookup_call_id(call_id) .await .unwrap_or_else(|| { - // Fallback: try to find the name from last_calls by position - last_calls - .first() + current_round + .as_ref() + .and_then(|r| r.calls.get(result_index)) .map(|fc| fc.name.clone()) .unwrap_or_else(|| "unknown_function".to_string()) }); @@ -286,35 +303,43 @@ pub(crate) async fn handle_completions( let result_value = serde_json::from_str::(&text) .unwrap_or_else(|_| serde_json::json!({"result": text})); - pending_results.push(PendingToolResult { - name, - result: result_value, - }); + if let Some(ref mut round) = current_round { + round.results.push(PendingToolResult { + name, + result: result_value, + }); + } } } - _ => {} + _ => { + // Any other role (user, system) finalizes the current round + if let Some(round) = current_round.take() { + if !round.calls.is_empty() { + rounds.push(round); + } + } + } + } + } + // Finalize last round + if let Some(round) = current_round.take() { + if !round.calls.is_empty() { + rounds.push(round); } } - if !last_calls.is_empty() { + if !rounds.is_empty() { info!( - count = last_calls.len(), - tools = ?last_calls.iter().map(|c| &c.name).collect::>(), - "Completions: stored last function calls for MITM history rewrite" + round_count = rounds.len(), + calls = ?rounds.iter().map(|r| r.calls.iter().map(|c| &c.name).collect::>()).collect::>(), + "Completions: stored {} tool round(s) for MITM history rewrite", + rounds.len(), ); - state.mitm_store.set_last_function_calls(last_calls).await; - } - - if !pending_results.is_empty() { - info!( - count = pending_results.len(), - tools = ?pending_results.iter().map(|r| &r.name).collect::>(), - "Completions: stored tool results for MITM injection" - ); - for result in pending_results { - state.mitm_store.add_tool_result(result).await; + // Also set last_function_calls from the latest round for proxy.rs recording compat + if let Some(last_round) = rounds.last() { + state.mitm_store.set_last_function_calls(last_round.calls.clone()).await; } - // Clear awaiting flag — we have the results now + state.mitm_store.set_tool_rounds(rounds).await; state.mitm_store.clear_awaiting_tool_result(); } } diff --git a/src/mitm/intercept.rs b/src/mitm/intercept.rs index bfc3e1d..1132e46 100644 --- a/src/mitm/intercept.rs +++ b/src/mitm/intercept.rs @@ -153,12 +153,10 @@ impl StreamingAccumulator { } } } - // Check for completion + // Check for completion — any finishReason means response is done if let Some(reason) = candidate["finishReason"].as_str() { self.stop_reason = Some(reason.to_string()); - if reason == "STOP" { - self.is_complete = true; - } + self.is_complete = true; // Log non-STOP finish reasons if reason != "STOP" { info!(finish_reason = reason, "MITM: non-STOP finish reason"); @@ -589,4 +587,30 @@ data: {"response": {"candidates": [{"content": {"role": "model","parts": [{"text ); assert_eq!(acc.stop_reason, Some("STOP".to_string())); } + + #[test] + fn test_function_call_finish_reason_sets_complete() { + let mut acc = StreamingAccumulator::new(); + + let event = "data: {\"response\": {\"candidates\": [{\"content\": {\"role\": \"model\", \"parts\": [{\"functionCall\": {\"name\": \"read_file\", \"args\": {\"path\": \"/foo\"}}}]}, \"finishReason\": \"FUNCTION_CALL\"}], \"usageMetadata\": {\"promptTokenCount\": 50, \"candidatesTokenCount\": 5, \"totalTokenCount\": 55}, \"modelVersion\": \"gemini-3-flash\"}}\n"; + parse_streaming_chunk(event, &mut acc); + + assert!(acc.is_complete, "FUNCTION_CALL finishReason should set is_complete"); + assert_eq!(acc.stop_reason, Some("FUNCTION_CALL".to_string())); + assert_eq!(acc.function_calls.len(), 1); + assert_eq!(acc.function_calls[0].name, "read_file"); + assert_eq!(acc.output_tokens, 5); + } + + #[test] + fn test_max_tokens_finish_reason_sets_complete() { + let mut acc = StreamingAccumulator::new(); + + let event = "data: {\"response\": {\"candidates\": [{\"content\": {\"role\": \"model\", \"parts\": [{\"text\": \"truncated...\"}]}, \"finishReason\": \"MAX_TOKENS\"}], \"usageMetadata\": {\"promptTokenCount\": 50, \"candidatesTokenCount\": 100, \"totalTokenCount\": 150}}}\n"; + parse_streaming_chunk(event, &mut acc); + + assert!(acc.is_complete, "MAX_TOKENS finishReason should set is_complete"); + assert_eq!(acc.stop_reason, Some("MAX_TOKENS".to_string())); + assert_eq!(acc.response_text, "truncated..."); + } } diff --git a/src/mitm/modify.rs b/src/mitm/modify.rs index d07441a..f8ee15f 100644 --- a/src/mitm/modify.rs +++ b/src/mitm/modify.rs @@ -8,7 +8,7 @@ use regex::Regex; use serde_json::Value; use tracing::info; -use super::store::{CapturedFunctionCall, PendingImage, PendingToolResult}; +use super::store::{CapturedFunctionCall, PendingImage, PendingToolResult, ToolRound}; /// Strip ALL tool definitions. /// Must be true: with tools present, the LS enters full agentic mode @@ -30,6 +30,9 @@ pub struct ToolContext { pub generation_params: Option, /// Pending image to inject as inlineData in the user message. pub pending_image: Option, + /// Multi-round tool call history. Each entry is a (calls, results) pair + /// from one round of tool use. Preferred over last_calls/pending_results. + pub tool_rounds: Vec, } /// Modify a streamGenerateContent request body in-place. @@ -355,71 +358,96 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option = Vec::new(); // (content_index, round_index) + let mut round_idx = 0; + for (i, msg) in contents.iter().enumerate() { + if round_idx >= rounds.len() { + break; + } if msg["role"].as_str() == Some("model") { if let Some(text) = msg["parts"][0]["text"].as_str() { if text.contains("Tool call completed") || text.contains("Awaiting external tool result") { - // Replace with functionCall parts - let fc_parts: Vec = ctx - .last_calls - .iter() - .map(|fc| { - serde_json::json!({ - "functionCall": { - "name": fc.name, - "args": fc.args, - } - }) - }) - .collect(); - msg["parts"] = Value::Array(fc_parts); - changes.push("rewrite model turn with functionCall".to_string()); - break; + rewrites.push((i, round_idx)); + round_idx += 1; } } } } - // Add functionResponse as a user turn before the last user message - let fn_response_parts: Vec = ctx - .pending_results - .iter() - .map(|r| { - serde_json::json!({ - "functionResponse": { - "name": r.name, - "response": r.result, - } - }) - }) - .collect(); - let fn_response_turn = serde_json::json!({ - "role": "user", - "parts": fn_response_parts, - }); + // Phase 2: apply rewrites (reverse order for stable indices during insertion) + let mut insert_offset = 0; + for (content_idx, round_idx) in &rewrites { + let actual_idx = *content_idx + insert_offset; + let round = &rounds[*round_idx]; - // Insert before the last user message - let last_user_idx = contents - .iter() - .rposition(|msg| msg["role"].as_str() == Some("user")); - if let Some(idx) = last_user_idx { - contents.insert(idx, fn_response_turn); - } else { - contents.push(fn_response_turn); + // Replace model turn with functionCall parts + let fc_parts: Vec = round + .calls + .iter() + .map(|fc| { + serde_json::json!({ + "functionCall": { + "name": fc.name, + "args": fc.args, + } + }) + }) + .collect(); + contents[actual_idx]["parts"] = Value::Array(fc_parts); + + // Inject functionResponse user turn right after + if !round.results.is_empty() { + let fr_parts: Vec = round + .results + .iter() + .map(|r| { + serde_json::json!({ + "functionResponse": { + "name": r.name, + "response": r.result, + } + }) + }) + .collect(); + contents.insert( + actual_idx + 1, + serde_json::json!({ + "role": "user", + "parts": fr_parts, + }), + ); + insert_offset += 1; + } + } + + if !rewrites.is_empty() { + changes.push(format!( + "rewrite {} tool round(s) in history", + rewrites.len() + )); } - changes.push(format!( - "inject {} functionResponse(s)", - ctx.pending_results.len() - )); } } } @@ -992,6 +1020,222 @@ mod tests { .unwrap(); assert_eq!(result, "keep this and this"); } + + #[test] + fn test_multi_round_history_rewrite() { + use crate::mitm::store::{CapturedFunctionCall, PendingToolResult, ToolRound}; + + // Simulate 2 rounds of tool use in LS history: + // user → model("Tool call completed") → user(text) → model("Tool call completed") → user(text) + let body = serde_json::json!({ + "project": "test", + "requestId": "test/1", + "request": { + "contents": [ + {"role": "user", "parts": [{"text": "Read foo and write to bar"}]}, + {"role": "model", "parts": [{"text": "Tool call completed. Awaiting external tool result."}]}, + {"role": "user", "parts": [{"text": "[Tool result: file contents here]"}]}, + {"role": "model", "parts": [{"text": "Tool call completed. Awaiting external tool result."}]}, + {"role": "user", "parts": [{"text": "[Tool result: write success]"}]}, + ], + "tools": [], + "generationConfig": {} + }, + "model": "test" + }); + + let tool_ctx = ToolContext { + tools: Some(vec![serde_json::json!({ + "functionDeclarations": [{ + "name": "read_file", + "description": "Read a file", + "parameters": {"type": "OBJECT", "properties": {"path": {"type": "STRING"}}} + }, { + "name": "write_file", + "description": "Write a file", + "parameters": {"type": "OBJECT", "properties": {"path": {"type": "STRING"}, "content": {"type": "STRING"}}} + }] + })]), + tool_config: None, + pending_results: vec![], + last_calls: vec![], + generation_params: None, + pending_image: None, + tool_rounds: vec![ + ToolRound { + calls: vec![CapturedFunctionCall { + name: "read_file".to_string(), + args: serde_json::json!({"path": "/foo"}), + captured_at: 0, + }], + results: vec![PendingToolResult { + name: "read_file".to_string(), + result: serde_json::json!({"content": "file contents here"}), + }], + }, + ToolRound { + calls: vec![CapturedFunctionCall { + name: "write_file".to_string(), + args: serde_json::json!({"path": "/bar", "content": "data"}), + captured_at: 0, + }], + results: vec![PendingToolResult { + name: "write_file".to_string(), + result: serde_json::json!({"result": "ok"}), + }], + }, + ], + }; + + let bytes = serde_json::to_vec(&body).unwrap(); + let modified = modify_request(&bytes, Some(&tool_ctx)).unwrap(); + let result: Value = serde_json::from_slice(&modified).unwrap(); + let contents = result["request"]["contents"].as_array().unwrap(); + + // Expected layout after rewrite: + // [0] user: "Read foo..." + // [1] model: functionCall(read_file) (was "Tool call completed") + // [2] user: functionResponse(read_file) (injected) + // [3] user: "[Tool result: file contents]" (original LS turn) + // [4] model: functionCall(write_file) (was "Tool call completed") + // [5] user: functionResponse(write_file) (injected) + // [6] user: "[Tool result: write success]" (original LS turn) + assert_eq!(contents.len(), 7, "should have 7 turns (5 original + 2 injected)"); + + // Check round 1: model turn rewritten to functionCall + assert_eq!( + contents[1]["parts"][0]["functionCall"]["name"].as_str().unwrap(), + "read_file" + ); + assert_eq!( + contents[1]["parts"][0]["functionCall"]["args"]["path"].as_str().unwrap(), + "/foo" + ); + // Check round 1: functionResponse injected + assert_eq!( + contents[2]["role"].as_str().unwrap(), + "user" + ); + assert_eq!( + contents[2]["parts"][0]["functionResponse"]["name"].as_str().unwrap(), + "read_file" + ); + + // Check round 2: model turn rewritten to functionCall + assert_eq!( + contents[4]["parts"][0]["functionCall"]["name"].as_str().unwrap(), + "write_file" + ); + // Check round 2: functionResponse injected + assert_eq!( + contents[5]["parts"][0]["functionResponse"]["name"].as_str().unwrap(), + "write_file" + ); + } + + #[test] + fn test_single_round_legacy_fallback() { + use crate::mitm::store::{CapturedFunctionCall, PendingToolResult}; + + // Simulate single round using legacy last_calls/pending_results (no tool_rounds). + // This is the path used by responses.rs. + let body = serde_json::json!({ + "project": "test", + "requestId": "test/1", + "request": { + "contents": [ + {"role": "user", "parts": [{"text": "Search for X"}]}, + {"role": "model", "parts": [{"text": "Tool call completed. Awaiting external tool result."}]}, + {"role": "user", "parts": [{"text": "[Tool result: found X]"}]}, + ], + "tools": [], + "generationConfig": {} + }, + "model": "test" + }); + + let tool_ctx = ToolContext { + tools: Some(vec![serde_json::json!({ + "functionDeclarations": [{ + "name": "search", + "description": "Search", + "parameters": {"type": "OBJECT", "properties": {"q": {"type": "STRING"}}} + }] + })]), + tool_config: None, + pending_results: vec![PendingToolResult { + name: "search".to_string(), + result: serde_json::json!({"results": ["x"]}), + }], + last_calls: vec![CapturedFunctionCall { + name: "search".to_string(), + args: serde_json::json!({"q": "X"}), + captured_at: 0, + }], + generation_params: None, + pending_image: None, + tool_rounds: vec![], // Empty — forces legacy fallback + }; + + let bytes = serde_json::to_vec(&body).unwrap(); + let modified = modify_request(&bytes, Some(&tool_ctx)).unwrap(); + let result: Value = serde_json::from_slice(&modified).unwrap(); + let contents = result["request"]["contents"].as_array().unwrap(); + + // Should still work: model turn rewritten + functionResponse injected + assert_eq!(contents.len(), 4, "should have 4 turns (3 original + 1 injected)"); + assert_eq!( + contents[1]["parts"][0]["functionCall"]["name"].as_str().unwrap(), + "search" + ); + assert_eq!( + contents[2]["parts"][0]["functionResponse"]["name"].as_str().unwrap(), + "search" + ); + } + + #[test] + fn test_no_tool_rounds_no_rewrite() { + // No tool rounds, no legacy data — no rewriting should happen + let body = serde_json::json!({ + "project": "test", + "requestId": "test/1", + "request": { + "contents": [ + {"role": "user", "parts": [{"text": "Hello"}]}, + {"role": "model", "parts": [{"text": "Hi there!"}]}, + ], + "tools": [], + "generationConfig": {} + }, + "model": "test" + }); + + let tool_ctx = ToolContext { + tools: Some(vec![serde_json::json!({ + "functionDeclarations": [{ + "name": "noop", + "description": "Does nothing", + "parameters": {"type": "OBJECT", "properties": {}} + }] + })]), + tool_config: None, + pending_results: vec![], + last_calls: vec![], + generation_params: None, + pending_image: None, + tool_rounds: vec![], + }; + + let bytes = serde_json::to_vec(&body).unwrap(); + let modified = modify_request(&bytes, Some(&tool_ctx)).unwrap(); + let result: Value = serde_json::from_slice(&modified).unwrap(); + let contents = result["request"]["contents"].as_array().unwrap(); + + // No rewriting — same number of turns + assert_eq!(contents.len(), 2); + assert_eq!(contents[1]["parts"][0]["text"].as_str().unwrap(), "Hi there!"); + } } // ─── Response modification ────────────────────────────────────────────────── diff --git a/src/mitm/proxy.rs b/src/mitm/proxy.rs index 37f6eaf..93bcedf 100644 --- a/src/mitm/proxy.rs +++ b/src/mitm/proxy.rs @@ -602,9 +602,11 @@ async fn handle_http_over_tls( let last_calls = store.get_last_function_calls().await; let generation_params = store.get_generation_params().await; let pending_image = store.take_pending_image().await; + let tool_rounds = store.take_tool_rounds().await; let tool_ctx = if tools.is_some() || !pending_results.is_empty() + || !tool_rounds.is_empty() || generation_params.is_some() || pending_image.is_some() { @@ -615,6 +617,7 @@ async fn handle_http_over_tls( last_calls, generation_params, pending_image, + tool_rounds, }) } else { None diff --git a/src/mitm/store.rs b/src/mitm/store.rs index d28d8c9..467fd15 100644 --- a/src/mitm/store.rs +++ b/src/mitm/store.rs @@ -60,6 +60,18 @@ pub struct PendingToolResult { pub result: serde_json::Value, } +/// A single round of tool calling: the model's function calls paired with +/// the client's execution results. +/// +/// In multi-step tool use, each round has its own calls and results. +/// This preserves per-turn data so history rewriting can map each +/// "Tool call completed" model turn to the correct functionCall/functionResponse. +#[derive(Debug, Clone)] +pub struct ToolRound { + pub calls: Vec, + pub results: Vec, +} + /// An upstream error captured from Google's API response. /// Stored by the MITM proxy so API handlers can return it to the client /// instead of hanging forever waiting for a response that won't come. @@ -152,6 +164,9 @@ pub struct MitmStore { call_id_to_name: Arc>>, /// Last captured function calls (for conversation history rewriting). last_function_calls: Arc>>, + /// Multi-round tool call history for correct per-turn history rewriting. + /// Set by completions/responses handler, consumed by modify_request. + tool_rounds: Arc>>, // ── Cascade correlation ────────────────────────────────────────────── /// Active cascade ID set by the API layer before sending a message. @@ -223,6 +238,7 @@ impl MitmStore { pending_tool_results: Arc::new(RwLock::new(Vec::new())), call_id_to_name: Arc::new(RwLock::new(HashMap::new())), last_function_calls: Arc::new(RwLock::new(Vec::new())), + tool_rounds: Arc::new(RwLock::new(Vec::new())), active_cascade_id: Arc::new(RwLock::new(None)), captured_response_text: Arc::new(RwLock::new(None)), captured_thinking_text: Arc::new(RwLock::new(None)), @@ -511,6 +527,16 @@ impl MitmStore { self.last_function_calls.read().await.clone() } + /// Store multi-round tool call history for correct per-turn history rewriting. + pub async fn set_tool_rounds(&self, rounds: Vec) { + *self.tool_rounds.write().await = rounds; + } + + /// Take (consume) multi-round tool call history. + pub async fn take_tool_rounds(&self) -> Vec { + std::mem::take(&mut *self.tool_rounds.write().await) + } + // ── Direct response capture (bypass LS) ────────────────────────────── /// Set (replace) the captured response text.