fix: tool call race conditions and missing completions tool result extraction
- store.rs: record_function_call now falls back to active_cascade_id (matching record_usage behavior) instead of blind _latest fallback - store.rs: add cascade-aware take_function_calls(cascade_id) method with priority: exact match → active cascade → _latest → any key - completions.rs: extract tool_calls from assistant messages and tool results from tool messages, storing them for MITM injection. This was the ROOT CAUSE — the completions handler stored tool definitions but never extracted tool results, so modify_request couldn't rewrite the LS conversation history with proper functionCall/functionResponse - responses.rs: use cascade-aware take_function_calls for consistency
This commit is contained in:
@@ -16,6 +16,7 @@ use super::polling::{
|
||||
use super::types::*;
|
||||
use super::util::{err_response, now_unix, upstream_err_response};
|
||||
use super::AppState;
|
||||
use crate::mitm::store::{CapturedFunctionCall, PendingToolResult};
|
||||
|
||||
/// Extract a conversation/session ID from a flexible JSON value.
|
||||
/// Accepts a plain string or an object with an "id" field.
|
||||
@@ -224,6 +225,100 @@ pub(crate) async fn handle_completions(
|
||||
}
|
||||
state.mitm_store.clear_active_function_call();
|
||||
|
||||
// ── Extract tool results from messages for MITM injection ──────────
|
||||
// When OpenCode sends back tool results, the messages array contains:
|
||||
// 1. assistant message with tool_calls (the model's previous function calls)
|
||||
// 2. tool messages with results (the executed tool outputs)
|
||||
// We need to store these so modify_request can rewrite the LS's
|
||||
// conversation history with proper functionCall/functionResponse parts
|
||||
// instead of the placeholder "Tool call completed" text.
|
||||
{
|
||||
let mut last_calls: Vec<CapturedFunctionCall> = Vec::new();
|
||||
let mut pending_results: Vec<PendingToolResult> = Vec::new();
|
||||
|
||||
for msg in &body.messages {
|
||||
match msg.role.as_str() {
|
||||
"assistant" => {
|
||||
// Extract function calls from assistant's tool_calls
|
||||
if let Some(ref tool_calls) = msg.tool_calls {
|
||||
for tc in tool_calls {
|
||||
if let Some(func) = tc.get("function") {
|
||||
let name = func["name"].as_str().unwrap_or("unknown").to_string();
|
||||
let args_str = func["arguments"].as_str().unwrap_or("{}");
|
||||
let args = serde_json::from_str::<serde_json::Value>(args_str)
|
||||
.unwrap_or(serde_json::json!({}));
|
||||
let call_id = tc["id"].as_str().unwrap_or("").to_string();
|
||||
|
||||
// Register call_id → name for lookup
|
||||
if !call_id.is_empty() {
|
||||
state.mitm_store.register_call_id(call_id, name.clone()).await;
|
||||
}
|
||||
|
||||
last_calls.push(CapturedFunctionCall {
|
||||
name,
|
||||
args,
|
||||
captured_at: std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"tool" => {
|
||||
// Extract tool results
|
||||
let text = extract_message_text(&msg.content);
|
||||
if let Some(ref call_id) = msg.tool_call_id {
|
||||
// Look up function name from call_id
|
||||
let name = state
|
||||
.mitm_store
|
||||
.lookup_call_id(call_id)
|
||||
.await
|
||||
.unwrap_or_else(|| {
|
||||
// Fallback: try to find the name from last_calls by position
|
||||
last_calls
|
||||
.first()
|
||||
.map(|fc| fc.name.clone())
|
||||
.unwrap_or_else(|| "unknown_function".to_string())
|
||||
});
|
||||
|
||||
let result_value = serde_json::from_str::<serde_json::Value>(&text)
|
||||
.unwrap_or_else(|_| serde_json::json!({"result": text}));
|
||||
|
||||
pending_results.push(PendingToolResult {
|
||||
name,
|
||||
result: result_value,
|
||||
});
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
if !last_calls.is_empty() {
|
||||
info!(
|
||||
count = last_calls.len(),
|
||||
tools = ?last_calls.iter().map(|c| &c.name).collect::<Vec<_>>(),
|
||||
"Completions: stored last function calls for MITM history rewrite"
|
||||
);
|
||||
state.mitm_store.set_last_function_calls(last_calls).await;
|
||||
}
|
||||
|
||||
if !pending_results.is_empty() {
|
||||
info!(
|
||||
count = pending_results.len(),
|
||||
tools = ?pending_results.iter().map(|r| &r.name).collect::<Vec<_>>(),
|
||||
"Completions: stored tool results for MITM injection"
|
||||
);
|
||||
for result in pending_results {
|
||||
state.mitm_store.add_tool_result(result).await;
|
||||
}
|
||||
// Clear awaiting flag — we have the results now
|
||||
state.mitm_store.clear_awaiting_tool_result();
|
||||
}
|
||||
}
|
||||
|
||||
// Store generation parameters for MITM injection
|
||||
{
|
||||
use crate::mitm::store::GenerationParams;
|
||||
@@ -584,7 +679,7 @@ async fn chat_completions_stream(
|
||||
// ── Check for MITM-captured function calls FIRST ──
|
||||
// This runs independently of LS steps — the MITM captures tool calls
|
||||
// at the proxy layer, so we don't need to wait for LS processing.
|
||||
let captured = state.mitm_store.take_any_function_calls().await;
|
||||
let captured = state.mitm_store.take_function_calls(&cascade_id).await;
|
||||
if let Some(ref calls) = captured {
|
||||
if !calls.is_empty() {
|
||||
let mut tool_calls = Vec::new();
|
||||
|
||||
@@ -599,7 +599,7 @@ async fn handle_responses_sync(
|
||||
let start = std::time::Instant::now();
|
||||
while start.elapsed().as_secs() < timeout {
|
||||
// Check for function calls
|
||||
let captured = state.mitm_store.take_any_function_calls().await;
|
||||
let captured = state.mitm_store.take_function_calls(&cascade_id).await;
|
||||
if let Some(ref raw_calls) = captured {
|
||||
let calls: Vec<_> = if let Some(max) = params.max_tool_calls {
|
||||
raw_calls.iter().take(max as usize).collect()
|
||||
@@ -715,7 +715,7 @@ async fn handle_responses_sync(
|
||||
let msg_id = format!("msg_{}", uuid::Uuid::new_v4().to_string().replace('-', ""));
|
||||
|
||||
// Check for captured function calls from MITM (clears the active flag)
|
||||
let captured_tool_calls = state.mitm_store.take_any_function_calls().await;
|
||||
let captured_tool_calls = state.mitm_store.take_function_calls(&cascade_id).await;
|
||||
|
||||
// Enforce max_tool_calls limit
|
||||
let captured_tool_calls = captured_tool_calls.map(|mut calls| {
|
||||
@@ -909,7 +909,7 @@ async fn handle_responses_stream(
|
||||
}
|
||||
|
||||
// Check for function calls first
|
||||
let captured = state.mitm_store.take_any_function_calls().await;
|
||||
let captured = state.mitm_store.take_function_calls(&cascade_id).await;
|
||||
if let Some(ref raw_calls) = captured {
|
||||
let calls: Vec<_> = if let Some(max) = params.max_tool_calls {
|
||||
raw_calls.iter().take(max as usize).collect()
|
||||
|
||||
@@ -345,10 +345,18 @@ impl MitmStore {
|
||||
}
|
||||
|
||||
/// Record a captured function call from Google's response.
|
||||
///
|
||||
/// Falls back to `active_cascade_id` (set by the API handler) when no
|
||||
/// cascade hint is available from the request body, matching
|
||||
/// `record_usage`'s fallback behavior for consistent correlation.
|
||||
pub async fn record_function_call(&self, cascade_id: Option<&str>, fc: CapturedFunctionCall) {
|
||||
let key = cascade_id
|
||||
.map(|s| s.to_string())
|
||||
.unwrap_or_else(|| "_latest".to_string());
|
||||
let key = if let Some(cid) = cascade_id {
|
||||
cid.to_string()
|
||||
} else if let Some(active) = self.active_cascade_id.read().await.as_ref() {
|
||||
active.clone()
|
||||
} else {
|
||||
"_latest".to_string()
|
||||
};
|
||||
info!(
|
||||
cascade = %key,
|
||||
tool = %fc.name,
|
||||
@@ -383,7 +391,50 @@ impl MitmStore {
|
||||
self.awaiting_tool_result.store(false, Ordering::SeqCst);
|
||||
}
|
||||
|
||||
/// Take pending function calls for a specific cascade.
|
||||
///
|
||||
/// Priority: exact cascade_id → active_cascade_id → `_latest` → any key.
|
||||
/// This prevents cross-cascade contamination when multiple requests are
|
||||
/// in-flight simultaneously.
|
||||
pub async fn take_function_calls(&self, cascade_id: &str) -> Option<Vec<CapturedFunctionCall>> {
|
||||
let mut pending = self.pending_function_calls.write().await;
|
||||
|
||||
// 1. Exact cascade match
|
||||
if let Some(result) = pending.remove(cascade_id) {
|
||||
self.has_active_function_call.store(false, Ordering::SeqCst);
|
||||
return Some(result);
|
||||
}
|
||||
|
||||
// 2. Active cascade (set by API handler)
|
||||
if let Some(active) = self.active_cascade_id.read().await.as_ref() {
|
||||
if active != cascade_id {
|
||||
if let Some(result) = pending.remove(active.as_str()) {
|
||||
self.has_active_function_call.store(false, Ordering::SeqCst);
|
||||
return Some(result);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Fallback to _latest
|
||||
if let Some(result) = pending.remove("_latest") {
|
||||
self.has_active_function_call.store(false, Ordering::SeqCst);
|
||||
return Some(result);
|
||||
}
|
||||
|
||||
// 4. Last resort: any key
|
||||
if let Some(key) = pending.keys().next().cloned() {
|
||||
let result = pending.remove(&key);
|
||||
if result.is_some() {
|
||||
self.has_active_function_call.store(false, Ordering::SeqCst);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Take any pending function calls (ignoring cascade ID).
|
||||
/// Legacy method — prefer `take_function_calls(cascade_id)` for proper correlation.
|
||||
pub async fn take_any_function_calls(&self) -> Option<Vec<CapturedFunctionCall>> {
|
||||
let mut pending = self.pending_function_calls.write().await;
|
||||
let result = pending.remove("_latest");
|
||||
|
||||
Reference in New Issue
Block a user