fix: tool call race conditions and missing completions tool result extraction
- store.rs: record_function_call now falls back to active_cascade_id (matching record_usage behavior) instead of blind _latest fallback - store.rs: add cascade-aware take_function_calls(cascade_id) method with priority: exact match → active cascade → _latest → any key - completions.rs: extract tool_calls from assistant messages and tool results from tool messages, storing them for MITM injection. This was the ROOT CAUSE — the completions handler stored tool definitions but never extracted tool results, so modify_request couldn't rewrite the LS conversation history with proper functionCall/functionResponse - responses.rs: use cascade-aware take_function_calls for consistency
This commit is contained in:
@@ -16,6 +16,7 @@ use super::polling::{
|
|||||||
use super::types::*;
|
use super::types::*;
|
||||||
use super::util::{err_response, now_unix, upstream_err_response};
|
use super::util::{err_response, now_unix, upstream_err_response};
|
||||||
use super::AppState;
|
use super::AppState;
|
||||||
|
use crate::mitm::store::{CapturedFunctionCall, PendingToolResult};
|
||||||
|
|
||||||
/// Extract a conversation/session ID from a flexible JSON value.
|
/// Extract a conversation/session ID from a flexible JSON value.
|
||||||
/// Accepts a plain string or an object with an "id" field.
|
/// Accepts a plain string or an object with an "id" field.
|
||||||
@@ -224,6 +225,100 @@ pub(crate) async fn handle_completions(
|
|||||||
}
|
}
|
||||||
state.mitm_store.clear_active_function_call();
|
state.mitm_store.clear_active_function_call();
|
||||||
|
|
||||||
|
// ── Extract tool results from messages for MITM injection ──────────
|
||||||
|
// When OpenCode sends back tool results, the messages array contains:
|
||||||
|
// 1. assistant message with tool_calls (the model's previous function calls)
|
||||||
|
// 2. tool messages with results (the executed tool outputs)
|
||||||
|
// We need to store these so modify_request can rewrite the LS's
|
||||||
|
// conversation history with proper functionCall/functionResponse parts
|
||||||
|
// instead of the placeholder "Tool call completed" text.
|
||||||
|
{
|
||||||
|
let mut last_calls: Vec<CapturedFunctionCall> = Vec::new();
|
||||||
|
let mut pending_results: Vec<PendingToolResult> = Vec::new();
|
||||||
|
|
||||||
|
for msg in &body.messages {
|
||||||
|
match msg.role.as_str() {
|
||||||
|
"assistant" => {
|
||||||
|
// Extract function calls from assistant's tool_calls
|
||||||
|
if let Some(ref tool_calls) = msg.tool_calls {
|
||||||
|
for tc in tool_calls {
|
||||||
|
if let Some(func) = tc.get("function") {
|
||||||
|
let name = func["name"].as_str().unwrap_or("unknown").to_string();
|
||||||
|
let args_str = func["arguments"].as_str().unwrap_or("{}");
|
||||||
|
let args = serde_json::from_str::<serde_json::Value>(args_str)
|
||||||
|
.unwrap_or(serde_json::json!({}));
|
||||||
|
let call_id = tc["id"].as_str().unwrap_or("").to_string();
|
||||||
|
|
||||||
|
// Register call_id → name for lookup
|
||||||
|
if !call_id.is_empty() {
|
||||||
|
state.mitm_store.register_call_id(call_id, name.clone()).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
last_calls.push(CapturedFunctionCall {
|
||||||
|
name,
|
||||||
|
args,
|
||||||
|
captured_at: std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.unwrap_or_default()
|
||||||
|
.as_secs(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"tool" => {
|
||||||
|
// Extract tool results
|
||||||
|
let text = extract_message_text(&msg.content);
|
||||||
|
if let Some(ref call_id) = msg.tool_call_id {
|
||||||
|
// Look up function name from call_id
|
||||||
|
let name = state
|
||||||
|
.mitm_store
|
||||||
|
.lookup_call_id(call_id)
|
||||||
|
.await
|
||||||
|
.unwrap_or_else(|| {
|
||||||
|
// Fallback: try to find the name from last_calls by position
|
||||||
|
last_calls
|
||||||
|
.first()
|
||||||
|
.map(|fc| fc.name.clone())
|
||||||
|
.unwrap_or_else(|| "unknown_function".to_string())
|
||||||
|
});
|
||||||
|
|
||||||
|
let result_value = serde_json::from_str::<serde_json::Value>(&text)
|
||||||
|
.unwrap_or_else(|_| serde_json::json!({"result": text}));
|
||||||
|
|
||||||
|
pending_results.push(PendingToolResult {
|
||||||
|
name,
|
||||||
|
result: result_value,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !last_calls.is_empty() {
|
||||||
|
info!(
|
||||||
|
count = last_calls.len(),
|
||||||
|
tools = ?last_calls.iter().map(|c| &c.name).collect::<Vec<_>>(),
|
||||||
|
"Completions: stored last function calls for MITM history rewrite"
|
||||||
|
);
|
||||||
|
state.mitm_store.set_last_function_calls(last_calls).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
if !pending_results.is_empty() {
|
||||||
|
info!(
|
||||||
|
count = pending_results.len(),
|
||||||
|
tools = ?pending_results.iter().map(|r| &r.name).collect::<Vec<_>>(),
|
||||||
|
"Completions: stored tool results for MITM injection"
|
||||||
|
);
|
||||||
|
for result in pending_results {
|
||||||
|
state.mitm_store.add_tool_result(result).await;
|
||||||
|
}
|
||||||
|
// Clear awaiting flag — we have the results now
|
||||||
|
state.mitm_store.clear_awaiting_tool_result();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Store generation parameters for MITM injection
|
// Store generation parameters for MITM injection
|
||||||
{
|
{
|
||||||
use crate::mitm::store::GenerationParams;
|
use crate::mitm::store::GenerationParams;
|
||||||
@@ -584,7 +679,7 @@ async fn chat_completions_stream(
|
|||||||
// ── Check for MITM-captured function calls FIRST ──
|
// ── Check for MITM-captured function calls FIRST ──
|
||||||
// This runs independently of LS steps — the MITM captures tool calls
|
// This runs independently of LS steps — the MITM captures tool calls
|
||||||
// at the proxy layer, so we don't need to wait for LS processing.
|
// at the proxy layer, so we don't need to wait for LS processing.
|
||||||
let captured = state.mitm_store.take_any_function_calls().await;
|
let captured = state.mitm_store.take_function_calls(&cascade_id).await;
|
||||||
if let Some(ref calls) = captured {
|
if let Some(ref calls) = captured {
|
||||||
if !calls.is_empty() {
|
if !calls.is_empty() {
|
||||||
let mut tool_calls = Vec::new();
|
let mut tool_calls = Vec::new();
|
||||||
|
|||||||
@@ -599,7 +599,7 @@ async fn handle_responses_sync(
|
|||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
while start.elapsed().as_secs() < timeout {
|
while start.elapsed().as_secs() < timeout {
|
||||||
// Check for function calls
|
// Check for function calls
|
||||||
let captured = state.mitm_store.take_any_function_calls().await;
|
let captured = state.mitm_store.take_function_calls(&cascade_id).await;
|
||||||
if let Some(ref raw_calls) = captured {
|
if let Some(ref raw_calls) = captured {
|
||||||
let calls: Vec<_> = if let Some(max) = params.max_tool_calls {
|
let calls: Vec<_> = if let Some(max) = params.max_tool_calls {
|
||||||
raw_calls.iter().take(max as usize).collect()
|
raw_calls.iter().take(max as usize).collect()
|
||||||
@@ -715,7 +715,7 @@ async fn handle_responses_sync(
|
|||||||
let msg_id = format!("msg_{}", uuid::Uuid::new_v4().to_string().replace('-', ""));
|
let msg_id = format!("msg_{}", uuid::Uuid::new_v4().to_string().replace('-', ""));
|
||||||
|
|
||||||
// Check for captured function calls from MITM (clears the active flag)
|
// Check for captured function calls from MITM (clears the active flag)
|
||||||
let captured_tool_calls = state.mitm_store.take_any_function_calls().await;
|
let captured_tool_calls = state.mitm_store.take_function_calls(&cascade_id).await;
|
||||||
|
|
||||||
// Enforce max_tool_calls limit
|
// Enforce max_tool_calls limit
|
||||||
let captured_tool_calls = captured_tool_calls.map(|mut calls| {
|
let captured_tool_calls = captured_tool_calls.map(|mut calls| {
|
||||||
@@ -909,7 +909,7 @@ async fn handle_responses_stream(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check for function calls first
|
// Check for function calls first
|
||||||
let captured = state.mitm_store.take_any_function_calls().await;
|
let captured = state.mitm_store.take_function_calls(&cascade_id).await;
|
||||||
if let Some(ref raw_calls) = captured {
|
if let Some(ref raw_calls) = captured {
|
||||||
let calls: Vec<_> = if let Some(max) = params.max_tool_calls {
|
let calls: Vec<_> = if let Some(max) = params.max_tool_calls {
|
||||||
raw_calls.iter().take(max as usize).collect()
|
raw_calls.iter().take(max as usize).collect()
|
||||||
|
|||||||
@@ -345,10 +345,18 @@ impl MitmStore {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Record a captured function call from Google's response.
|
/// Record a captured function call from Google's response.
|
||||||
|
///
|
||||||
|
/// Falls back to `active_cascade_id` (set by the API handler) when no
|
||||||
|
/// cascade hint is available from the request body, matching
|
||||||
|
/// `record_usage`'s fallback behavior for consistent correlation.
|
||||||
pub async fn record_function_call(&self, cascade_id: Option<&str>, fc: CapturedFunctionCall) {
|
pub async fn record_function_call(&self, cascade_id: Option<&str>, fc: CapturedFunctionCall) {
|
||||||
let key = cascade_id
|
let key = if let Some(cid) = cascade_id {
|
||||||
.map(|s| s.to_string())
|
cid.to_string()
|
||||||
.unwrap_or_else(|| "_latest".to_string());
|
} else if let Some(active) = self.active_cascade_id.read().await.as_ref() {
|
||||||
|
active.clone()
|
||||||
|
} else {
|
||||||
|
"_latest".to_string()
|
||||||
|
};
|
||||||
info!(
|
info!(
|
||||||
cascade = %key,
|
cascade = %key,
|
||||||
tool = %fc.name,
|
tool = %fc.name,
|
||||||
@@ -383,7 +391,50 @@ impl MitmStore {
|
|||||||
self.awaiting_tool_result.store(false, Ordering::SeqCst);
|
self.awaiting_tool_result.store(false, Ordering::SeqCst);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Take pending function calls for a specific cascade.
|
||||||
|
///
|
||||||
|
/// Priority: exact cascade_id → active_cascade_id → `_latest` → any key.
|
||||||
|
/// This prevents cross-cascade contamination when multiple requests are
|
||||||
|
/// in-flight simultaneously.
|
||||||
|
pub async fn take_function_calls(&self, cascade_id: &str) -> Option<Vec<CapturedFunctionCall>> {
|
||||||
|
let mut pending = self.pending_function_calls.write().await;
|
||||||
|
|
||||||
|
// 1. Exact cascade match
|
||||||
|
if let Some(result) = pending.remove(cascade_id) {
|
||||||
|
self.has_active_function_call.store(false, Ordering::SeqCst);
|
||||||
|
return Some(result);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Active cascade (set by API handler)
|
||||||
|
if let Some(active) = self.active_cascade_id.read().await.as_ref() {
|
||||||
|
if active != cascade_id {
|
||||||
|
if let Some(result) = pending.remove(active.as_str()) {
|
||||||
|
self.has_active_function_call.store(false, Ordering::SeqCst);
|
||||||
|
return Some(result);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Fallback to _latest
|
||||||
|
if let Some(result) = pending.remove("_latest") {
|
||||||
|
self.has_active_function_call.store(false, Ordering::SeqCst);
|
||||||
|
return Some(result);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. Last resort: any key
|
||||||
|
if let Some(key) = pending.keys().next().cloned() {
|
||||||
|
let result = pending.remove(&key);
|
||||||
|
if result.is_some() {
|
||||||
|
self.has_active_function_call.store(false, Ordering::SeqCst);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
/// Take any pending function calls (ignoring cascade ID).
|
/// Take any pending function calls (ignoring cascade ID).
|
||||||
|
/// Legacy method — prefer `take_function_calls(cascade_id)` for proper correlation.
|
||||||
pub async fn take_any_function_calls(&self) -> Option<Vec<CapturedFunctionCall>> {
|
pub async fn take_any_function_calls(&self) -> Option<Vec<CapturedFunctionCall>> {
|
||||||
let mut pending = self.pending_function_calls.write().await;
|
let mut pending = self.pending_function_calls.write().await;
|
||||||
let result = pending.remove("_latest");
|
let result = pending.remove("_latest");
|
||||||
|
|||||||
Reference in New Issue
Block a user