Files
zerogravity/src/api/completions.rs
Nikketryhard ec1c0c700d fix: decouple function call detection from LS step polling
Move MitmStore function call check outside get_steps() block so tool
calls are detected immediately when captured by MITM, regardless of
LS processing state. Also reduce poll interval to 300ms.

The LS can take 20-30s for its internal multi-turn loop. Previously,
function call checks were nested inside the steps block and required
LS to have produced steps. Now the MITM capture is picked up within
300ms of detection.
2026-02-15 00:48:14 -06:00

506 lines
19 KiB
Rust

//! OpenAI Chat Completions API (/v1/chat/completions) handler.
use axum::{
extract::State,
http::StatusCode,
response::{sse::Event, IntoResponse, Json, Sse},
};
use rand::Rng;
use std::sync::Arc;
use tracing::{debug, info, warn};
use super::models::{lookup_model, DEFAULT_MODEL, MODELS};
use super::polling::{extract_response_text, is_response_done, poll_for_response};
use super::types::*;
use super::util::{err_response, now_unix};
use super::AppState;
// ─── Input extraction ────────────────────────────────────────────────────────
/// Extract user text from Chat Completions messages array.
///
/// When tool results are present, builds the full conversation including
/// tool call results so the model can continue after tool use.
fn extract_chat_input(messages: &[CompletionMessage]) -> String {
let has_tool_results = messages.iter().any(|m| m.role == "tool");
if has_tool_results {
// Build full conversation context including tool results
return build_conversation_with_tools(messages);
}
// Simple path: no tools, just extract system + last user message
let mut system_parts = Vec::new();
let mut user_parts = Vec::new();
for msg in messages {
let text = extract_message_text(&msg.content);
if text.is_empty() {
continue;
}
match msg.role.as_str() {
"system" | "developer" => system_parts.push(text),
"user" => user_parts.push(text),
_ => {}
}
}
let mut result = String::new();
if !system_parts.is_empty() {
result.push_str(&system_parts.join("\n"));
result.push_str("\n\n");
}
if let Some(last) = user_parts.last() {
result.push_str(last);
}
result.trim().to_string()
}
/// Extract text content from a message's content field (string or array).
fn extract_message_text(content: &serde_json::Value) -> String {
match content {
serde_json::Value::String(s) => s.clone(),
serde_json::Value::Array(arr) => arr
.iter()
.filter_map(|item| item["text"].as_str())
.collect::<Vec<_>>()
.join("\n"),
_ => String::new(),
}
}
/// Build conversation text that includes tool call results.
///
/// Format:
/// [system prompt]
/// [user message]
/// [assistant called tool X with args Y]
/// [tool result: Z]
/// [user followup if any]
fn build_conversation_with_tools(messages: &[CompletionMessage]) -> String {
let mut parts = Vec::new();
for msg in messages {
match msg.role.as_str() {
"system" | "developer" => {
let text = extract_message_text(&msg.content);
if !text.is_empty() {
parts.push(text);
}
}
"user" => {
let text = extract_message_text(&msg.content);
if !text.is_empty() {
parts.push(text);
}
}
"assistant" => {
// Include assistant text if any
let text = extract_message_text(&msg.content);
if !text.is_empty() {
parts.push(text);
}
// Include tool calls as context
if let Some(ref tool_calls) = msg.tool_calls {
for tc in tool_calls {
if let Some(func) = tc.get("function") {
let name = func["name"].as_str().unwrap_or("unknown");
let args = func["arguments"].as_str().unwrap_or("{}");
parts.push(format!(
"[Tool call: {}({})]",
name, args
));
}
}
}
}
"tool" => {
let text = extract_message_text(&msg.content);
let tool_id = msg.tool_call_id.as_deref().unwrap_or("unknown");
if !text.is_empty() {
parts.push(format!(
"[Tool result ({})]:\n{}",
tool_id, text
));
}
}
_ => {}
}
}
parts.join("\n\n")
}
// ─── Handler ─────────────────────────────────────────────────────────────────
/// POST /v1/chat/completions — OpenAI Chat Completions API compatibility shim.
/// Accepts standard messages format, reuses the same backend cascade, and
/// outputs in the Chat Completions streaming/sync format.
pub(crate) async fn handle_completions(
State(state): State<Arc<AppState>>,
Json(body): Json<CompletionRequest>,
) -> axum::response::Response {
let model_name = body.model.as_deref().unwrap_or(DEFAULT_MODEL);
info!(
"POST /v1/chat/completions model={} stream={}",
model_name, body.stream
);
let model = match lookup_model(model_name) {
Some(m) => m,
None => {
let names: Vec<&str> = MODELS.iter().map(|m| m.name).collect();
return err_response(
StatusCode::BAD_REQUEST,
format!("Unknown model: {model_name}. Available: {names:?}"),
"invalid_request_error",
);
}
};
// Store client tools from this request (or clear stale ones from other endpoints)
if let Some(ref tools) = body.tools {
let gemini_tools = crate::mitm::modify::openai_tools_to_gemini(tools);
if !gemini_tools.is_empty() {
state.mitm_store.set_tools(gemini_tools).await;
if let Some(ref choice) = body.tool_choice {
let gemini_config = crate::mitm::modify::openai_tool_choice_to_gemini(choice);
state.mitm_store.set_tool_config(gemini_config).await;
}
info!(count = tools.len(), "Completions: stored client tools for MITM injection");
} else {
state.mitm_store.clear_tools().await;
}
} else {
state.mitm_store.clear_tools().await;
}
state.mitm_store.clear_active_function_call();
let token = state.backend.oauth_token().await;
if token.is_empty() {
return err_response(
StatusCode::UNAUTHORIZED,
"No OAuth token. POST to /v1/token or set ANTIGRAVITY_OAUTH_TOKEN env var.".into(),
"authentication_error",
);
}
let user_text = extract_chat_input(&body.messages);
if user_text.is_empty() {
return err_response(
StatusCode::BAD_REQUEST,
"No user message found".to_string(),
"invalid_request_error",
);
}
// Fresh cascade per request
let cascade_id = match state.backend.create_cascade().await {
Ok(cid) => cid,
Err(e) => {
return err_response(
StatusCode::BAD_GATEWAY,
format!("StartCascade failed: {e}"),
"server_error",
);
}
};
// Send message
match state
.backend
.send_message(&cascade_id, &user_text, model.model_enum)
.await
{
Ok((200, _)) => {
let bg = Arc::clone(&state.backend);
let cid = cascade_id.clone();
tokio::spawn(async move {
let _ = bg.update_annotations(&cid).await;
});
}
Ok((status, _)) => {
return err_response(
StatusCode::BAD_GATEWAY,
format!("Backend returned {status}"),
"server_error",
);
}
Err(e) => {
return err_response(
StatusCode::BAD_GATEWAY,
format!("Send failed: {e}"),
"server_error",
);
}
}
let completion_id = format!(
"chatcmpl-{}",
uuid::Uuid::new_v4().to_string().replace('-', "")
);
if body.stream {
chat_completions_stream(
state,
completion_id,
model_name.to_string(),
cascade_id,
body.timeout,
)
.await
} else {
chat_completions_sync(
state,
completion_id,
model_name.to_string(),
cascade_id,
body.timeout,
)
.await
}
}
// ─── Streaming ───────────────────────────────────────────────────────────────
/// Streaming output in Chat Completions format.
async fn chat_completions_stream(
state: Arc<AppState>,
completion_id: String,
model_name: String,
cascade_id: String,
timeout: u64,
) -> axum::response::Response {
let stream = async_stream::stream! {
let start = std::time::Instant::now();
let mut last_text = String::new();
// Initial role chunk
yield Ok::<_, std::convert::Infallible>(Event::default().data(serde_json::to_string(&serde_json::json!({
"id": completion_id,
"object": "chat.completion.chunk",
"created": now_unix(),
"model": model_name,
"choices": [{
"index": 0,
"delta": {"role": "assistant", "content": ""},
"finish_reason": serde_json::Value::Null,
}],
})).unwrap_or_default()));
let mut keepalive_counter: u64 = 0;
while start.elapsed().as_secs() < timeout {
// ── Check for MITM-captured function calls FIRST ──
// This runs independently of LS steps — the MITM captures tool calls
// at the proxy layer, so we don't need to wait for LS processing.
let captured = state.mitm_store.take_any_function_calls().await;
if let Some(ref calls) = captured {
if !calls.is_empty() {
let mut tool_calls = Vec::new();
for (i, fc) in calls.iter().enumerate() {
let call_id = format!(
"call_{}",
uuid::Uuid::new_v4().to_string().replace('-', "")[..24].to_string()
);
let arguments = serde_json::to_string(&fc.args).unwrap_or_default();
tool_calls.push(serde_json::json!({
"index": i,
"id": call_id,
"type": "function",
"function": {
"name": fc.name,
"arguments": arguments,
},
}));
}
yield Ok(Event::default().data(serde_json::to_string(&serde_json::json!({
"id": completion_id,
"object": "chat.completion.chunk",
"created": now_unix(),
"model": model_name,
"choices": [{
"index": 0,
"delta": {"tool_calls": tool_calls},
"finish_reason": serde_json::Value::Null,
}],
})).unwrap_or_default()));
yield Ok(Event::default().data(serde_json::to_string(&serde_json::json!({
"id": completion_id,
"object": "chat.completion.chunk",
"created": now_unix(),
"model": model_name,
"choices": [{
"index": 0,
"delta": {},
"finish_reason": "tool_calls",
}],
})).unwrap_or_default()));
yield Ok(Event::default().data("[DONE]"));
return;
}
}
// ── Check LS steps for text streaming ──
if let Ok((status, data)) = state.backend.get_steps(&cascade_id).await {
if status == 200 {
if let Some(steps) = data["steps"].as_array() {
let text = extract_response_text(steps);
if !text.is_empty() && text != last_text {
let delta = if text.len() > last_text.len() && text.starts_with(&*last_text) {
&text[last_text.len()..]
} else {
&text
};
if !delta.is_empty() {
yield Ok(Event::default().data(serde_json::to_string(&serde_json::json!({
"id": completion_id,
"object": "chat.completion.chunk",
"created": now_unix(),
"model": model_name,
"choices": [{
"index": 0,
"delta": {"content": delta},
"finish_reason": serde_json::Value::Null,
}],
})).unwrap_or_default()));
last_text = text.to_string();
}
}
// Done check: need DONE status AND non-empty text
if is_response_done(steps) && !last_text.is_empty() {
debug!("Completions stream done, text length={}", last_text.len());
yield Ok(Event::default().data(serde_json::to_string(&serde_json::json!({
"id": completion_id,
"object": "chat.completion.chunk",
"created": now_unix(),
"model": model_name,
"choices": [{
"index": 0,
"delta": {},
"finish_reason": "stop",
}],
})).unwrap_or_default()));
yield Ok(Event::default().data("[DONE]"));
return;
}
// IDLE fallback
let step_count = steps.len();
if step_count > 4 && step_count % 5 == 0 {
if let Ok((ts, td)) = state.backend.get_trajectory(&cascade_id).await {
if ts == 200 {
let run_status = td["status"].as_str().unwrap_or("");
if run_status.contains("IDLE") && !last_text.is_empty() {
debug!("Completions IDLE, text length={}", last_text.len());
yield Ok(Event::default().data(serde_json::to_string(&serde_json::json!({
"id": completion_id,
"object": "chat.completion.chunk",
"created": now_unix(),
"model": model_name,
"choices": [{
"index": 0,
"delta": {},
"finish_reason": "stop",
}],
})).unwrap_or_default()));
yield Ok(Event::default().data("[DONE]"));
return;
}
}
}
}
}
}
}
// Keep-alive comment every ~5 iterations
keepalive_counter += 1;
if keepalive_counter % 5 == 0 {
yield Ok(Event::default().comment("keepalive"));
}
// Fast poll — 300ms so we pick up MITM captures quickly
let poll_ms: u64 = rand::thread_rng().gen_range(250..400);
tokio::time::sleep(tokio::time::Duration::from_millis(poll_ms)).await;
}
// Timeout
warn!("Completions stream timeout after {}s", timeout);
yield Ok(Event::default().data(serde_json::to_string(&serde_json::json!({
"id": completion_id,
"object": "chat.completion.chunk",
"created": now_unix(),
"model": model_name,
"choices": [{
"index": 0,
"delta": {"content": if last_text.is_empty() { "[Timeout waiting for response]" } else { "" }},
"finish_reason": "stop",
}],
})).unwrap_or_default()));
yield Ok(Event::default().data("[DONE]"));
};
Sse::new(stream)
.keep_alive(
axum::response::sse::KeepAlive::new()
.interval(std::time::Duration::from_secs(15))
.text(""),
)
.into_response()
}
// ─── Sync ────────────────────────────────────────────────────────────────────
/// Sync output in Chat Completions format.
async fn chat_completions_sync(
state: Arc<AppState>,
completion_id: String,
model_name: String,
cascade_id: String,
timeout: u64,
) -> axum::response::Response {
let result = poll_for_response(&state, &cascade_id, timeout).await;
// Check MITM store first for real intercepted usage (fallback to _latest)
let mitm = match state.mitm_store.take_usage(&cascade_id).await {
Some(u) => Some(u),
None => state.mitm_store.take_usage("_latest").await,
};
let (prompt_tokens, completion_tokens, cached_tokens) = if let Some(mitm_usage) = mitm {
(mitm_usage.input_tokens, mitm_usage.output_tokens, mitm_usage.cache_read_input_tokens)
} else if let Some(u) = &result.usage {
(u.input_tokens, u.output_tokens, 0)
} else {
(0, 0, 0)
};
Json(serde_json::json!({
"id": completion_id,
"object": "chat.completion",
"created": now_unix(),
"model": model_name,
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": result.text,
},
"finish_reason": "stop",
}],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
"prompt_tokens_details": {
"cached_tokens": cached_tokens,
},
},
}))
.into_response()
}