fix: multi-round tool history rewrite and finishReason handling

- Add ToolRound struct to pair function calls with results per-round
- Replace single-match history rewrite (broke after first round) with
  multi-round loop that rewrites ALL placeholder model turns
- Fix tool result name fallback: use positional index instead of always
  picking the first call
- Set is_complete for any finishReason (FUNCTION_CALL, MAX_TOKENS, etc.)
  not just STOP — prevents response_complete flag from never being set
- Legacy fallback: responses.rs path (single-round via last_calls +
  pending_results) still works when tool_rounds is empty
- Add tests: multi-round rewrite, single-round legacy, no-op, and
  FUNCTION_CALL/MAX_TOKENS finishReason handling
This commit is contained in:
Nikketryhard
2026-02-16 19:05:37 -06:00
parent 6bda2ecafa
commit 39381a4dfe
5 changed files with 410 additions and 88 deletions

View File

@@ -16,7 +16,7 @@ use super::polling::{
use super::types::*; use super::types::*;
use super::util::{err_response, now_unix, upstream_err_response}; use super::util::{err_response, now_unix, upstream_err_response};
use super::AppState; use super::AppState;
use crate::mitm::store::{CapturedFunctionCall, PendingToolResult}; use crate::mitm::store::{CapturedFunctionCall, PendingToolResult, ToolRound};
/// Extract a conversation/session ID from a flexible JSON value. /// Extract a conversation/session ID from a flexible JSON value.
/// Accepts a plain string or an object with an "id" field. /// Accepts a plain string or an object with an "id" field.
@@ -229,18 +229,25 @@ pub(crate) async fn handle_completions(
// When OpenCode sends back tool results, the messages array contains: // When OpenCode sends back tool results, the messages array contains:
// 1. assistant message with tool_calls (the model's previous function calls) // 1. assistant message with tool_calls (the model's previous function calls)
// 2. tool messages with results (the executed tool outputs) // 2. tool messages with results (the executed tool outputs)
// We need to store these so modify_request can rewrite the LS's // We build ToolRounds: each round pairs one assistant's tool_calls with
// conversation history with proper functionCall/functionResponse parts // the subsequent tool result messages. This enables correct per-turn
// instead of the placeholder "Tool call completed" text. // history rewriting for multi-step tool use.
{ {
let mut last_calls: Vec<CapturedFunctionCall> = Vec::new(); let mut rounds: Vec<ToolRound> = Vec::new();
let mut pending_results: Vec<PendingToolResult> = Vec::new(); let mut current_round: Option<ToolRound> = None;
for msg in &body.messages { for msg in &body.messages {
match msg.role.as_str() { match msg.role.as_str() {
"assistant" => { "assistant" => {
// Extract function calls from assistant's tool_calls // Finalize any open round
if let Some(round) = current_round.take() {
if !round.calls.is_empty() {
rounds.push(round);
}
}
// Start new round if this assistant has tool_calls
if let Some(ref tool_calls) = msg.tool_calls { if let Some(ref tool_calls) = msg.tool_calls {
let mut calls = Vec::new();
for tc in tool_calls { for tc in tool_calls {
if let Some(func) = tc.get("function") { if let Some(func) = tc.get("function") {
let name = func["name"].as_str().unwrap_or("unknown").to_string(); let name = func["name"].as_str().unwrap_or("unknown").to_string();
@@ -254,7 +261,7 @@ pub(crate) async fn handle_completions(
state.mitm_store.register_call_id(call_id, name.clone()).await; state.mitm_store.register_call_id(call_id, name.clone()).await;
} }
last_calls.push(CapturedFunctionCall { calls.push(CapturedFunctionCall {
name, name,
args, args,
captured_at: std::time::SystemTime::now() captured_at: std::time::SystemTime::now()
@@ -264,21 +271,31 @@ pub(crate) async fn handle_completions(
}); });
} }
} }
if !calls.is_empty() {
current_round = Some(ToolRound {
calls,
results: Vec::new(),
});
}
} }
} }
"tool" => { "tool" => {
// Extract tool results
let text = extract_message_text(&msg.content); let text = extract_message_text(&msg.content);
if let Some(ref call_id) = msg.tool_call_id { if let Some(ref call_id) = msg.tool_call_id {
// Look up function name from call_id // Look up function name from call_id, fall back to
// positional index within the current round's calls
let result_index = current_round
.as_ref()
.map(|r| r.results.len())
.unwrap_or(0);
let name = state let name = state
.mitm_store .mitm_store
.lookup_call_id(call_id) .lookup_call_id(call_id)
.await .await
.unwrap_or_else(|| { .unwrap_or_else(|| {
// Fallback: try to find the name from last_calls by position current_round
last_calls .as_ref()
.first() .and_then(|r| r.calls.get(result_index))
.map(|fc| fc.name.clone()) .map(|fc| fc.name.clone())
.unwrap_or_else(|| "unknown_function".to_string()) .unwrap_or_else(|| "unknown_function".to_string())
}); });
@@ -286,35 +303,43 @@ pub(crate) async fn handle_completions(
let result_value = serde_json::from_str::<serde_json::Value>(&text) let result_value = serde_json::from_str::<serde_json::Value>(&text)
.unwrap_or_else(|_| serde_json::json!({"result": text})); .unwrap_or_else(|_| serde_json::json!({"result": text}));
pending_results.push(PendingToolResult { if let Some(ref mut round) = current_round {
round.results.push(PendingToolResult {
name, name,
result: result_value, result: result_value,
}); });
} }
} }
_ => {} }
_ => {
// Any other role (user, system) finalizes the current round
if let Some(round) = current_round.take() {
if !round.calls.is_empty() {
rounds.push(round);
}
}
}
}
}
// Finalize last round
if let Some(round) = current_round.take() {
if !round.calls.is_empty() {
rounds.push(round);
} }
} }
if !last_calls.is_empty() { if !rounds.is_empty() {
info!( info!(
count = last_calls.len(), round_count = rounds.len(),
tools = ?last_calls.iter().map(|c| &c.name).collect::<Vec<_>>(), calls = ?rounds.iter().map(|r| r.calls.iter().map(|c| &c.name).collect::<Vec<_>>()).collect::<Vec<_>>(),
"Completions: stored last function calls for MITM history rewrite" "Completions: stored {} tool round(s) for MITM history rewrite",
rounds.len(),
); );
state.mitm_store.set_last_function_calls(last_calls).await; // Also set last_function_calls from the latest round for proxy.rs recording compat
if let Some(last_round) = rounds.last() {
state.mitm_store.set_last_function_calls(last_round.calls.clone()).await;
} }
state.mitm_store.set_tool_rounds(rounds).await;
if !pending_results.is_empty() {
info!(
count = pending_results.len(),
tools = ?pending_results.iter().map(|r| &r.name).collect::<Vec<_>>(),
"Completions: stored tool results for MITM injection"
);
for result in pending_results {
state.mitm_store.add_tool_result(result).await;
}
// Clear awaiting flag — we have the results now
state.mitm_store.clear_awaiting_tool_result(); state.mitm_store.clear_awaiting_tool_result();
} }
} }

View File

@@ -153,12 +153,10 @@ impl StreamingAccumulator {
} }
} }
} }
// Check for completion // Check for completion — any finishReason means response is done
if let Some(reason) = candidate["finishReason"].as_str() { if let Some(reason) = candidate["finishReason"].as_str() {
self.stop_reason = Some(reason.to_string()); self.stop_reason = Some(reason.to_string());
if reason == "STOP" {
self.is_complete = true; self.is_complete = true;
}
// Log non-STOP finish reasons // Log non-STOP finish reasons
if reason != "STOP" { if reason != "STOP" {
info!(finish_reason = reason, "MITM: non-STOP finish reason"); info!(finish_reason = reason, "MITM: non-STOP finish reason");
@@ -589,4 +587,30 @@ data: {"response": {"candidates": [{"content": {"role": "model","parts": [{"text
); );
assert_eq!(acc.stop_reason, Some("STOP".to_string())); assert_eq!(acc.stop_reason, Some("STOP".to_string()));
} }
#[test]
fn test_function_call_finish_reason_sets_complete() {
let mut acc = StreamingAccumulator::new();
let event = "data: {\"response\": {\"candidates\": [{\"content\": {\"role\": \"model\", \"parts\": [{\"functionCall\": {\"name\": \"read_file\", \"args\": {\"path\": \"/foo\"}}}]}, \"finishReason\": \"FUNCTION_CALL\"}], \"usageMetadata\": {\"promptTokenCount\": 50, \"candidatesTokenCount\": 5, \"totalTokenCount\": 55}, \"modelVersion\": \"gemini-3-flash\"}}\n";
parse_streaming_chunk(event, &mut acc);
assert!(acc.is_complete, "FUNCTION_CALL finishReason should set is_complete");
assert_eq!(acc.stop_reason, Some("FUNCTION_CALL".to_string()));
assert_eq!(acc.function_calls.len(), 1);
assert_eq!(acc.function_calls[0].name, "read_file");
assert_eq!(acc.output_tokens, 5);
}
#[test]
fn test_max_tokens_finish_reason_sets_complete() {
let mut acc = StreamingAccumulator::new();
let event = "data: {\"response\": {\"candidates\": [{\"content\": {\"role\": \"model\", \"parts\": [{\"text\": \"truncated...\"}]}, \"finishReason\": \"MAX_TOKENS\"}], \"usageMetadata\": {\"promptTokenCount\": 50, \"candidatesTokenCount\": 100, \"totalTokenCount\": 150}}}\n";
parse_streaming_chunk(event, &mut acc);
assert!(acc.is_complete, "MAX_TOKENS finishReason should set is_complete");
assert_eq!(acc.stop_reason, Some("MAX_TOKENS".to_string()));
assert_eq!(acc.response_text, "truncated...");
}
} }

View File

@@ -8,7 +8,7 @@ use regex::Regex;
use serde_json::Value; use serde_json::Value;
use tracing::info; use tracing::info;
use super::store::{CapturedFunctionCall, PendingImage, PendingToolResult}; use super::store::{CapturedFunctionCall, PendingImage, PendingToolResult, ToolRound};
/// Strip ALL tool definitions. /// Strip ALL tool definitions.
/// Must be true: with tools present, the LS enters full agentic mode /// Must be true: with tools present, the LS enters full agentic mode
@@ -30,6 +30,9 @@ pub struct ToolContext {
pub generation_params: Option<super::store::GenerationParams>, pub generation_params: Option<super::store::GenerationParams>,
/// Pending image to inject as inlineData in the user message. /// Pending image to inject as inlineData in the user message.
pub pending_image: Option<PendingImage>, pub pending_image: Option<PendingImage>,
/// Multi-round tool call history. Each entry is a (calls, results) pair
/// from one round of tool use. Preferred over last_calls/pending_results.
pub tool_rounds: Vec<ToolRound>,
} }
/// Modify a streamGenerateContent request body in-place. /// Modify a streamGenerateContent request body in-place.
@@ -355,22 +358,53 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option<Vec
// ── 3b. Rewrite conversation history for tool results ──────────── // ── 3b. Rewrite conversation history for tool results ────────────
if let Some(ref ctx) = tool_ctx { if let Some(ref ctx) = tool_ctx {
if !ctx.pending_results.is_empty() && !ctx.last_calls.is_empty() { // Prefer multi-round tool_rounds (set by completions handler) over
// legacy last_calls/pending_results (set by responses handler).
let rounds = if !ctx.tool_rounds.is_empty() {
ctx.tool_rounds.clone()
} else if !ctx.pending_results.is_empty() && !ctx.last_calls.is_empty() {
// Legacy single-round: wrap in one ToolRound
vec![ToolRound {
calls: ctx.last_calls.clone(),
results: ctx.pending_results.clone(),
}]
} else {
vec![]
};
if !rounds.is_empty() {
if let Some(contents) = json if let Some(contents) = json
.pointer_mut("/request/contents") .pointer_mut("/request/contents")
.and_then(|v| v.as_array_mut()) .and_then(|v| v.as_array_mut())
{ {
// Find the model turn with our fake "Tool call completed" text and replace it // Phase 1: find ALL model turns with placeholder text, pair with rounds
// with the actual functionCall parts let mut rewrites: Vec<(usize, usize)> = Vec::new(); // (content_index, round_index)
for msg in contents.iter_mut() { let mut round_idx = 0;
for (i, msg) in contents.iter().enumerate() {
if round_idx >= rounds.len() {
break;
}
if msg["role"].as_str() == Some("model") { if msg["role"].as_str() == Some("model") {
if let Some(text) = msg["parts"][0]["text"].as_str() { if let Some(text) = msg["parts"][0]["text"].as_str() {
if text.contains("Tool call completed") if text.contains("Tool call completed")
|| text.contains("Awaiting external tool result") || text.contains("Awaiting external tool result")
{ {
// Replace with functionCall parts rewrites.push((i, round_idx));
let fc_parts: Vec<Value> = ctx round_idx += 1;
.last_calls }
}
}
}
// Phase 2: apply rewrites (reverse order for stable indices during insertion)
let mut insert_offset = 0;
for (content_idx, round_idx) in &rewrites {
let actual_idx = *content_idx + insert_offset;
let round = &rounds[*round_idx];
// Replace model turn with functionCall parts
let fc_parts: Vec<Value> = round
.calls
.iter() .iter()
.map(|fc| { .map(|fc| {
serde_json::json!({ serde_json::json!({
@@ -381,17 +415,12 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option<Vec
}) })
}) })
.collect(); .collect();
msg["parts"] = Value::Array(fc_parts); contents[actual_idx]["parts"] = Value::Array(fc_parts);
changes.push("rewrite model turn with functionCall".to_string());
break;
}
}
}
}
// Add functionResponse as a user turn before the last user message // Inject functionResponse user turn right after
let fn_response_parts: Vec<Value> = ctx if !round.results.is_empty() {
.pending_results let fr_parts: Vec<Value> = round
.results
.iter() .iter()
.map(|r| { .map(|r| {
serde_json::json!({ serde_json::json!({
@@ -402,27 +431,26 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option<Vec
}) })
}) })
.collect(); .collect();
let fn_response_turn = serde_json::json!({ contents.insert(
actual_idx + 1,
serde_json::json!({
"role": "user", "role": "user",
"parts": fn_response_parts, "parts": fr_parts,
}); }),
);
// Insert before the last user message insert_offset += 1;
let last_user_idx = contents
.iter()
.rposition(|msg| msg["role"].as_str() == Some("user"));
if let Some(idx) = last_user_idx {
contents.insert(idx, fn_response_turn);
} else {
contents.push(fn_response_turn);
} }
}
if !rewrites.is_empty() {
changes.push(format!( changes.push(format!(
"inject {} functionResponse(s)", "rewrite {} tool round(s) in history",
ctx.pending_results.len() rewrites.len()
)); ));
} }
} }
} }
}
// ── 4. Inject includeThoughts to capture thinking text ─────────────── // ── 4. Inject includeThoughts to capture thinking text ───────────────
// Without this flag, Google only reports thinking token counts // Without this flag, Google only reports thinking token counts
@@ -992,6 +1020,222 @@ mod tests {
.unwrap(); .unwrap();
assert_eq!(result, "keep this and this"); assert_eq!(result, "keep this and this");
} }
#[test]
fn test_multi_round_history_rewrite() {
use crate::mitm::store::{CapturedFunctionCall, PendingToolResult, ToolRound};
// Simulate 2 rounds of tool use in LS history:
// user → model("Tool call completed") → user(text) → model("Tool call completed") → user(text)
let body = serde_json::json!({
"project": "test",
"requestId": "test/1",
"request": {
"contents": [
{"role": "user", "parts": [{"text": "Read foo and write to bar"}]},
{"role": "model", "parts": [{"text": "Tool call completed. Awaiting external tool result."}]},
{"role": "user", "parts": [{"text": "[Tool result: file contents here]"}]},
{"role": "model", "parts": [{"text": "Tool call completed. Awaiting external tool result."}]},
{"role": "user", "parts": [{"text": "[Tool result: write success]"}]},
],
"tools": [],
"generationConfig": {}
},
"model": "test"
});
let tool_ctx = ToolContext {
tools: Some(vec![serde_json::json!({
"functionDeclarations": [{
"name": "read_file",
"description": "Read a file",
"parameters": {"type": "OBJECT", "properties": {"path": {"type": "STRING"}}}
}, {
"name": "write_file",
"description": "Write a file",
"parameters": {"type": "OBJECT", "properties": {"path": {"type": "STRING"}, "content": {"type": "STRING"}}}
}]
})]),
tool_config: None,
pending_results: vec![],
last_calls: vec![],
generation_params: None,
pending_image: None,
tool_rounds: vec![
ToolRound {
calls: vec![CapturedFunctionCall {
name: "read_file".to_string(),
args: serde_json::json!({"path": "/foo"}),
captured_at: 0,
}],
results: vec![PendingToolResult {
name: "read_file".to_string(),
result: serde_json::json!({"content": "file contents here"}),
}],
},
ToolRound {
calls: vec![CapturedFunctionCall {
name: "write_file".to_string(),
args: serde_json::json!({"path": "/bar", "content": "data"}),
captured_at: 0,
}],
results: vec![PendingToolResult {
name: "write_file".to_string(),
result: serde_json::json!({"result": "ok"}),
}],
},
],
};
let bytes = serde_json::to_vec(&body).unwrap();
let modified = modify_request(&bytes, Some(&tool_ctx)).unwrap();
let result: Value = serde_json::from_slice(&modified).unwrap();
let contents = result["request"]["contents"].as_array().unwrap();
// Expected layout after rewrite:
// [0] user: "Read foo..."
// [1] model: functionCall(read_file) (was "Tool call completed")
// [2] user: functionResponse(read_file) (injected)
// [3] user: "[Tool result: file contents]" (original LS turn)
// [4] model: functionCall(write_file) (was "Tool call completed")
// [5] user: functionResponse(write_file) (injected)
// [6] user: "[Tool result: write success]" (original LS turn)
assert_eq!(contents.len(), 7, "should have 7 turns (5 original + 2 injected)");
// Check round 1: model turn rewritten to functionCall
assert_eq!(
contents[1]["parts"][0]["functionCall"]["name"].as_str().unwrap(),
"read_file"
);
assert_eq!(
contents[1]["parts"][0]["functionCall"]["args"]["path"].as_str().unwrap(),
"/foo"
);
// Check round 1: functionResponse injected
assert_eq!(
contents[2]["role"].as_str().unwrap(),
"user"
);
assert_eq!(
contents[2]["parts"][0]["functionResponse"]["name"].as_str().unwrap(),
"read_file"
);
// Check round 2: model turn rewritten to functionCall
assert_eq!(
contents[4]["parts"][0]["functionCall"]["name"].as_str().unwrap(),
"write_file"
);
// Check round 2: functionResponse injected
assert_eq!(
contents[5]["parts"][0]["functionResponse"]["name"].as_str().unwrap(),
"write_file"
);
}
#[test]
fn test_single_round_legacy_fallback() {
use crate::mitm::store::{CapturedFunctionCall, PendingToolResult};
// Simulate single round using legacy last_calls/pending_results (no tool_rounds).
// This is the path used by responses.rs.
let body = serde_json::json!({
"project": "test",
"requestId": "test/1",
"request": {
"contents": [
{"role": "user", "parts": [{"text": "Search for X"}]},
{"role": "model", "parts": [{"text": "Tool call completed. Awaiting external tool result."}]},
{"role": "user", "parts": [{"text": "[Tool result: found X]"}]},
],
"tools": [],
"generationConfig": {}
},
"model": "test"
});
let tool_ctx = ToolContext {
tools: Some(vec![serde_json::json!({
"functionDeclarations": [{
"name": "search",
"description": "Search",
"parameters": {"type": "OBJECT", "properties": {"q": {"type": "STRING"}}}
}]
})]),
tool_config: None,
pending_results: vec![PendingToolResult {
name: "search".to_string(),
result: serde_json::json!({"results": ["x"]}),
}],
last_calls: vec![CapturedFunctionCall {
name: "search".to_string(),
args: serde_json::json!({"q": "X"}),
captured_at: 0,
}],
generation_params: None,
pending_image: None,
tool_rounds: vec![], // Empty — forces legacy fallback
};
let bytes = serde_json::to_vec(&body).unwrap();
let modified = modify_request(&bytes, Some(&tool_ctx)).unwrap();
let result: Value = serde_json::from_slice(&modified).unwrap();
let contents = result["request"]["contents"].as_array().unwrap();
// Should still work: model turn rewritten + functionResponse injected
assert_eq!(contents.len(), 4, "should have 4 turns (3 original + 1 injected)");
assert_eq!(
contents[1]["parts"][0]["functionCall"]["name"].as_str().unwrap(),
"search"
);
assert_eq!(
contents[2]["parts"][0]["functionResponse"]["name"].as_str().unwrap(),
"search"
);
}
#[test]
fn test_no_tool_rounds_no_rewrite() {
// No tool rounds, no legacy data — no rewriting should happen
let body = serde_json::json!({
"project": "test",
"requestId": "test/1",
"request": {
"contents": [
{"role": "user", "parts": [{"text": "Hello"}]},
{"role": "model", "parts": [{"text": "Hi there!"}]},
],
"tools": [],
"generationConfig": {}
},
"model": "test"
});
let tool_ctx = ToolContext {
tools: Some(vec![serde_json::json!({
"functionDeclarations": [{
"name": "noop",
"description": "Does nothing",
"parameters": {"type": "OBJECT", "properties": {}}
}]
})]),
tool_config: None,
pending_results: vec![],
last_calls: vec![],
generation_params: None,
pending_image: None,
tool_rounds: vec![],
};
let bytes = serde_json::to_vec(&body).unwrap();
let modified = modify_request(&bytes, Some(&tool_ctx)).unwrap();
let result: Value = serde_json::from_slice(&modified).unwrap();
let contents = result["request"]["contents"].as_array().unwrap();
// No rewriting — same number of turns
assert_eq!(contents.len(), 2);
assert_eq!(contents[1]["parts"][0]["text"].as_str().unwrap(), "Hi there!");
}
} }
// ─── Response modification ────────────────────────────────────────────────── // ─── Response modification ──────────────────────────────────────────────────

View File

@@ -602,9 +602,11 @@ async fn handle_http_over_tls(
let last_calls = store.get_last_function_calls().await; let last_calls = store.get_last_function_calls().await;
let generation_params = store.get_generation_params().await; let generation_params = store.get_generation_params().await;
let pending_image = store.take_pending_image().await; let pending_image = store.take_pending_image().await;
let tool_rounds = store.take_tool_rounds().await;
let tool_ctx = if tools.is_some() let tool_ctx = if tools.is_some()
|| !pending_results.is_empty() || !pending_results.is_empty()
|| !tool_rounds.is_empty()
|| generation_params.is_some() || generation_params.is_some()
|| pending_image.is_some() || pending_image.is_some()
{ {
@@ -615,6 +617,7 @@ async fn handle_http_over_tls(
last_calls, last_calls,
generation_params, generation_params,
pending_image, pending_image,
tool_rounds,
}) })
} else { } else {
None None

View File

@@ -60,6 +60,18 @@ pub struct PendingToolResult {
pub result: serde_json::Value, pub result: serde_json::Value,
} }
/// A single round of tool calling: the model's function calls paired with
/// the client's execution results.
///
/// In multi-step tool use, each round has its own calls and results.
/// This preserves per-turn data so history rewriting can map each
/// "Tool call completed" model turn to the correct functionCall/functionResponse.
#[derive(Debug, Clone)]
pub struct ToolRound {
pub calls: Vec<CapturedFunctionCall>,
pub results: Vec<PendingToolResult>,
}
/// An upstream error captured from Google's API response. /// An upstream error captured from Google's API response.
/// Stored by the MITM proxy so API handlers can return it to the client /// Stored by the MITM proxy so API handlers can return it to the client
/// instead of hanging forever waiting for a response that won't come. /// instead of hanging forever waiting for a response that won't come.
@@ -152,6 +164,9 @@ pub struct MitmStore {
call_id_to_name: Arc<RwLock<HashMap<String, String>>>, call_id_to_name: Arc<RwLock<HashMap<String, String>>>,
/// Last captured function calls (for conversation history rewriting). /// Last captured function calls (for conversation history rewriting).
last_function_calls: Arc<RwLock<Vec<CapturedFunctionCall>>>, last_function_calls: Arc<RwLock<Vec<CapturedFunctionCall>>>,
/// Multi-round tool call history for correct per-turn history rewriting.
/// Set by completions/responses handler, consumed by modify_request.
tool_rounds: Arc<RwLock<Vec<ToolRound>>>,
// ── Cascade correlation ────────────────────────────────────────────── // ── Cascade correlation ──────────────────────────────────────────────
/// Active cascade ID set by the API layer before sending a message. /// Active cascade ID set by the API layer before sending a message.
@@ -223,6 +238,7 @@ impl MitmStore {
pending_tool_results: Arc::new(RwLock::new(Vec::new())), pending_tool_results: Arc::new(RwLock::new(Vec::new())),
call_id_to_name: Arc::new(RwLock::new(HashMap::new())), call_id_to_name: Arc::new(RwLock::new(HashMap::new())),
last_function_calls: Arc::new(RwLock::new(Vec::new())), last_function_calls: Arc::new(RwLock::new(Vec::new())),
tool_rounds: Arc::new(RwLock::new(Vec::new())),
active_cascade_id: Arc::new(RwLock::new(None)), active_cascade_id: Arc::new(RwLock::new(None)),
captured_response_text: Arc::new(RwLock::new(None)), captured_response_text: Arc::new(RwLock::new(None)),
captured_thinking_text: Arc::new(RwLock::new(None)), captured_thinking_text: Arc::new(RwLock::new(None)),
@@ -511,6 +527,16 @@ impl MitmStore {
self.last_function_calls.read().await.clone() self.last_function_calls.read().await.clone()
} }
/// Store multi-round tool call history for correct per-turn history rewriting.
pub async fn set_tool_rounds(&self, rounds: Vec<ToolRound>) {
*self.tool_rounds.write().await = rounds;
}
/// Take (consume) multi-round tool call history.
pub async fn take_tool_rounds(&self) -> Vec<ToolRound> {
std::mem::take(&mut *self.tool_rounds.write().await)
}
// ── Direct response capture (bypass LS) ────────────────────────────── // ── Direct response capture (bypass LS) ──────────────────────────────
/// Set (replace) the captured response text. /// Set (replace) the captured response text.