refactor: endpoint parity and proxy improvements

Mixed changes from recent sessions: endpoint feature parity
improvements, proxy bug fixes, and store cleanup.
This commit is contained in:
Nikketryhard
2026-02-16 21:47:00 -06:00
parent 86675fd960
commit 637fbc0e54
5 changed files with 763 additions and 692 deletions

View File

@@ -223,7 +223,6 @@ pub(crate) async fn handle_completions(
} else {
state.mitm_store.clear_tools().await;
}
state.mitm_store.clear_active_function_call();
// ── Extract tool results from messages for MITM injection ──────────
// When OpenCode sends back tool results, the messages array contains:
@@ -340,7 +339,6 @@ pub(crate) async fn handle_completions(
state.mitm_store.set_last_function_calls(last_round.calls.clone()).await;
}
state.mitm_store.set_tool_rounds(rounds).await;
state.mitm_store.clear_awaiting_tool_result();
}
}
@@ -441,6 +439,8 @@ pub(crate) async fn handle_completions(
// Send message on primary cascade
state.mitm_store.set_active_cascade(&cascade_id).await;
// Store real user text for MITM injection — LS gets a dummy prompt
state.mitm_store.set_pending_user_text(user_text.clone()).await;
// Store image for MITM injection (LS doesn't forward images to Google API)
if let Some(ref img) = image {
use base64::Engine;
@@ -452,9 +452,25 @@ pub(crate) async fn handle_completions(
})
.await;
}
// Pre-flight: install channel BEFORE send_message so the MITM proxy
// can grab it when the LS fires its API call.
// Only for streaming — sync paths use poll_for_response (legacy store).
let has_custom_tools = state.mitm_store.get_tools().await.is_some();
let mitm_rx = if has_custom_tools && body.stream {
state.mitm_store.clear_response_async().await;
state.mitm_store.clear_upstream_error().await;
let _ = state.mitm_store.take_any_function_calls().await;
let (tx, rx) = tokio::sync::mpsc::channel(64);
state.mitm_store.set_channel(tx).await;
Some(rx)
} else {
None
};
match state
.backend
.send_message_with_image(&cascade_id, &user_text, model.model_enum, image.as_ref())
.send_message_with_image(&cascade_id, ".", model.model_enum, image.as_ref())
.await
{
Ok((200, _)) => {
@@ -465,6 +481,7 @@ pub(crate) async fn handle_completions(
});
}
Ok((status, _)) => {
state.mitm_store.drop_channel().await;
return err_response(
StatusCode::BAD_GATEWAY,
format!("Backend returned {status}"),
@@ -472,6 +489,7 @@ pub(crate) async fn handle_completions(
);
}
Err(e) => {
state.mitm_store.drop_channel().await;
return err_response(
StatusCode::BAD_GATEWAY,
format!("Send failed: {e}"),
@@ -498,6 +516,7 @@ pub(crate) async fn handle_completions(
cascade_id,
body.timeout,
include_usage,
mitm_rx,
)
.await
} else if n <= 1 {
@@ -518,7 +537,7 @@ pub(crate) async fn handle_completions(
// Send the same message on each extra cascade
match state
.backend
.send_message_with_image(&cid, &user_text, model.model_enum, image.as_ref())
.send_message_with_image(&cid, ".", model.model_enum, image.as_ref())
.await
{
Ok((200, _)) => {
@@ -635,18 +654,18 @@ async fn chat_completions_stream(
cascade_id: String,
timeout: u64,
include_usage: bool,
mitm_rx: Option<tokio::sync::mpsc::Receiver<crate::mitm::store::MitmEvent>>,
) -> axum::response::Response {
let stream = async_stream::stream! {
let start = std::time::Instant::now();
let mut last_text = String::new();
let has_custom_tools = state.mitm_store.get_tools().await.is_some();
let has_custom_tools = mitm_rx.is_some();
// Clear ALL stale state from previous requests
state.mitm_store.clear_response_async().await;
state.mitm_store.clear_upstream_error().await;
state.mitm_store.clear_active_function_call();
// Drain any stale function calls from previous requests
let _ = state.mitm_store.take_any_function_calls().await;
if !has_custom_tools {
state.mitm_store.clear_response_async().await;
state.mitm_store.clear_upstream_error().await;
let _ = state.mitm_store.take_any_function_calls().await;
}
// Initial role chunk
yield Ok::<_, std::convert::Infallible>(Event::default().data(chunk_json(
@@ -659,7 +678,6 @@ async fn chat_completions_stream(
let mut last_thinking_len: usize = 0;
let mut complete_polls: u32 = 0;
let mut did_unblock_ls = false; // Prevents infinite unblock loops
let mut my_generation = state.mitm_store.current_generation();
// Helper: build usage JSON from MITM tokens
let build_usage = |pt: u64, ct: u64, crt: u64, tt: u64| -> serde_json::Value {
@@ -672,212 +690,219 @@ async fn chat_completions_stream(
})
};
while start.elapsed().as_secs() < timeout {
// Check for upstream errors from MITM (Google API errors)
if let Some(err) = state.mitm_store.take_upstream_error().await {
let error_msg = super::util::upstream_error_message(&err);
let error_type = super::util::upstream_error_type(&err);
yield Ok(Event::default().data(serde_json::to_string(&serde_json::json!({
"error": {
"message": error_msg,
"type": error_type,
"code": err.status,
}
})).unwrap()));
yield Ok(Event::default().data("[DONE]".to_string()));
break;
}
// Take ownership of the pre-installed channel receiver
let mut rx_opt = mitm_rx;
// Bail if another completions handler has superseded us
if state.mitm_store.current_generation() != my_generation {
debug!("Completions: generation changed (superseded), ending stream");
while start.elapsed().as_secs() < timeout {
if let Some(ref mut rx) = rx_opt {
// ── Channel-based MITM pipeline ──
// Track accumulated text for delta computation
let mut acc_text = String::new();
let mut acc_thinking = String::new();
let mut last_usage: Option<crate::mitm::store::ApiUsage> = None;
'channel_loop: while let Some(event) = tokio::time::timeout(
std::time::Duration::from_secs(timeout.saturating_sub(start.elapsed().as_secs())),
rx.recv(),
).await.ok().flatten() {
use crate::mitm::store::MitmEvent;
match event {
MitmEvent::ThinkingDelta(full_thinking) => {
if full_thinking.len() > acc_thinking.len() {
let delta = full_thinking[acc_thinking.len()..].to_string();
acc_thinking = full_thinking;
last_thinking_len = acc_thinking.len();
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([chunk_choice(0, serde_json::json!({"reasoning_content": delta}), None)]),
None,
)));
}
}
MitmEvent::TextDelta(full_text) => {
if full_text.len() > acc_text.len() {
let delta = full_text[acc_text.len()..].to_string();
acc_text = full_text;
last_text = acc_text.clone();
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([chunk_choice(0, serde_json::json!({"content": delta}), None)]),
None,
)));
}
}
MitmEvent::FunctionCall(calls) => {
let mut tool_calls = Vec::new();
for (i, fc) in calls.iter().enumerate() {
let call_id = format!(
"call_{}",
uuid::Uuid::new_v4().to_string().replace('-', "")[..24].to_string()
);
let arguments = serde_json::to_string(&fc.args).unwrap_or_default();
tool_calls.push(serde_json::json!({
"index": i,
"id": call_id,
"type": "function",
"function": {
"name": fc.name,
"arguments": arguments,
},
}));
}
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([chunk_choice(0, serde_json::json!({"tool_calls": tool_calls}), None)]),
None,
)));
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([chunk_choice(0, serde_json::json!({}), Some("tool_calls"))]),
None,
)));
if include_usage {
let mitm = state.mitm_store.take_usage(&cascade_id).await
.or(state.mitm_store.take_usage("_latest").await);
let (pt, ct, crt, tt) = if let Some(ref u) = mitm {
(u.input_tokens, u.output_tokens, u.cache_read_input_tokens, u.thinking_output_tokens)
} else if let Some(ref u) = last_usage {
(u.input_tokens, u.output_tokens, u.cache_read_input_tokens, u.thinking_output_tokens)
} else { (0, 0, 0, 0) };
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([]),
Some(build_usage(pt, ct, crt, tt)),
)));
}
yield Ok(Event::default().data("[DONE]"));
state.mitm_store.drop_channel().await;
return;
}
MitmEvent::ResponseComplete => {
if !acc_text.is_empty() {
// Have response text — done
debug!("Completions: channel response complete, text_len={}, thinking_len={}",
acc_text.len(), acc_thinking.len());
let mitm = state.mitm_store.take_usage(&cascade_id).await
.or(state.mitm_store.take_usage("_latest").await)
.or(last_usage.take());
let fr = google_to_openai_finish_reason(mitm.as_ref().and_then(|u| u.stop_reason.as_deref()));
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([chunk_choice(0, serde_json::json!({}), Some(fr))]),
None,
)));
if include_usage {
let (pt, ct, crt, tt) = if let Some(ref u) = mitm {
(u.input_tokens, u.output_tokens, u.cache_read_input_tokens, u.thinking_output_tokens)
} else { (0, 0, 0, 0) };
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([]),
Some(build_usage(pt, ct, crt, tt)),
)));
}
yield Ok(Event::default().data("[DONE]"));
state.mitm_store.drop_channel().await;
return;
} else if !acc_thinking.is_empty() && !did_unblock_ls {
// Thinking-only response — LS needs follow-up API calls.
// Create a new channel and unblock the gate.
did_unblock_ls = true;
let (new_tx, new_rx) = tokio::sync::mpsc::channel(64);
state.mitm_store.set_channel(new_tx).await;
state.mitm_store.clear_request_in_flight();
let _ = state.mitm_store.take_any_function_calls().await;
*rx = new_rx;
debug!(
"Completions: thinking-only — new channel for follow-up, thinking_len={}",
acc_thinking.len()
);
continue 'channel_loop;
} else if !acc_thinking.is_empty() && did_unblock_ls {
// Already unblocked once, still thinking-only.
// Wait a bit for potential follow-up events.
complete_polls += 1;
if complete_polls >= 25 {
info!("Completions: thinking-only timeout, thinking_len={}", acc_thinking.len());
let mitm = state.mitm_store.take_usage(&cascade_id).await
.or(state.mitm_store.take_usage("_latest").await)
.or(last_usage.take());
let fr = google_to_openai_finish_reason(mitm.as_ref().and_then(|u| u.stop_reason.as_deref()));
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([chunk_choice(0, serde_json::json!({}), Some(fr))]),
None,
)));
if include_usage {
let (pt, ct, crt, tt) = if let Some(ref u) = mitm {
(u.input_tokens, u.output_tokens, u.cache_read_input_tokens, u.thinking_output_tokens)
} else { (0, 0, 0, 0) };
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([]),
Some(build_usage(pt, ct, crt, tt)),
)));
}
yield Ok(Event::default().data("[DONE]"));
state.mitm_store.drop_channel().await;
return;
}
// Don't break — wait for more channel events
continue 'channel_loop;
} else {
// Empty response (no text, no thinking, no tools)
complete_polls += 1;
if complete_polls >= 4 {
info!("Completions: channel response complete but empty, ending stream");
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([chunk_choice(0, serde_json::json!({}), Some("stop"))]),
None,
)));
yield Ok(Event::default().data("[DONE]"));
state.mitm_store.drop_channel().await;
return;
}
continue 'channel_loop;
}
}
MitmEvent::UpstreamError(err) => {
let error_msg = super::util::upstream_error_message(&err);
let error_type = super::util::upstream_error_type(&err);
yield Ok(Event::default().data(serde_json::to_string(&serde_json::json!({
"error": {
"message": error_msg,
"type": error_type,
"code": err.status,
}
})).unwrap()));
yield Ok(Event::default().data("[DONE]"));
state.mitm_store.drop_channel().await;
return;
}
MitmEvent::Usage(u) => {
last_usage = Some(u);
}
MitmEvent::Grounding(_) => {
// Grounding metadata handled by store directly
}
}
}
// Channel closed or timeout — clean up
state.mitm_store.drop_channel().await;
// If we got here from timeout with content, emit what we have
if !last_text.is_empty() || last_thinking_len > 0 {
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([chunk_choice(0, serde_json::json!({}), Some("stop"))]),
None,
)));
}
yield Ok(Event::default().data("[DONE]"));
return;
}
// ── Check for MITM-captured function calls FIRST ──
// This runs independently of LS steps — the MITM captures tool calls
// at the proxy layer, so we don't need to wait for LS processing.
let captured = state.mitm_store.take_function_calls(&cascade_id).await;
if let Some(ref calls) = captured {
if !calls.is_empty() {
let mut tool_calls = Vec::new();
for (i, fc) in calls.iter().enumerate() {
let call_id = format!(
"call_{}",
uuid::Uuid::new_v4().to_string().replace('-', "")[..24].to_string()
);
let arguments = serde_json::to_string(&fc.args).unwrap_or_default();
tool_calls.push(serde_json::json!({
"index": i,
"id": call_id,
"type": "function",
"function": {
"name": fc.name,
"arguments": arguments,
},
}));
}
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([chunk_choice(0, serde_json::json!({"tool_calls": tool_calls}), None)]),
None,
)));
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([chunk_choice(0, serde_json::json!({}), Some("tool_calls"))]),
None,
)));
if include_usage {
let mitm = state.mitm_store.take_usage(&cascade_id).await
.or(state.mitm_store.take_usage("_latest").await);
let (pt, ct, crt, tt) = if let Some(ref u) = mitm {
(u.input_tokens, u.output_tokens, u.cache_read_input_tokens, u.thinking_output_tokens)
} else { (0, 0, 0, 0) };
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([]),
Some(build_usage(pt, ct, crt, tt)),
)));
}
yield Ok(Event::default().data("[DONE]"));
// Clear in-flight flag so the next turn's requests can get through
state.mitm_store.clear_response_async().await;
return;
}
}
// ── Primary: MITM-captured response (when custom tools are active) ──
// The MITM intercepts the real Google SSE stream and captures text,
// thinking, and function calls. This is the authoritative data source.
// The LS only gets rewritten responses (function calls → text placeholders)
// so it doesn't provide useful streaming data when MITM is active.
if has_custom_tools {
// Stream thinking text as reasoning_content deltas
if let Some(tc) = state.mitm_store.peek_thinking_text().await {
if tc.len() > last_thinking_len {
let delta = &tc[last_thinking_len..];
last_thinking_len = tc.len();
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([chunk_choice(0, serde_json::json!({"reasoning_content": delta}), None)]),
None,
)));
}
}
// Stream response text as content deltas
if let Some(text) = state.mitm_store.peek_response_text().await {
if !text.is_empty() && text != last_text {
let delta = if text.len() > last_text.len() && text.starts_with(&*last_text) {
text[last_text.len()..].to_string()
} else {
text.clone()
};
if !delta.is_empty() {
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([chunk_choice(0, serde_json::json!({"content": delta}), None)]),
None,
)));
last_text = text;
}
}
}
// Check if MITM response is complete
if state.mitm_store.is_response_complete() {
if !last_text.is_empty() {
// Have actual response text — done
complete_polls += 1;
if complete_polls >= 2 {
debug!("Completions: MITM response complete, text_len={}, thinking_len={}", last_text.len(), last_thinking_len);
let mitm = state.mitm_store.take_usage(&cascade_id).await
.or(state.mitm_store.take_usage("_latest").await);
let fr = google_to_openai_finish_reason(mitm.as_ref().and_then(|u| u.stop_reason.as_deref()));
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([chunk_choice(0, serde_json::json!({}), Some(fr))]),
None,
)));
if include_usage {
let (pt, ct, crt, tt) = if let Some(ref u) = mitm {
(u.input_tokens, u.output_tokens, u.cache_read_input_tokens, u.thinking_output_tokens)
} else { (0, 0, 0, 0) };
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([]),
Some(build_usage(pt, ct, crt, tt)),
)));
}
yield Ok(Event::default().data("[DONE]"));
return;
}
} else if last_thinking_len > 0 && !did_unblock_ls {
// Thinking-only response. The LS needs follow-up API calls
// to get actual function calls or text. Unblock once.
did_unblock_ls = true;
complete_polls = 0;
// Bump generation FIRST — invalidates old MITM connection's store writes
my_generation = state.mitm_store.bump_generation();
state.mitm_store.clear_request_in_flight();
state.mitm_store.clear_response_complete();
// Drain store so leaked connections can't produce stale content
state.mitm_store.set_response_text("").await;
state.mitm_store.set_thinking_text("").await;
let _ = state.mitm_store.take_any_function_calls().await;
debug!(
"Completions: thinking-only — unblocking LS for follow-up, thinking_len={}, new_gen={}",
last_thinking_len, my_generation
);
} else if last_thinking_len > 0 && did_unblock_ls {
// Already unblocked once. Still only thinking after follow-up.
complete_polls += 1;
if complete_polls >= 25 {
info!("Completions: thinking-only timeout after ~10s, thinking_len={}", last_thinking_len);
let mitm = state.mitm_store.take_usage(&cascade_id).await
.or(state.mitm_store.take_usage("_latest").await);
let fr = google_to_openai_finish_reason(mitm.as_ref().and_then(|u| u.stop_reason.as_deref()));
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([chunk_choice(0, serde_json::json!({}), Some(fr))]),
None,
)));
if include_usage {
let (pt, ct, crt, tt) = if let Some(ref u) = mitm {
(u.input_tokens, u.output_tokens, u.cache_read_input_tokens, u.thinking_output_tokens)
} else { (0, 0, 0, 0) };
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([]),
Some(build_usage(pt, ct, crt, tt)),
)));
}
yield Ok(Event::default().data("[DONE]"));
return;
}
} else {
// response_complete but no text AND no thinking — might be
// a function-call-only response that was already consumed,
// or empty response. Wait a bit then give up.
complete_polls += 1;
if complete_polls >= 4 {
info!("Completions: MITM response complete but no content (text/thinking/tools all empty), ending stream");
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([chunk_choice(0, serde_json::json!({}), Some("stop"))]),
None,
)));
yield Ok(Event::default().data("[DONE]"));
return;
}
}
} else {
complete_polls = 0; // Reset — not complete yet
}
} else {
// ── Fallback: LS steps (no MITM capture active) ──
if let Ok((status, data)) = state.backend.get_steps(&cascade_id).await {
@@ -1001,7 +1026,7 @@ async fn chat_completions_stream(
}
})).unwrap()));
// Always clear in-flight flag when stream ends
state.mitm_store.clear_response_async().await;
state.mitm_store.drop_channel().await;
yield Ok(Event::default().data("[DONE]"));
};