fix: multi-round tool history rewrite and finishReason handling
- Add ToolRound struct to pair function calls with results per-round - Replace single-match history rewrite (broke after first round) with multi-round loop that rewrites ALL placeholder model turns - Fix tool result name fallback: use positional index instead of always picking the first call - Set is_complete for any finishReason (FUNCTION_CALL, MAX_TOKENS, etc.) not just STOP — prevents response_complete flag from never being set - Legacy fallback: responses.rs path (single-round via last_calls + pending_results) still works when tool_rounds is empty - Add tests: multi-round rewrite, single-round legacy, no-op, and FUNCTION_CALL/MAX_TOKENS finishReason handling
This commit is contained in:
@@ -8,7 +8,7 @@ use regex::Regex;
|
||||
use serde_json::Value;
|
||||
use tracing::info;
|
||||
|
||||
use super::store::{CapturedFunctionCall, PendingImage, PendingToolResult};
|
||||
use super::store::{CapturedFunctionCall, PendingImage, PendingToolResult, ToolRound};
|
||||
|
||||
/// Strip ALL tool definitions.
|
||||
/// Must be true: with tools present, the LS enters full agentic mode
|
||||
@@ -30,6 +30,9 @@ pub struct ToolContext {
|
||||
pub generation_params: Option<super::store::GenerationParams>,
|
||||
/// Pending image to inject as inlineData in the user message.
|
||||
pub pending_image: Option<PendingImage>,
|
||||
/// Multi-round tool call history. Each entry is a (calls, results) pair
|
||||
/// from one round of tool use. Preferred over last_calls/pending_results.
|
||||
pub tool_rounds: Vec<ToolRound>,
|
||||
}
|
||||
|
||||
/// Modify a streamGenerateContent request body in-place.
|
||||
@@ -355,71 +358,96 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option<Vec
|
||||
|
||||
// ── 3b. Rewrite conversation history for tool results ────────────
|
||||
if let Some(ref ctx) = tool_ctx {
|
||||
if !ctx.pending_results.is_empty() && !ctx.last_calls.is_empty() {
|
||||
// Prefer multi-round tool_rounds (set by completions handler) over
|
||||
// legacy last_calls/pending_results (set by responses handler).
|
||||
let rounds = if !ctx.tool_rounds.is_empty() {
|
||||
ctx.tool_rounds.clone()
|
||||
} else if !ctx.pending_results.is_empty() && !ctx.last_calls.is_empty() {
|
||||
// Legacy single-round: wrap in one ToolRound
|
||||
vec![ToolRound {
|
||||
calls: ctx.last_calls.clone(),
|
||||
results: ctx.pending_results.clone(),
|
||||
}]
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
|
||||
if !rounds.is_empty() {
|
||||
if let Some(contents) = json
|
||||
.pointer_mut("/request/contents")
|
||||
.and_then(|v| v.as_array_mut())
|
||||
{
|
||||
// Find the model turn with our fake "Tool call completed" text and replace it
|
||||
// with the actual functionCall parts
|
||||
for msg in contents.iter_mut() {
|
||||
// Phase 1: find ALL model turns with placeholder text, pair with rounds
|
||||
let mut rewrites: Vec<(usize, usize)> = Vec::new(); // (content_index, round_index)
|
||||
let mut round_idx = 0;
|
||||
for (i, msg) in contents.iter().enumerate() {
|
||||
if round_idx >= rounds.len() {
|
||||
break;
|
||||
}
|
||||
if msg["role"].as_str() == Some("model") {
|
||||
if let Some(text) = msg["parts"][0]["text"].as_str() {
|
||||
if text.contains("Tool call completed")
|
||||
|| text.contains("Awaiting external tool result")
|
||||
{
|
||||
// Replace with functionCall parts
|
||||
let fc_parts: Vec<Value> = ctx
|
||||
.last_calls
|
||||
.iter()
|
||||
.map(|fc| {
|
||||
serde_json::json!({
|
||||
"functionCall": {
|
||||
"name": fc.name,
|
||||
"args": fc.args,
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
msg["parts"] = Value::Array(fc_parts);
|
||||
changes.push("rewrite model turn with functionCall".to_string());
|
||||
break;
|
||||
rewrites.push((i, round_idx));
|
||||
round_idx += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add functionResponse as a user turn before the last user message
|
||||
let fn_response_parts: Vec<Value> = ctx
|
||||
.pending_results
|
||||
.iter()
|
||||
.map(|r| {
|
||||
serde_json::json!({
|
||||
"functionResponse": {
|
||||
"name": r.name,
|
||||
"response": r.result,
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
let fn_response_turn = serde_json::json!({
|
||||
"role": "user",
|
||||
"parts": fn_response_parts,
|
||||
});
|
||||
// Phase 2: apply rewrites (reverse order for stable indices during insertion)
|
||||
let mut insert_offset = 0;
|
||||
for (content_idx, round_idx) in &rewrites {
|
||||
let actual_idx = *content_idx + insert_offset;
|
||||
let round = &rounds[*round_idx];
|
||||
|
||||
// Insert before the last user message
|
||||
let last_user_idx = contents
|
||||
.iter()
|
||||
.rposition(|msg| msg["role"].as_str() == Some("user"));
|
||||
if let Some(idx) = last_user_idx {
|
||||
contents.insert(idx, fn_response_turn);
|
||||
} else {
|
||||
contents.push(fn_response_turn);
|
||||
// Replace model turn with functionCall parts
|
||||
let fc_parts: Vec<Value> = round
|
||||
.calls
|
||||
.iter()
|
||||
.map(|fc| {
|
||||
serde_json::json!({
|
||||
"functionCall": {
|
||||
"name": fc.name,
|
||||
"args": fc.args,
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
contents[actual_idx]["parts"] = Value::Array(fc_parts);
|
||||
|
||||
// Inject functionResponse user turn right after
|
||||
if !round.results.is_empty() {
|
||||
let fr_parts: Vec<Value> = round
|
||||
.results
|
||||
.iter()
|
||||
.map(|r| {
|
||||
serde_json::json!({
|
||||
"functionResponse": {
|
||||
"name": r.name,
|
||||
"response": r.result,
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
contents.insert(
|
||||
actual_idx + 1,
|
||||
serde_json::json!({
|
||||
"role": "user",
|
||||
"parts": fr_parts,
|
||||
}),
|
||||
);
|
||||
insert_offset += 1;
|
||||
}
|
||||
}
|
||||
|
||||
if !rewrites.is_empty() {
|
||||
changes.push(format!(
|
||||
"rewrite {} tool round(s) in history",
|
||||
rewrites.len()
|
||||
));
|
||||
}
|
||||
changes.push(format!(
|
||||
"inject {} functionResponse(s)",
|
||||
ctx.pending_results.len()
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -992,6 +1020,222 @@ mod tests {
|
||||
.unwrap();
|
||||
assert_eq!(result, "keep this and this");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multi_round_history_rewrite() {
|
||||
use crate::mitm::store::{CapturedFunctionCall, PendingToolResult, ToolRound};
|
||||
|
||||
// Simulate 2 rounds of tool use in LS history:
|
||||
// user → model("Tool call completed") → user(text) → model("Tool call completed") → user(text)
|
||||
let body = serde_json::json!({
|
||||
"project": "test",
|
||||
"requestId": "test/1",
|
||||
"request": {
|
||||
"contents": [
|
||||
{"role": "user", "parts": [{"text": "Read foo and write to bar"}]},
|
||||
{"role": "model", "parts": [{"text": "Tool call completed. Awaiting external tool result."}]},
|
||||
{"role": "user", "parts": [{"text": "[Tool result: file contents here]"}]},
|
||||
{"role": "model", "parts": [{"text": "Tool call completed. Awaiting external tool result."}]},
|
||||
{"role": "user", "parts": [{"text": "[Tool result: write success]"}]},
|
||||
],
|
||||
"tools": [],
|
||||
"generationConfig": {}
|
||||
},
|
||||
"model": "test"
|
||||
});
|
||||
|
||||
let tool_ctx = ToolContext {
|
||||
tools: Some(vec![serde_json::json!({
|
||||
"functionDeclarations": [{
|
||||
"name": "read_file",
|
||||
"description": "Read a file",
|
||||
"parameters": {"type": "OBJECT", "properties": {"path": {"type": "STRING"}}}
|
||||
}, {
|
||||
"name": "write_file",
|
||||
"description": "Write a file",
|
||||
"parameters": {"type": "OBJECT", "properties": {"path": {"type": "STRING"}, "content": {"type": "STRING"}}}
|
||||
}]
|
||||
})]),
|
||||
tool_config: None,
|
||||
pending_results: vec![],
|
||||
last_calls: vec![],
|
||||
generation_params: None,
|
||||
pending_image: None,
|
||||
tool_rounds: vec![
|
||||
ToolRound {
|
||||
calls: vec![CapturedFunctionCall {
|
||||
name: "read_file".to_string(),
|
||||
args: serde_json::json!({"path": "/foo"}),
|
||||
captured_at: 0,
|
||||
}],
|
||||
results: vec![PendingToolResult {
|
||||
name: "read_file".to_string(),
|
||||
result: serde_json::json!({"content": "file contents here"}),
|
||||
}],
|
||||
},
|
||||
ToolRound {
|
||||
calls: vec![CapturedFunctionCall {
|
||||
name: "write_file".to_string(),
|
||||
args: serde_json::json!({"path": "/bar", "content": "data"}),
|
||||
captured_at: 0,
|
||||
}],
|
||||
results: vec![PendingToolResult {
|
||||
name: "write_file".to_string(),
|
||||
result: serde_json::json!({"result": "ok"}),
|
||||
}],
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
let bytes = serde_json::to_vec(&body).unwrap();
|
||||
let modified = modify_request(&bytes, Some(&tool_ctx)).unwrap();
|
||||
let result: Value = serde_json::from_slice(&modified).unwrap();
|
||||
let contents = result["request"]["contents"].as_array().unwrap();
|
||||
|
||||
// Expected layout after rewrite:
|
||||
// [0] user: "Read foo..."
|
||||
// [1] model: functionCall(read_file) (was "Tool call completed")
|
||||
// [2] user: functionResponse(read_file) (injected)
|
||||
// [3] user: "[Tool result: file contents]" (original LS turn)
|
||||
// [4] model: functionCall(write_file) (was "Tool call completed")
|
||||
// [5] user: functionResponse(write_file) (injected)
|
||||
// [6] user: "[Tool result: write success]" (original LS turn)
|
||||
assert_eq!(contents.len(), 7, "should have 7 turns (5 original + 2 injected)");
|
||||
|
||||
// Check round 1: model turn rewritten to functionCall
|
||||
assert_eq!(
|
||||
contents[1]["parts"][0]["functionCall"]["name"].as_str().unwrap(),
|
||||
"read_file"
|
||||
);
|
||||
assert_eq!(
|
||||
contents[1]["parts"][0]["functionCall"]["args"]["path"].as_str().unwrap(),
|
||||
"/foo"
|
||||
);
|
||||
// Check round 1: functionResponse injected
|
||||
assert_eq!(
|
||||
contents[2]["role"].as_str().unwrap(),
|
||||
"user"
|
||||
);
|
||||
assert_eq!(
|
||||
contents[2]["parts"][0]["functionResponse"]["name"].as_str().unwrap(),
|
||||
"read_file"
|
||||
);
|
||||
|
||||
// Check round 2: model turn rewritten to functionCall
|
||||
assert_eq!(
|
||||
contents[4]["parts"][0]["functionCall"]["name"].as_str().unwrap(),
|
||||
"write_file"
|
||||
);
|
||||
// Check round 2: functionResponse injected
|
||||
assert_eq!(
|
||||
contents[5]["parts"][0]["functionResponse"]["name"].as_str().unwrap(),
|
||||
"write_file"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_single_round_legacy_fallback() {
|
||||
use crate::mitm::store::{CapturedFunctionCall, PendingToolResult};
|
||||
|
||||
// Simulate single round using legacy last_calls/pending_results (no tool_rounds).
|
||||
// This is the path used by responses.rs.
|
||||
let body = serde_json::json!({
|
||||
"project": "test",
|
||||
"requestId": "test/1",
|
||||
"request": {
|
||||
"contents": [
|
||||
{"role": "user", "parts": [{"text": "Search for X"}]},
|
||||
{"role": "model", "parts": [{"text": "Tool call completed. Awaiting external tool result."}]},
|
||||
{"role": "user", "parts": [{"text": "[Tool result: found X]"}]},
|
||||
],
|
||||
"tools": [],
|
||||
"generationConfig": {}
|
||||
},
|
||||
"model": "test"
|
||||
});
|
||||
|
||||
let tool_ctx = ToolContext {
|
||||
tools: Some(vec![serde_json::json!({
|
||||
"functionDeclarations": [{
|
||||
"name": "search",
|
||||
"description": "Search",
|
||||
"parameters": {"type": "OBJECT", "properties": {"q": {"type": "STRING"}}}
|
||||
}]
|
||||
})]),
|
||||
tool_config: None,
|
||||
pending_results: vec![PendingToolResult {
|
||||
name: "search".to_string(),
|
||||
result: serde_json::json!({"results": ["x"]}),
|
||||
}],
|
||||
last_calls: vec![CapturedFunctionCall {
|
||||
name: "search".to_string(),
|
||||
args: serde_json::json!({"q": "X"}),
|
||||
captured_at: 0,
|
||||
}],
|
||||
generation_params: None,
|
||||
pending_image: None,
|
||||
tool_rounds: vec![], // Empty — forces legacy fallback
|
||||
};
|
||||
|
||||
let bytes = serde_json::to_vec(&body).unwrap();
|
||||
let modified = modify_request(&bytes, Some(&tool_ctx)).unwrap();
|
||||
let result: Value = serde_json::from_slice(&modified).unwrap();
|
||||
let contents = result["request"]["contents"].as_array().unwrap();
|
||||
|
||||
// Should still work: model turn rewritten + functionResponse injected
|
||||
assert_eq!(contents.len(), 4, "should have 4 turns (3 original + 1 injected)");
|
||||
assert_eq!(
|
||||
contents[1]["parts"][0]["functionCall"]["name"].as_str().unwrap(),
|
||||
"search"
|
||||
);
|
||||
assert_eq!(
|
||||
contents[2]["parts"][0]["functionResponse"]["name"].as_str().unwrap(),
|
||||
"search"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_no_tool_rounds_no_rewrite() {
|
||||
// No tool rounds, no legacy data — no rewriting should happen
|
||||
let body = serde_json::json!({
|
||||
"project": "test",
|
||||
"requestId": "test/1",
|
||||
"request": {
|
||||
"contents": [
|
||||
{"role": "user", "parts": [{"text": "Hello"}]},
|
||||
{"role": "model", "parts": [{"text": "Hi there!"}]},
|
||||
],
|
||||
"tools": [],
|
||||
"generationConfig": {}
|
||||
},
|
||||
"model": "test"
|
||||
});
|
||||
|
||||
let tool_ctx = ToolContext {
|
||||
tools: Some(vec![serde_json::json!({
|
||||
"functionDeclarations": [{
|
||||
"name": "noop",
|
||||
"description": "Does nothing",
|
||||
"parameters": {"type": "OBJECT", "properties": {}}
|
||||
}]
|
||||
})]),
|
||||
tool_config: None,
|
||||
pending_results: vec![],
|
||||
last_calls: vec![],
|
||||
generation_params: None,
|
||||
pending_image: None,
|
||||
tool_rounds: vec![],
|
||||
};
|
||||
|
||||
let bytes = serde_json::to_vec(&body).unwrap();
|
||||
let modified = modify_request(&bytes, Some(&tool_ctx)).unwrap();
|
||||
let result: Value = serde_json::from_slice(&modified).unwrap();
|
||||
let contents = result["request"]["contents"].as_array().unwrap();
|
||||
|
||||
// No rewriting — same number of turns
|
||||
assert_eq!(contents.len(), 2);
|
||||
assert_eq!(contents[1]["parts"][0]["text"].as_str().unwrap(), "Hi there!");
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Response modification ──────────────────────────────────────────────────
|
||||
|
||||
Reference in New Issue
Block a user