fix: multi-round tool history rewrite and finishReason handling

- Add ToolRound struct to pair function calls with results per-round
- Replace single-match history rewrite (broke after first round) with
  multi-round loop that rewrites ALL placeholder model turns
- Fix tool result name fallback: use positional index instead of always
  picking the first call
- Set is_complete for any finishReason (FUNCTION_CALL, MAX_TOKENS, etc.)
  not just STOP — prevents response_complete flag from never being set
- Legacy fallback: responses.rs path (single-round via last_calls +
  pending_results) still works when tool_rounds is empty
- Add tests: multi-round rewrite, single-round legacy, no-op, and
  FUNCTION_CALL/MAX_TOKENS finishReason handling
This commit is contained in:
Nikketryhard
2026-02-16 19:05:37 -06:00
parent 6bda2ecafa
commit 39381a4dfe
5 changed files with 410 additions and 88 deletions

View File

@@ -8,7 +8,7 @@ use regex::Regex;
use serde_json::Value;
use tracing::info;
use super::store::{CapturedFunctionCall, PendingImage, PendingToolResult};
use super::store::{CapturedFunctionCall, PendingImage, PendingToolResult, ToolRound};
/// Strip ALL tool definitions.
/// Must be true: with tools present, the LS enters full agentic mode
@@ -30,6 +30,9 @@ pub struct ToolContext {
pub generation_params: Option<super::store::GenerationParams>,
/// Pending image to inject as inlineData in the user message.
pub pending_image: Option<PendingImage>,
/// Multi-round tool call history. Each entry is a (calls, results) pair
/// from one round of tool use. Preferred over last_calls/pending_results.
pub tool_rounds: Vec<ToolRound>,
}
/// Modify a streamGenerateContent request body in-place.
@@ -355,71 +358,96 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option<Vec
// ── 3b. Rewrite conversation history for tool results ────────────
if let Some(ref ctx) = tool_ctx {
if !ctx.pending_results.is_empty() && !ctx.last_calls.is_empty() {
// Prefer multi-round tool_rounds (set by completions handler) over
// legacy last_calls/pending_results (set by responses handler).
let rounds = if !ctx.tool_rounds.is_empty() {
ctx.tool_rounds.clone()
} else if !ctx.pending_results.is_empty() && !ctx.last_calls.is_empty() {
// Legacy single-round: wrap in one ToolRound
vec![ToolRound {
calls: ctx.last_calls.clone(),
results: ctx.pending_results.clone(),
}]
} else {
vec![]
};
if !rounds.is_empty() {
if let Some(contents) = json
.pointer_mut("/request/contents")
.and_then(|v| v.as_array_mut())
{
// Find the model turn with our fake "Tool call completed" text and replace it
// with the actual functionCall parts
for msg in contents.iter_mut() {
// Phase 1: find ALL model turns with placeholder text, pair with rounds
let mut rewrites: Vec<(usize, usize)> = Vec::new(); // (content_index, round_index)
let mut round_idx = 0;
for (i, msg) in contents.iter().enumerate() {
if round_idx >= rounds.len() {
break;
}
if msg["role"].as_str() == Some("model") {
if let Some(text) = msg["parts"][0]["text"].as_str() {
if text.contains("Tool call completed")
|| text.contains("Awaiting external tool result")
{
// Replace with functionCall parts
let fc_parts: Vec<Value> = ctx
.last_calls
.iter()
.map(|fc| {
serde_json::json!({
"functionCall": {
"name": fc.name,
"args": fc.args,
}
})
})
.collect();
msg["parts"] = Value::Array(fc_parts);
changes.push("rewrite model turn with functionCall".to_string());
break;
rewrites.push((i, round_idx));
round_idx += 1;
}
}
}
}
// Add functionResponse as a user turn before the last user message
let fn_response_parts: Vec<Value> = ctx
.pending_results
.iter()
.map(|r| {
serde_json::json!({
"functionResponse": {
"name": r.name,
"response": r.result,
}
})
})
.collect();
let fn_response_turn = serde_json::json!({
"role": "user",
"parts": fn_response_parts,
});
// Phase 2: apply rewrites (reverse order for stable indices during insertion)
let mut insert_offset = 0;
for (content_idx, round_idx) in &rewrites {
let actual_idx = *content_idx + insert_offset;
let round = &rounds[*round_idx];
// Insert before the last user message
let last_user_idx = contents
.iter()
.rposition(|msg| msg["role"].as_str() == Some("user"));
if let Some(idx) = last_user_idx {
contents.insert(idx, fn_response_turn);
} else {
contents.push(fn_response_turn);
// Replace model turn with functionCall parts
let fc_parts: Vec<Value> = round
.calls
.iter()
.map(|fc| {
serde_json::json!({
"functionCall": {
"name": fc.name,
"args": fc.args,
}
})
})
.collect();
contents[actual_idx]["parts"] = Value::Array(fc_parts);
// Inject functionResponse user turn right after
if !round.results.is_empty() {
let fr_parts: Vec<Value> = round
.results
.iter()
.map(|r| {
serde_json::json!({
"functionResponse": {
"name": r.name,
"response": r.result,
}
})
})
.collect();
contents.insert(
actual_idx + 1,
serde_json::json!({
"role": "user",
"parts": fr_parts,
}),
);
insert_offset += 1;
}
}
if !rewrites.is_empty() {
changes.push(format!(
"rewrite {} tool round(s) in history",
rewrites.len()
));
}
changes.push(format!(
"inject {} functionResponse(s)",
ctx.pending_results.len()
));
}
}
}
@@ -992,6 +1020,222 @@ mod tests {
.unwrap();
assert_eq!(result, "keep this and this");
}
#[test]
fn test_multi_round_history_rewrite() {
use crate::mitm::store::{CapturedFunctionCall, PendingToolResult, ToolRound};
// Simulate 2 rounds of tool use in LS history:
// user → model("Tool call completed") → user(text) → model("Tool call completed") → user(text)
let body = serde_json::json!({
"project": "test",
"requestId": "test/1",
"request": {
"contents": [
{"role": "user", "parts": [{"text": "Read foo and write to bar"}]},
{"role": "model", "parts": [{"text": "Tool call completed. Awaiting external tool result."}]},
{"role": "user", "parts": [{"text": "[Tool result: file contents here]"}]},
{"role": "model", "parts": [{"text": "Tool call completed. Awaiting external tool result."}]},
{"role": "user", "parts": [{"text": "[Tool result: write success]"}]},
],
"tools": [],
"generationConfig": {}
},
"model": "test"
});
let tool_ctx = ToolContext {
tools: Some(vec![serde_json::json!({
"functionDeclarations": [{
"name": "read_file",
"description": "Read a file",
"parameters": {"type": "OBJECT", "properties": {"path": {"type": "STRING"}}}
}, {
"name": "write_file",
"description": "Write a file",
"parameters": {"type": "OBJECT", "properties": {"path": {"type": "STRING"}, "content": {"type": "STRING"}}}
}]
})]),
tool_config: None,
pending_results: vec![],
last_calls: vec![],
generation_params: None,
pending_image: None,
tool_rounds: vec![
ToolRound {
calls: vec![CapturedFunctionCall {
name: "read_file".to_string(),
args: serde_json::json!({"path": "/foo"}),
captured_at: 0,
}],
results: vec![PendingToolResult {
name: "read_file".to_string(),
result: serde_json::json!({"content": "file contents here"}),
}],
},
ToolRound {
calls: vec![CapturedFunctionCall {
name: "write_file".to_string(),
args: serde_json::json!({"path": "/bar", "content": "data"}),
captured_at: 0,
}],
results: vec![PendingToolResult {
name: "write_file".to_string(),
result: serde_json::json!({"result": "ok"}),
}],
},
],
};
let bytes = serde_json::to_vec(&body).unwrap();
let modified = modify_request(&bytes, Some(&tool_ctx)).unwrap();
let result: Value = serde_json::from_slice(&modified).unwrap();
let contents = result["request"]["contents"].as_array().unwrap();
// Expected layout after rewrite:
// [0] user: "Read foo..."
// [1] model: functionCall(read_file) (was "Tool call completed")
// [2] user: functionResponse(read_file) (injected)
// [3] user: "[Tool result: file contents]" (original LS turn)
// [4] model: functionCall(write_file) (was "Tool call completed")
// [5] user: functionResponse(write_file) (injected)
// [6] user: "[Tool result: write success]" (original LS turn)
assert_eq!(contents.len(), 7, "should have 7 turns (5 original + 2 injected)");
// Check round 1: model turn rewritten to functionCall
assert_eq!(
contents[1]["parts"][0]["functionCall"]["name"].as_str().unwrap(),
"read_file"
);
assert_eq!(
contents[1]["parts"][0]["functionCall"]["args"]["path"].as_str().unwrap(),
"/foo"
);
// Check round 1: functionResponse injected
assert_eq!(
contents[2]["role"].as_str().unwrap(),
"user"
);
assert_eq!(
contents[2]["parts"][0]["functionResponse"]["name"].as_str().unwrap(),
"read_file"
);
// Check round 2: model turn rewritten to functionCall
assert_eq!(
contents[4]["parts"][0]["functionCall"]["name"].as_str().unwrap(),
"write_file"
);
// Check round 2: functionResponse injected
assert_eq!(
contents[5]["parts"][0]["functionResponse"]["name"].as_str().unwrap(),
"write_file"
);
}
#[test]
fn test_single_round_legacy_fallback() {
use crate::mitm::store::{CapturedFunctionCall, PendingToolResult};
// Simulate single round using legacy last_calls/pending_results (no tool_rounds).
// This is the path used by responses.rs.
let body = serde_json::json!({
"project": "test",
"requestId": "test/1",
"request": {
"contents": [
{"role": "user", "parts": [{"text": "Search for X"}]},
{"role": "model", "parts": [{"text": "Tool call completed. Awaiting external tool result."}]},
{"role": "user", "parts": [{"text": "[Tool result: found X]"}]},
],
"tools": [],
"generationConfig": {}
},
"model": "test"
});
let tool_ctx = ToolContext {
tools: Some(vec![serde_json::json!({
"functionDeclarations": [{
"name": "search",
"description": "Search",
"parameters": {"type": "OBJECT", "properties": {"q": {"type": "STRING"}}}
}]
})]),
tool_config: None,
pending_results: vec![PendingToolResult {
name: "search".to_string(),
result: serde_json::json!({"results": ["x"]}),
}],
last_calls: vec![CapturedFunctionCall {
name: "search".to_string(),
args: serde_json::json!({"q": "X"}),
captured_at: 0,
}],
generation_params: None,
pending_image: None,
tool_rounds: vec![], // Empty — forces legacy fallback
};
let bytes = serde_json::to_vec(&body).unwrap();
let modified = modify_request(&bytes, Some(&tool_ctx)).unwrap();
let result: Value = serde_json::from_slice(&modified).unwrap();
let contents = result["request"]["contents"].as_array().unwrap();
// Should still work: model turn rewritten + functionResponse injected
assert_eq!(contents.len(), 4, "should have 4 turns (3 original + 1 injected)");
assert_eq!(
contents[1]["parts"][0]["functionCall"]["name"].as_str().unwrap(),
"search"
);
assert_eq!(
contents[2]["parts"][0]["functionResponse"]["name"].as_str().unwrap(),
"search"
);
}
#[test]
fn test_no_tool_rounds_no_rewrite() {
// No tool rounds, no legacy data — no rewriting should happen
let body = serde_json::json!({
"project": "test",
"requestId": "test/1",
"request": {
"contents": [
{"role": "user", "parts": [{"text": "Hello"}]},
{"role": "model", "parts": [{"text": "Hi there!"}]},
],
"tools": [],
"generationConfig": {}
},
"model": "test"
});
let tool_ctx = ToolContext {
tools: Some(vec![serde_json::json!({
"functionDeclarations": [{
"name": "noop",
"description": "Does nothing",
"parameters": {"type": "OBJECT", "properties": {}}
}]
})]),
tool_config: None,
pending_results: vec![],
last_calls: vec![],
generation_params: None,
pending_image: None,
tool_rounds: vec![],
};
let bytes = serde_json::to_vec(&body).unwrap();
let modified = modify_request(&bytes, Some(&tool_ctx)).unwrap();
let result: Value = serde_json::from_slice(&modified).unwrap();
let contents = result["request"]["contents"].as_array().unwrap();
// No rewriting — same number of turns
assert_eq!(contents.len(), 2);
assert_eq!(contents[1]["parts"][0]["text"].as_str().unwrap(), "Hi there!");
}
}
// ─── Response modification ──────────────────────────────────────────────────