fix: tool call race conditions and missing completions tool result extraction

- store.rs: record_function_call now falls back to active_cascade_id
  (matching record_usage behavior) instead of blind _latest fallback
- store.rs: add cascade-aware take_function_calls(cascade_id) method
  with priority: exact match → active cascade → _latest → any key
- completions.rs: extract tool_calls from assistant messages and tool
  results from tool messages, storing them for MITM injection. This was
  the ROOT CAUSE — the completions handler stored tool definitions but
  never extracted tool results, so modify_request couldn't rewrite the
  LS conversation history with proper functionCall/functionResponse
- responses.rs: use cascade-aware take_function_calls for consistency
This commit is contained in:
Nikketryhard
2026-02-16 18:43:16 -06:00
parent 38b4130c55
commit 6bda2ecafa
3 changed files with 153 additions and 7 deletions

View File

@@ -16,6 +16,7 @@ use super::polling::{
use super::types::*; use super::types::*;
use super::util::{err_response, now_unix, upstream_err_response}; use super::util::{err_response, now_unix, upstream_err_response};
use super::AppState; use super::AppState;
use crate::mitm::store::{CapturedFunctionCall, PendingToolResult};
/// Extract a conversation/session ID from a flexible JSON value. /// Extract a conversation/session ID from a flexible JSON value.
/// Accepts a plain string or an object with an "id" field. /// Accepts a plain string or an object with an "id" field.
@@ -224,6 +225,100 @@ pub(crate) async fn handle_completions(
} }
state.mitm_store.clear_active_function_call(); state.mitm_store.clear_active_function_call();
// ── Extract tool results from messages for MITM injection ──────────
// When OpenCode sends back tool results, the messages array contains:
// 1. assistant message with tool_calls (the model's previous function calls)
// 2. tool messages with results (the executed tool outputs)
// We need to store these so modify_request can rewrite the LS's
// conversation history with proper functionCall/functionResponse parts
// instead of the placeholder "Tool call completed" text.
{
let mut last_calls: Vec<CapturedFunctionCall> = Vec::new();
let mut pending_results: Vec<PendingToolResult> = Vec::new();
for msg in &body.messages {
match msg.role.as_str() {
"assistant" => {
// Extract function calls from assistant's tool_calls
if let Some(ref tool_calls) = msg.tool_calls {
for tc in tool_calls {
if let Some(func) = tc.get("function") {
let name = func["name"].as_str().unwrap_or("unknown").to_string();
let args_str = func["arguments"].as_str().unwrap_or("{}");
let args = serde_json::from_str::<serde_json::Value>(args_str)
.unwrap_or(serde_json::json!({}));
let call_id = tc["id"].as_str().unwrap_or("").to_string();
// Register call_id → name for lookup
if !call_id.is_empty() {
state.mitm_store.register_call_id(call_id, name.clone()).await;
}
last_calls.push(CapturedFunctionCall {
name,
args,
captured_at: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
});
}
}
}
}
"tool" => {
// Extract tool results
let text = extract_message_text(&msg.content);
if let Some(ref call_id) = msg.tool_call_id {
// Look up function name from call_id
let name = state
.mitm_store
.lookup_call_id(call_id)
.await
.unwrap_or_else(|| {
// Fallback: try to find the name from last_calls by position
last_calls
.first()
.map(|fc| fc.name.clone())
.unwrap_or_else(|| "unknown_function".to_string())
});
let result_value = serde_json::from_str::<serde_json::Value>(&text)
.unwrap_or_else(|_| serde_json::json!({"result": text}));
pending_results.push(PendingToolResult {
name,
result: result_value,
});
}
}
_ => {}
}
}
if !last_calls.is_empty() {
info!(
count = last_calls.len(),
tools = ?last_calls.iter().map(|c| &c.name).collect::<Vec<_>>(),
"Completions: stored last function calls for MITM history rewrite"
);
state.mitm_store.set_last_function_calls(last_calls).await;
}
if !pending_results.is_empty() {
info!(
count = pending_results.len(),
tools = ?pending_results.iter().map(|r| &r.name).collect::<Vec<_>>(),
"Completions: stored tool results for MITM injection"
);
for result in pending_results {
state.mitm_store.add_tool_result(result).await;
}
// Clear awaiting flag — we have the results now
state.mitm_store.clear_awaiting_tool_result();
}
}
// Store generation parameters for MITM injection // Store generation parameters for MITM injection
{ {
use crate::mitm::store::GenerationParams; use crate::mitm::store::GenerationParams;
@@ -584,7 +679,7 @@ async fn chat_completions_stream(
// ── Check for MITM-captured function calls FIRST ── // ── Check for MITM-captured function calls FIRST ──
// This runs independently of LS steps — the MITM captures tool calls // This runs independently of LS steps — the MITM captures tool calls
// at the proxy layer, so we don't need to wait for LS processing. // at the proxy layer, so we don't need to wait for LS processing.
let captured = state.mitm_store.take_any_function_calls().await; let captured = state.mitm_store.take_function_calls(&cascade_id).await;
if let Some(ref calls) = captured { if let Some(ref calls) = captured {
if !calls.is_empty() { if !calls.is_empty() {
let mut tool_calls = Vec::new(); let mut tool_calls = Vec::new();

View File

@@ -599,7 +599,7 @@ async fn handle_responses_sync(
let start = std::time::Instant::now(); let start = std::time::Instant::now();
while start.elapsed().as_secs() < timeout { while start.elapsed().as_secs() < timeout {
// Check for function calls // Check for function calls
let captured = state.mitm_store.take_any_function_calls().await; let captured = state.mitm_store.take_function_calls(&cascade_id).await;
if let Some(ref raw_calls) = captured { if let Some(ref raw_calls) = captured {
let calls: Vec<_> = if let Some(max) = params.max_tool_calls { let calls: Vec<_> = if let Some(max) = params.max_tool_calls {
raw_calls.iter().take(max as usize).collect() raw_calls.iter().take(max as usize).collect()
@@ -715,7 +715,7 @@ async fn handle_responses_sync(
let msg_id = format!("msg_{}", uuid::Uuid::new_v4().to_string().replace('-', "")); let msg_id = format!("msg_{}", uuid::Uuid::new_v4().to_string().replace('-', ""));
// Check for captured function calls from MITM (clears the active flag) // Check for captured function calls from MITM (clears the active flag)
let captured_tool_calls = state.mitm_store.take_any_function_calls().await; let captured_tool_calls = state.mitm_store.take_function_calls(&cascade_id).await;
// Enforce max_tool_calls limit // Enforce max_tool_calls limit
let captured_tool_calls = captured_tool_calls.map(|mut calls| { let captured_tool_calls = captured_tool_calls.map(|mut calls| {
@@ -909,7 +909,7 @@ async fn handle_responses_stream(
} }
// Check for function calls first // Check for function calls first
let captured = state.mitm_store.take_any_function_calls().await; let captured = state.mitm_store.take_function_calls(&cascade_id).await;
if let Some(ref raw_calls) = captured { if let Some(ref raw_calls) = captured {
let calls: Vec<_> = if let Some(max) = params.max_tool_calls { let calls: Vec<_> = if let Some(max) = params.max_tool_calls {
raw_calls.iter().take(max as usize).collect() raw_calls.iter().take(max as usize).collect()

View File

@@ -345,10 +345,18 @@ impl MitmStore {
} }
/// Record a captured function call from Google's response. /// Record a captured function call from Google's response.
///
/// Falls back to `active_cascade_id` (set by the API handler) when no
/// cascade hint is available from the request body, matching
/// `record_usage`'s fallback behavior for consistent correlation.
pub async fn record_function_call(&self, cascade_id: Option<&str>, fc: CapturedFunctionCall) { pub async fn record_function_call(&self, cascade_id: Option<&str>, fc: CapturedFunctionCall) {
let key = cascade_id let key = if let Some(cid) = cascade_id {
.map(|s| s.to_string()) cid.to_string()
.unwrap_or_else(|| "_latest".to_string()); } else if let Some(active) = self.active_cascade_id.read().await.as_ref() {
active.clone()
} else {
"_latest".to_string()
};
info!( info!(
cascade = %key, cascade = %key,
tool = %fc.name, tool = %fc.name,
@@ -383,7 +391,50 @@ impl MitmStore {
self.awaiting_tool_result.store(false, Ordering::SeqCst); self.awaiting_tool_result.store(false, Ordering::SeqCst);
} }
/// Take pending function calls for a specific cascade.
///
/// Priority: exact cascade_id → active_cascade_id → `_latest` → any key.
/// This prevents cross-cascade contamination when multiple requests are
/// in-flight simultaneously.
pub async fn take_function_calls(&self, cascade_id: &str) -> Option<Vec<CapturedFunctionCall>> {
let mut pending = self.pending_function_calls.write().await;
// 1. Exact cascade match
if let Some(result) = pending.remove(cascade_id) {
self.has_active_function_call.store(false, Ordering::SeqCst);
return Some(result);
}
// 2. Active cascade (set by API handler)
if let Some(active) = self.active_cascade_id.read().await.as_ref() {
if active != cascade_id {
if let Some(result) = pending.remove(active.as_str()) {
self.has_active_function_call.store(false, Ordering::SeqCst);
return Some(result);
}
}
}
// 3. Fallback to _latest
if let Some(result) = pending.remove("_latest") {
self.has_active_function_call.store(false, Ordering::SeqCst);
return Some(result);
}
// 4. Last resort: any key
if let Some(key) = pending.keys().next().cloned() {
let result = pending.remove(&key);
if result.is_some() {
self.has_active_function_call.store(false, Ordering::SeqCst);
}
return result;
}
None
}
/// Take any pending function calls (ignoring cascade ID). /// Take any pending function calls (ignoring cascade ID).
/// Legacy method — prefer `take_function_calls(cascade_id)` for proper correlation.
pub async fn take_any_function_calls(&self) -> Option<Vec<CapturedFunctionCall>> { pub async fn take_any_function_calls(&self) -> Option<Vec<CapturedFunctionCall>> {
let mut pending = self.pending_function_calls.write().await; let mut pending = self.pending_function_calls.write().await;
let result = pending.remove("_latest"); let result = pending.remove("_latest");