fix: multi-round tool history rewrite and finishReason handling

- Add ToolRound struct to pair function calls with results per-round
- Replace single-match history rewrite (broke after first round) with
  multi-round loop that rewrites ALL placeholder model turns
- Fix tool result name fallback: use positional index instead of always
  picking the first call
- Set is_complete for any finishReason (FUNCTION_CALL, MAX_TOKENS, etc.)
  not just STOP — prevents response_complete flag from never being set
- Legacy fallback: responses.rs path (single-round via last_calls +
  pending_results) still works when tool_rounds is empty
- Add tests: multi-round rewrite, single-round legacy, no-op, and
  FUNCTION_CALL/MAX_TOKENS finishReason handling
This commit is contained in:
Nikketryhard
2026-02-16 19:05:37 -06:00
parent 6bda2ecafa
commit 39381a4dfe
5 changed files with 410 additions and 88 deletions

View File

@@ -60,6 +60,18 @@ pub struct PendingToolResult {
pub result: serde_json::Value,
}
/// A single round of tool calling: the model's function calls paired with
/// the client's execution results.
///
/// In multi-step tool use, each round has its own calls and results.
/// This preserves per-turn data so history rewriting can map each
/// "Tool call completed" model turn to the correct functionCall/functionResponse.
#[derive(Debug, Clone)]
pub struct ToolRound {
pub calls: Vec<CapturedFunctionCall>,
pub results: Vec<PendingToolResult>,
}
/// An upstream error captured from Google's API response.
/// Stored by the MITM proxy so API handlers can return it to the client
/// instead of hanging forever waiting for a response that won't come.
@@ -152,6 +164,9 @@ pub struct MitmStore {
call_id_to_name: Arc<RwLock<HashMap<String, String>>>,
/// Last captured function calls (for conversation history rewriting).
last_function_calls: Arc<RwLock<Vec<CapturedFunctionCall>>>,
/// Multi-round tool call history for correct per-turn history rewriting.
/// Set by completions/responses handler, consumed by modify_request.
tool_rounds: Arc<RwLock<Vec<ToolRound>>>,
// ── Cascade correlation ──────────────────────────────────────────────
/// Active cascade ID set by the API layer before sending a message.
@@ -223,6 +238,7 @@ impl MitmStore {
pending_tool_results: Arc::new(RwLock::new(Vec::new())),
call_id_to_name: Arc::new(RwLock::new(HashMap::new())),
last_function_calls: Arc::new(RwLock::new(Vec::new())),
tool_rounds: Arc::new(RwLock::new(Vec::new())),
active_cascade_id: Arc::new(RwLock::new(None)),
captured_response_text: Arc::new(RwLock::new(None)),
captured_thinking_text: Arc::new(RwLock::new(None)),
@@ -511,6 +527,16 @@ impl MitmStore {
self.last_function_calls.read().await.clone()
}
/// Store multi-round tool call history for correct per-turn history rewriting.
pub async fn set_tool_rounds(&self, rounds: Vec<ToolRound>) {
*self.tool_rounds.write().await = rounds;
}
/// Take (consume) multi-round tool call history.
pub async fn take_tool_rounds(&self) -> Vec<ToolRound> {
std::mem::take(&mut *self.tool_rounds.write().await)
}
// ── Direct response capture (bypass LS) ──────────────────────────────
/// Set (replace) the captured response text.