diff --git a/src/api/completions.rs b/src/api/completions.rs index be233f4..27f4b8d 100644 --- a/src/api/completions.rs +++ b/src/api/completions.rs @@ -16,6 +16,7 @@ use super::polling::{ use super::types::*; use super::util::{err_response, now_unix, upstream_err_response}; use super::AppState; +use crate::mitm::store::{CapturedFunctionCall, PendingToolResult}; /// Extract a conversation/session ID from a flexible JSON value. /// Accepts a plain string or an object with an "id" field. @@ -224,6 +225,100 @@ pub(crate) async fn handle_completions( } state.mitm_store.clear_active_function_call(); + // ── Extract tool results from messages for MITM injection ────────── + // When OpenCode sends back tool results, the messages array contains: + // 1. assistant message with tool_calls (the model's previous function calls) + // 2. tool messages with results (the executed tool outputs) + // We need to store these so modify_request can rewrite the LS's + // conversation history with proper functionCall/functionResponse parts + // instead of the placeholder "Tool call completed" text. + { + let mut last_calls: Vec = Vec::new(); + let mut pending_results: Vec = Vec::new(); + + for msg in &body.messages { + match msg.role.as_str() { + "assistant" => { + // Extract function calls from assistant's tool_calls + if let Some(ref tool_calls) = msg.tool_calls { + for tc in tool_calls { + if let Some(func) = tc.get("function") { + let name = func["name"].as_str().unwrap_or("unknown").to_string(); + let args_str = func["arguments"].as_str().unwrap_or("{}"); + let args = serde_json::from_str::(args_str) + .unwrap_or(serde_json::json!({})); + let call_id = tc["id"].as_str().unwrap_or("").to_string(); + + // Register call_id → name for lookup + if !call_id.is_empty() { + state.mitm_store.register_call_id(call_id, name.clone()).await; + } + + last_calls.push(CapturedFunctionCall { + name, + args, + captured_at: std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs(), + }); + } + } + } + } + "tool" => { + // Extract tool results + let text = extract_message_text(&msg.content); + if let Some(ref call_id) = msg.tool_call_id { + // Look up function name from call_id + let name = state + .mitm_store + .lookup_call_id(call_id) + .await + .unwrap_or_else(|| { + // Fallback: try to find the name from last_calls by position + last_calls + .first() + .map(|fc| fc.name.clone()) + .unwrap_or_else(|| "unknown_function".to_string()) + }); + + let result_value = serde_json::from_str::(&text) + .unwrap_or_else(|_| serde_json::json!({"result": text})); + + pending_results.push(PendingToolResult { + name, + result: result_value, + }); + } + } + _ => {} + } + } + + if !last_calls.is_empty() { + info!( + count = last_calls.len(), + tools = ?last_calls.iter().map(|c| &c.name).collect::>(), + "Completions: stored last function calls for MITM history rewrite" + ); + state.mitm_store.set_last_function_calls(last_calls).await; + } + + if !pending_results.is_empty() { + info!( + count = pending_results.len(), + tools = ?pending_results.iter().map(|r| &r.name).collect::>(), + "Completions: stored tool results for MITM injection" + ); + for result in pending_results { + state.mitm_store.add_tool_result(result).await; + } + // Clear awaiting flag — we have the results now + state.mitm_store.clear_awaiting_tool_result(); + } + } + // Store generation parameters for MITM injection { use crate::mitm::store::GenerationParams; @@ -584,7 +679,7 @@ async fn chat_completions_stream( // ── Check for MITM-captured function calls FIRST ── // This runs independently of LS steps — the MITM captures tool calls // at the proxy layer, so we don't need to wait for LS processing. - let captured = state.mitm_store.take_any_function_calls().await; + let captured = state.mitm_store.take_function_calls(&cascade_id).await; if let Some(ref calls) = captured { if !calls.is_empty() { let mut tool_calls = Vec::new(); diff --git a/src/api/responses.rs b/src/api/responses.rs index 38bf973..ac62828 100644 --- a/src/api/responses.rs +++ b/src/api/responses.rs @@ -599,7 +599,7 @@ async fn handle_responses_sync( let start = std::time::Instant::now(); while start.elapsed().as_secs() < timeout { // Check for function calls - let captured = state.mitm_store.take_any_function_calls().await; + let captured = state.mitm_store.take_function_calls(&cascade_id).await; if let Some(ref raw_calls) = captured { let calls: Vec<_> = if let Some(max) = params.max_tool_calls { raw_calls.iter().take(max as usize).collect() @@ -715,7 +715,7 @@ async fn handle_responses_sync( let msg_id = format!("msg_{}", uuid::Uuid::new_v4().to_string().replace('-', "")); // Check for captured function calls from MITM (clears the active flag) - let captured_tool_calls = state.mitm_store.take_any_function_calls().await; + let captured_tool_calls = state.mitm_store.take_function_calls(&cascade_id).await; // Enforce max_tool_calls limit let captured_tool_calls = captured_tool_calls.map(|mut calls| { @@ -909,7 +909,7 @@ async fn handle_responses_stream( } // Check for function calls first - let captured = state.mitm_store.take_any_function_calls().await; + let captured = state.mitm_store.take_function_calls(&cascade_id).await; if let Some(ref raw_calls) = captured { let calls: Vec<_> = if let Some(max) = params.max_tool_calls { raw_calls.iter().take(max as usize).collect() diff --git a/src/mitm/store.rs b/src/mitm/store.rs index 5cd6b74..d28d8c9 100644 --- a/src/mitm/store.rs +++ b/src/mitm/store.rs @@ -345,10 +345,18 @@ impl MitmStore { } /// Record a captured function call from Google's response. + /// + /// Falls back to `active_cascade_id` (set by the API handler) when no + /// cascade hint is available from the request body, matching + /// `record_usage`'s fallback behavior for consistent correlation. pub async fn record_function_call(&self, cascade_id: Option<&str>, fc: CapturedFunctionCall) { - let key = cascade_id - .map(|s| s.to_string()) - .unwrap_or_else(|| "_latest".to_string()); + let key = if let Some(cid) = cascade_id { + cid.to_string() + } else if let Some(active) = self.active_cascade_id.read().await.as_ref() { + active.clone() + } else { + "_latest".to_string() + }; info!( cascade = %key, tool = %fc.name, @@ -383,7 +391,50 @@ impl MitmStore { self.awaiting_tool_result.store(false, Ordering::SeqCst); } + /// Take pending function calls for a specific cascade. + /// + /// Priority: exact cascade_id → active_cascade_id → `_latest` → any key. + /// This prevents cross-cascade contamination when multiple requests are + /// in-flight simultaneously. + pub async fn take_function_calls(&self, cascade_id: &str) -> Option> { + let mut pending = self.pending_function_calls.write().await; + + // 1. Exact cascade match + if let Some(result) = pending.remove(cascade_id) { + self.has_active_function_call.store(false, Ordering::SeqCst); + return Some(result); + } + + // 2. Active cascade (set by API handler) + if let Some(active) = self.active_cascade_id.read().await.as_ref() { + if active != cascade_id { + if let Some(result) = pending.remove(active.as_str()) { + self.has_active_function_call.store(false, Ordering::SeqCst); + return Some(result); + } + } + } + + // 3. Fallback to _latest + if let Some(result) = pending.remove("_latest") { + self.has_active_function_call.store(false, Ordering::SeqCst); + return Some(result); + } + + // 4. Last resort: any key + if let Some(key) = pending.keys().next().cloned() { + let result = pending.remove(&key); + if result.is_some() { + self.has_active_function_call.store(false, Ordering::SeqCst); + } + return result; + } + + None + } + /// Take any pending function calls (ignoring cascade ID). + /// Legacy method — prefer `take_function_calls(cascade_id)` for proper correlation. pub async fn take_any_function_calls(&self) -> Option> { let mut pending = self.pending_function_calls.write().await; let result = pending.remove("_latest");