feat: initial commit — antigravity proxy with MITM, standalone LS, and snapshot tooling
This commit is contained in:
343
src/api/completions.rs
Normal file
343
src/api/completions.rs
Normal file
@@ -0,0 +1,343 @@
|
||||
//! OpenAI Chat Completions API (/v1/chat/completions) handler.
|
||||
|
||||
use axum::{
|
||||
extract::State,
|
||||
http::StatusCode,
|
||||
response::{sse::Event, IntoResponse, Json, Sse},
|
||||
};
|
||||
use rand::Rng;
|
||||
use std::sync::Arc;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use super::models::{lookup_model, DEFAULT_MODEL, MODELS};
|
||||
use super::polling::{extract_response_text, is_response_done, poll_for_response};
|
||||
use super::types::*;
|
||||
use super::util::{err_response, now_unix};
|
||||
use super::AppState;
|
||||
|
||||
// ─── Input extraction ────────────────────────────────────────────────────────
|
||||
|
||||
/// Extract user text from Chat Completions messages array.
|
||||
fn extract_chat_input(messages: &[CompletionMessage]) -> String {
|
||||
let mut system_parts = Vec::new();
|
||||
let mut user_parts = Vec::new();
|
||||
|
||||
for msg in messages {
|
||||
let text = match &msg.content {
|
||||
serde_json::Value::String(s) => s.clone(),
|
||||
serde_json::Value::Array(arr) => arr
|
||||
.iter()
|
||||
.filter_map(|item| item["text"].as_str())
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n"),
|
||||
_ => continue,
|
||||
};
|
||||
match msg.role.as_str() {
|
||||
"system" | "developer" => system_parts.push(text),
|
||||
"user" => user_parts.push(text),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
let mut result = String::new();
|
||||
if !system_parts.is_empty() {
|
||||
result.push_str(&system_parts.join("\n"));
|
||||
result.push_str("\n\n");
|
||||
}
|
||||
// Use the last user message
|
||||
if let Some(last) = user_parts.last() {
|
||||
result.push_str(last);
|
||||
}
|
||||
result.trim().to_string()
|
||||
}
|
||||
|
||||
// ─── Handler ─────────────────────────────────────────────────────────────────
|
||||
|
||||
/// POST /v1/chat/completions — OpenAI Chat Completions API compatibility shim.
|
||||
/// Accepts standard messages format, reuses the same backend cascade, and
|
||||
/// outputs in the Chat Completions streaming/sync format.
|
||||
pub(crate) async fn handle_completions(
|
||||
State(state): State<Arc<AppState>>,
|
||||
Json(body): Json<CompletionRequest>,
|
||||
) -> axum::response::Response {
|
||||
let model_name = body.model.as_deref().unwrap_or(DEFAULT_MODEL);
|
||||
info!(
|
||||
"POST /v1/chat/completions model={} stream={}",
|
||||
model_name, body.stream
|
||||
);
|
||||
|
||||
let model = match lookup_model(model_name) {
|
||||
Some(m) => m,
|
||||
None => {
|
||||
let names: Vec<&str> = MODELS.iter().map(|m| m.name).collect();
|
||||
return err_response(
|
||||
StatusCode::BAD_REQUEST,
|
||||
format!("Unknown model: {model_name}. Available: {names:?}"),
|
||||
"invalid_request_error",
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
let token = state.backend.oauth_token().await;
|
||||
if token.is_empty() {
|
||||
return err_response(
|
||||
StatusCode::UNAUTHORIZED,
|
||||
"No OAuth token. POST to /v1/token or set ANTIGRAVITY_OAUTH_TOKEN env var.".into(),
|
||||
"authentication_error",
|
||||
);
|
||||
}
|
||||
|
||||
let user_text = extract_chat_input(&body.messages);
|
||||
if user_text.is_empty() {
|
||||
return err_response(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"No user message found".to_string(),
|
||||
"invalid_request_error",
|
||||
);
|
||||
}
|
||||
|
||||
// Fresh cascade per request
|
||||
let cascade_id = match state.backend.create_cascade().await {
|
||||
Ok(cid) => cid,
|
||||
Err(e) => {
|
||||
return err_response(
|
||||
StatusCode::BAD_GATEWAY,
|
||||
format!("StartCascade failed: {e}"),
|
||||
"server_error",
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
// Send message
|
||||
match state
|
||||
.backend
|
||||
.send_message(&cascade_id, &user_text, model.model_enum)
|
||||
.await
|
||||
{
|
||||
Ok((200, _)) => {
|
||||
let bg = Arc::clone(&state.backend);
|
||||
let cid = cascade_id.clone();
|
||||
tokio::spawn(async move {
|
||||
let _ = bg.update_annotations(&cid).await;
|
||||
});
|
||||
}
|
||||
Ok((status, _)) => {
|
||||
return err_response(
|
||||
StatusCode::BAD_GATEWAY,
|
||||
format!("Backend returned {status}"),
|
||||
"server_error",
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
return err_response(
|
||||
StatusCode::BAD_GATEWAY,
|
||||
format!("Send failed: {e}"),
|
||||
"server_error",
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let completion_id = format!(
|
||||
"chatcmpl-{}",
|
||||
uuid::Uuid::new_v4().to_string().replace('-', "")
|
||||
);
|
||||
|
||||
if body.stream {
|
||||
chat_completions_stream(
|
||||
state,
|
||||
completion_id,
|
||||
model_name.to_string(),
|
||||
cascade_id,
|
||||
body.timeout,
|
||||
)
|
||||
.await
|
||||
} else {
|
||||
chat_completions_sync(
|
||||
state,
|
||||
completion_id,
|
||||
model_name.to_string(),
|
||||
cascade_id,
|
||||
body.timeout,
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Streaming ───────────────────────────────────────────────────────────────
|
||||
|
||||
/// Streaming output in Chat Completions format.
|
||||
async fn chat_completions_stream(
|
||||
state: Arc<AppState>,
|
||||
completion_id: String,
|
||||
model_name: String,
|
||||
cascade_id: String,
|
||||
timeout: u64,
|
||||
) -> axum::response::Response {
|
||||
let stream = async_stream::stream! {
|
||||
let start = std::time::Instant::now();
|
||||
let mut last_text = String::new();
|
||||
|
||||
// Initial role chunk
|
||||
yield Ok::<_, std::convert::Infallible>(Event::default().data(serde_json::to_string(&serde_json::json!({
|
||||
"id": completion_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": now_unix(),
|
||||
"model": model_name,
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"delta": {"role": "assistant", "content": ""},
|
||||
"finish_reason": serde_json::Value::Null,
|
||||
}],
|
||||
})).unwrap_or_default()));
|
||||
|
||||
while start.elapsed().as_secs() < timeout {
|
||||
if let Ok((status, data)) = state.backend.get_steps(&cascade_id).await {
|
||||
if status == 200 {
|
||||
if let Some(steps) = data["steps"].as_array() {
|
||||
let text = extract_response_text(steps);
|
||||
|
||||
if !text.is_empty() && text != last_text {
|
||||
let delta = if text.len() > last_text.len() && text.starts_with(&*last_text) {
|
||||
&text[last_text.len()..]
|
||||
} else {
|
||||
&text
|
||||
};
|
||||
|
||||
if !delta.is_empty() {
|
||||
yield Ok(Event::default().data(serde_json::to_string(&serde_json::json!({
|
||||
"id": completion_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": now_unix(),
|
||||
"model": model_name,
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"delta": {"content": delta},
|
||||
"finish_reason": serde_json::Value::Null,
|
||||
}],
|
||||
})).unwrap_or_default()));
|
||||
last_text = text.to_string();
|
||||
}
|
||||
}
|
||||
|
||||
// Done check: need DONE status AND non-empty text
|
||||
if is_response_done(steps) && !last_text.is_empty() {
|
||||
debug!("Completions stream done, text length={}", last_text.len());
|
||||
yield Ok(Event::default().data(serde_json::to_string(&serde_json::json!({
|
||||
"id": completion_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": now_unix(),
|
||||
"model": model_name,
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"delta": {},
|
||||
"finish_reason": "stop",
|
||||
}],
|
||||
})).unwrap_or_default()));
|
||||
yield Ok(Event::default().data("[DONE]".to_string()));
|
||||
return;
|
||||
}
|
||||
|
||||
// IDLE fallback: check trajectory status periodically
|
||||
// Only check every 5th step count to reduce backend traffic
|
||||
let step_count = steps.len();
|
||||
if step_count > 4 && step_count % 5 == 0 {
|
||||
if let Ok((ts, td)) = state.backend.get_trajectory(&cascade_id).await {
|
||||
if ts == 200 {
|
||||
let run_status = td["status"].as_str().unwrap_or("");
|
||||
if run_status.contains("IDLE") && !last_text.is_empty() {
|
||||
debug!("Completions IDLE, text length={}", last_text.len());
|
||||
yield Ok(Event::default().data(serde_json::to_string(&serde_json::json!({
|
||||
"id": completion_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": now_unix(),
|
||||
"model": model_name,
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"delta": {},
|
||||
"finish_reason": "stop",
|
||||
}],
|
||||
})).unwrap_or_default()));
|
||||
yield Ok(Event::default().data("[DONE]".to_string()));
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let poll_ms: u64 = rand::thread_rng().gen_range(800..1200);
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(poll_ms)).await;
|
||||
}
|
||||
|
||||
// Timeout
|
||||
warn!("Completions stream timeout after {}s", timeout);
|
||||
yield Ok(Event::default().data(serde_json::to_string(&serde_json::json!({
|
||||
"id": completion_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": now_unix(),
|
||||
"model": model_name,
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"delta": {"content": if last_text.is_empty() { "[Timeout waiting for response]" } else { "" }},
|
||||
"finish_reason": "stop",
|
||||
}],
|
||||
})).unwrap_or_default()));
|
||||
yield Ok(Event::default().data("[DONE]".to_string()));
|
||||
};
|
||||
|
||||
Sse::new(stream)
|
||||
.keep_alive(
|
||||
axum::response::sse::KeepAlive::new()
|
||||
.interval(std::time::Duration::from_secs(15))
|
||||
.text(""),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
// ─── Sync ────────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Sync output in Chat Completions format.
|
||||
async fn chat_completions_sync(
|
||||
state: Arc<AppState>,
|
||||
completion_id: String,
|
||||
model_name: String,
|
||||
cascade_id: String,
|
||||
timeout: u64,
|
||||
) -> axum::response::Response {
|
||||
let result = poll_for_response(&state, &cascade_id, timeout).await;
|
||||
|
||||
// Check MITM store first for real intercepted usage
|
||||
let (prompt_tokens, completion_tokens, cached_tokens) = if let Some(mitm_usage) = state.mitm_store.take_usage(&cascade_id).await {
|
||||
(mitm_usage.input_tokens, mitm_usage.output_tokens, mitm_usage.cache_read_input_tokens)
|
||||
} else if let Some(u) = &result.usage {
|
||||
(u.input_tokens, u.output_tokens, 0)
|
||||
} else {
|
||||
(0, 0, 0)
|
||||
};
|
||||
|
||||
Json(serde_json::json!({
|
||||
"id": completion_id,
|
||||
"object": "chat.completion",
|
||||
"created": now_unix(),
|
||||
"model": model_name,
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": result.text,
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": prompt_tokens + completion_tokens,
|
||||
"prompt_tokens_details": {
|
||||
"cached_tokens": cached_tokens,
|
||||
},
|
||||
},
|
||||
}))
|
||||
.into_response()
|
||||
}
|
||||
176
src/api/mod.rs
Normal file
176
src/api/mod.rs
Normal file
@@ -0,0 +1,176 @@
|
||||
//! Axum API server — OpenAI-compatible Responses + Chat Completions endpoints.
|
||||
|
||||
mod completions;
|
||||
mod models;
|
||||
mod polling;
|
||||
mod responses;
|
||||
mod types;
|
||||
mod util;
|
||||
|
||||
use crate::constants::safe_truncate;
|
||||
use crate::session::SessionManager;
|
||||
use axum::{
|
||||
extract::{DefaultBodyLimit, State},
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Json},
|
||||
routing::{delete, get, post},
|
||||
Router,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
use tower_http::cors::CorsLayer;
|
||||
use tracing::warn;
|
||||
|
||||
use self::models::MODELS;
|
||||
use self::types::TokenRequest;
|
||||
|
||||
// ─── Shared state ────────────────────────────────────────────────────────────
|
||||
|
||||
pub struct AppState {
|
||||
pub backend: Arc<crate::backend::Backend>,
|
||||
pub sessions: SessionManager,
|
||||
pub mitm_store: crate::mitm::store::MitmStore,
|
||||
pub quota_store: crate::quota::QuotaStore,
|
||||
}
|
||||
|
||||
// ─── Router ──────────────────────────────────────────────────────────────────
|
||||
|
||||
pub fn router(state: Arc<AppState>) -> Router {
|
||||
Router::new()
|
||||
.route("/v1/responses", post(responses::handle_responses))
|
||||
.route(
|
||||
"/v1/chat/completions",
|
||||
post(completions::handle_completions),
|
||||
)
|
||||
.route("/v1/models", get(handle_models))
|
||||
.route("/v1/sessions", get(handle_list_sessions))
|
||||
.route("/v1/sessions/{id}", delete(handle_delete_session))
|
||||
.route("/v1/token", post(handle_set_token))
|
||||
.route("/v1/usage", get(handle_usage))
|
||||
.route("/v1/quota", get(handle_quota))
|
||||
.route("/health", get(handle_health))
|
||||
.route("/", get(handle_root))
|
||||
.layer(CorsLayer::permissive())
|
||||
.layer(DefaultBodyLimit::max(1_048_576)) // 1 MB
|
||||
.with_state(state)
|
||||
}
|
||||
|
||||
// ─── Simple handlers ─────────────────────────────────────────────────────────
|
||||
|
||||
async fn handle_root() -> Json<serde_json::Value> {
|
||||
Json(serde_json::json!({
|
||||
"service": "antigravity-openai-proxy",
|
||||
"version": "3.2.0",
|
||||
"runtime": "rust",
|
||||
"endpoints": [
|
||||
"/v1/chat/completions",
|
||||
"/v1/responses",
|
||||
"/v1/models",
|
||||
"/v1/sessions",
|
||||
"/v1/token",
|
||||
"/v1/usage",
|
||||
"/v1/quota",
|
||||
"/health",
|
||||
],
|
||||
}))
|
||||
}
|
||||
|
||||
async fn handle_health() -> Json<serde_json::Value> {
|
||||
Json(serde_json::json!({"status": "ok"}))
|
||||
}
|
||||
|
||||
async fn handle_models() -> Json<serde_json::Value> {
|
||||
let models: Vec<serde_json::Value> = MODELS
|
||||
.iter()
|
||||
.map(|m| {
|
||||
serde_json::json!({
|
||||
"id": m.name,
|
||||
"object": "model",
|
||||
"created": 1700000000u64,
|
||||
"owned_by": "antigravity",
|
||||
"permission": [],
|
||||
"root": m.name,
|
||||
"parent": null,
|
||||
"meta": {
|
||||
"label": m.label,
|
||||
"enum_value": m.model_enum,
|
||||
},
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
Json(serde_json::json!({"object": "list", "data": models}))
|
||||
}
|
||||
|
||||
async fn handle_list_sessions(
|
||||
State(state): State<Arc<AppState>>,
|
||||
) -> Json<serde_json::Value> {
|
||||
let sessions = state.sessions.list_sessions().await;
|
||||
Json(serde_json::json!({"sessions": sessions}))
|
||||
}
|
||||
|
||||
async fn handle_delete_session(
|
||||
State(state): State<Arc<AppState>>,
|
||||
axum::extract::Path(id): axum::extract::Path<String>,
|
||||
) -> impl IntoResponse {
|
||||
if state.sessions.delete_session(&id).await {
|
||||
(
|
||||
StatusCode::OK,
|
||||
Json(serde_json::json!({"status": "deleted", "session_id": id})),
|
||||
)
|
||||
} else {
|
||||
(
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(serde_json::json!({"error": format!("Session not found: {id}")})),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_set_token(
|
||||
State(state): State<Arc<AppState>>,
|
||||
Json(body): Json<TokenRequest>,
|
||||
) -> impl IntoResponse {
|
||||
if !body.token.starts_with("ya29.") {
|
||||
return (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(serde_json::json!({"error": "Invalid token. Must start with ya29."})),
|
||||
);
|
||||
}
|
||||
state.backend.set_oauth_token(body.token.clone()).await;
|
||||
|
||||
// Also persist to file
|
||||
let token_path = crate::constants::token_file_path();
|
||||
if let Err(e) = tokio::fs::write(&token_path, &body.token).await {
|
||||
warn!("Failed to write token file: {e}");
|
||||
}
|
||||
|
||||
let preview = safe_truncate(&body.token, 20);
|
||||
(
|
||||
StatusCode::OK,
|
||||
Json(serde_json::json!({"status": "ok", "token_prefix": preview})),
|
||||
)
|
||||
}
|
||||
|
||||
async fn handle_usage(
|
||||
State(state): State<Arc<AppState>>,
|
||||
) -> Json<serde_json::Value> {
|
||||
let stats = state.mitm_store.stats().await;
|
||||
Json(serde_json::json!({
|
||||
"mitm": {
|
||||
"total_requests": stats.total_requests,
|
||||
"total_input_tokens": stats.total_input_tokens,
|
||||
"total_output_tokens": stats.total_output_tokens,
|
||||
"total_cache_read_tokens": stats.total_cache_read_tokens,
|
||||
"total_cache_creation_tokens": stats.total_cache_creation_tokens,
|
||||
"total_thinking_output_tokens": stats.total_thinking_output_tokens,
|
||||
"total_response_output_tokens": stats.total_response_output_tokens,
|
||||
"total_tokens": stats.total_input_tokens + stats.total_output_tokens,
|
||||
"per_model": stats.per_model,
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
async fn handle_quota(
|
||||
State(state): State<Arc<AppState>>,
|
||||
) -> Json<serde_json::Value> {
|
||||
let snap = state.quota_store.snapshot().await;
|
||||
Json(serde_json::to_value(snap).unwrap_or_default())
|
||||
}
|
||||
49
src/api/models.rs
Normal file
49
src/api/models.rs
Normal file
@@ -0,0 +1,49 @@
|
||||
//! Model definitions and lookup.
|
||||
|
||||
/// Model definition: friendly name → (antigravity_id, protobuf_enum, label).
|
||||
pub(crate) struct ModelDef {
|
||||
pub name: &'static str,
|
||||
#[allow(dead_code)]
|
||||
pub ag_id: &'static str,
|
||||
pub model_enum: u32,
|
||||
pub label: &'static str,
|
||||
}
|
||||
|
||||
pub(crate) const MODELS: &[ModelDef] = &[
|
||||
ModelDef {
|
||||
name: "opus-4.6",
|
||||
ag_id: "MODEL_PLACEHOLDER_M26",
|
||||
model_enum: 1026,
|
||||
label: "Claude Opus 4.6 (Thinking)",
|
||||
},
|
||||
ModelDef {
|
||||
name: "opus-4.5",
|
||||
ag_id: "MODEL_PLACEHOLDER_M12",
|
||||
model_enum: 1012,
|
||||
label: "Claude Opus 4.5 (Thinking)",
|
||||
},
|
||||
ModelDef {
|
||||
name: "gemini-3-pro-high",
|
||||
ag_id: "MODEL_PLACEHOLDER_M8",
|
||||
model_enum: 1008,
|
||||
label: "Gemini 3 Pro (High)",
|
||||
},
|
||||
ModelDef {
|
||||
name: "gemini-3-pro",
|
||||
ag_id: "MODEL_PLACEHOLDER_M7",
|
||||
model_enum: 1007,
|
||||
label: "Gemini 3 Pro (Low)",
|
||||
},
|
||||
ModelDef {
|
||||
name: "gemini-3-flash",
|
||||
ag_id: "MODEL_PLACEHOLDER_M18",
|
||||
model_enum: 1018,
|
||||
label: "Gemini 3 Flash",
|
||||
},
|
||||
];
|
||||
|
||||
pub(crate) const DEFAULT_MODEL: &str = "opus-4.6";
|
||||
|
||||
pub(crate) fn lookup_model(name: &str) -> Option<&'static ModelDef> {
|
||||
MODELS.iter().find(|m| m.name == name)
|
||||
}
|
||||
298
src/api/polling.rs
Normal file
298
src/api/polling.rs
Normal file
@@ -0,0 +1,298 @@
|
||||
//! Shared polling engine and step extraction helpers.
|
||||
|
||||
use rand::Rng;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use super::AppState;
|
||||
|
||||
/// Real token usage reported by the language server.
|
||||
#[derive(Debug, Clone, Default)]
|
||||
#[allow(dead_code)]
|
||||
pub(crate) struct ModelUsage {
|
||||
pub input_tokens: u64,
|
||||
pub output_tokens: u64,
|
||||
pub api_provider: String,
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
/// Result of polling — text + optional real usage data + thinking data.
|
||||
pub(crate) struct PollResult {
|
||||
pub text: String,
|
||||
pub usage: Option<ModelUsage>,
|
||||
/// Opaque Anthropic thinking verification signature from PLANNER_RESPONSE.
|
||||
/// Required for multi-turn thinking model chaining.
|
||||
pub thinking_signature: Option<String>,
|
||||
/// The model's internal reasoning/thinking content.
|
||||
/// Available for both Opus (Anthropic) and Gemini models.
|
||||
pub thinking: Option<String>,
|
||||
/// Time the model spent thinking, as reported by the LS (e.g. "0.041999832s").
|
||||
pub thinking_duration: Option<String>,
|
||||
}
|
||||
|
||||
/// Extract the response text from steps — scans in REVERSE to find the latest response.
|
||||
pub(crate) fn extract_response_text(steps: &[serde_json::Value]) -> String {
|
||||
for step in steps.iter().rev() {
|
||||
let step_type = step["type"].as_str().unwrap_or("");
|
||||
|
||||
if step_type.contains("PLANNER_RESPONSE") {
|
||||
let resp = &step["plannerResponse"];
|
||||
let text = resp["rawResponse"]
|
||||
.as_str()
|
||||
.or_else(|| resp["response"].as_str())
|
||||
.unwrap_or("");
|
||||
if !text.is_empty() {
|
||||
return text.to_string();
|
||||
}
|
||||
}
|
||||
|
||||
if step_type.contains("AI_RESPONSE") || step_type.contains("MODEL_RESPONSE") {
|
||||
if let Some(text) = step["response"]
|
||||
.as_str()
|
||||
.or_else(|| step["rawResponse"].as_str())
|
||||
.or_else(|| step["text"].as_str())
|
||||
{
|
||||
if !text.is_empty() {
|
||||
return text.to_string();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
String::new()
|
||||
}
|
||||
|
||||
/// Extract real token usage from the LS's modelUsage field.
|
||||
/// The LS reports this in CHECKPOINT steps and sometimes in retryInfos.
|
||||
/// Scans in reverse to find the latest usage data.
|
||||
pub(crate) fn extract_model_usage(steps: &[serde_json::Value]) -> Option<ModelUsage> {
|
||||
for step in steps.iter().rev() {
|
||||
if let Some(usage) = step.get("metadata").and_then(|m| m.get("modelUsage")) {
|
||||
let input = usage["inputTokens"]
|
||||
.as_str()
|
||||
.and_then(|s| s.parse::<u64>().ok())
|
||||
.or_else(|| usage["inputTokens"].as_u64())
|
||||
.unwrap_or(0);
|
||||
let output = usage["outputTokens"]
|
||||
.as_str()
|
||||
.and_then(|s| s.parse::<u64>().ok())
|
||||
.or_else(|| usage["outputTokens"].as_u64())
|
||||
.unwrap_or(0);
|
||||
|
||||
if input > 0 || output > 0 {
|
||||
return Some(ModelUsage {
|
||||
input_tokens: input,
|
||||
output_tokens: output,
|
||||
api_provider: usage["apiProvider"]
|
||||
.as_str()
|
||||
.unwrap_or("")
|
||||
.to_string(),
|
||||
model: usage["model"]
|
||||
.as_str()
|
||||
.unwrap_or("")
|
||||
.to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Extract the thinking signature from PLANNER_RESPONSE steps.
|
||||
/// This is an opaque Base64 blob used by Anthropic for extended thinking
|
||||
/// verification. Needed to chain multi-turn conversations with thinking models.
|
||||
pub(crate) fn extract_thinking_signature(steps: &[serde_json::Value]) -> Option<String> {
|
||||
for step in steps.iter().rev() {
|
||||
let step_type = step["type"].as_str().unwrap_or("");
|
||||
if step_type.contains("PLANNER_RESPONSE") {
|
||||
if let Some(sig) = step["plannerResponse"]["thinkingSignature"].as_str() {
|
||||
if !sig.is_empty() {
|
||||
return Some(sig.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Extract the model's thinking/reasoning content from PLANNER_RESPONSE steps.
|
||||
/// This is the internal monologue produced during extended thinking.
|
||||
/// Available for ALL models (Opus, Gemini Flash, Gemini Pro).
|
||||
pub(crate) fn extract_thinking_content(steps: &[serde_json::Value]) -> Option<String> {
|
||||
for step in steps.iter().rev() {
|
||||
let step_type = step["type"].as_str().unwrap_or("");
|
||||
if step_type.contains("PLANNER_RESPONSE") {
|
||||
if let Some(thinking) = step["plannerResponse"]["thinking"].as_str() {
|
||||
if !thinking.is_empty() {
|
||||
return Some(thinking.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Extract thinking duration from PLANNER_RESPONSE steps.
|
||||
/// Returns the raw duration string as reported by the LS (e.g. "0.041999832s").
|
||||
pub(crate) fn extract_thinking_duration(steps: &[serde_json::Value]) -> Option<String> {
|
||||
for step in steps.iter().rev() {
|
||||
let step_type = step["type"].as_str().unwrap_or("");
|
||||
if step_type.contains("PLANNER_RESPONSE") {
|
||||
if let Some(dur) = step["plannerResponse"]["thinkingDuration"].as_str() {
|
||||
if !dur.is_empty() {
|
||||
return Some(dur.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Check if the cascade has truly finished — the last PLANNER_RESPONSE must be DONE
|
||||
/// AND the very last step must be a terminal type (CHECKPOINT or PLANNER_RESPONSE with DONE).
|
||||
/// This prevents false positives during agentic tool-call loops where intermediate
|
||||
/// PLANNER_RESPONSE steps show DONE but the cascade keeps going.
|
||||
pub(crate) fn is_response_done(steps: &[serde_json::Value]) -> bool {
|
||||
if steps.is_empty() {
|
||||
return false;
|
||||
}
|
||||
|
||||
let last = &steps[steps.len() - 1];
|
||||
let last_type = last["type"].as_str().unwrap_or("");
|
||||
let last_status = last["status"].as_str().unwrap_or("");
|
||||
|
||||
// CHECKPOINT at the end = cascade is definitely done
|
||||
if last_type.contains("CHECKPOINT") {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Last step is a PLANNER_RESPONSE with DONE = final answer (no more tool calls coming)
|
||||
if (last_type.contains("PLANNER_RESPONSE")
|
||||
|| last_type.contains("AI_RESPONSE")
|
||||
|| last_type.contains("MODEL_RESPONSE"))
|
||||
&& last_status.contains("DONE")
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
/// Poll the backend until we get a response or timeout.
|
||||
pub(crate) async fn poll_for_response(
|
||||
state: &AppState,
|
||||
cascade_id: &str,
|
||||
timeout: u64,
|
||||
) -> PollResult {
|
||||
let start = std::time::Instant::now();
|
||||
let short_id = &cascade_id[..8.min(cascade_id.len())];
|
||||
info!("Polling for response on cascade {short_id} (timeout={timeout}s)");
|
||||
|
||||
let mut last_step_count: usize = 0;
|
||||
|
||||
while start.elapsed().as_secs() < timeout {
|
||||
if let Ok((status, data)) = state.backend.get_steps(cascade_id).await {
|
||||
if status == 200 {
|
||||
if let Some(steps) = data["steps"].as_array() {
|
||||
let step_count = steps.len();
|
||||
|
||||
// Only log when step count changes (denoised)
|
||||
if step_count != last_step_count {
|
||||
// Compact type summary: count unique types
|
||||
let mut type_counts: std::collections::BTreeMap<&str, usize> =
|
||||
std::collections::BTreeMap::new();
|
||||
for s in steps.iter() {
|
||||
let t = s["type"]
|
||||
.as_str()
|
||||
.unwrap_or("?")
|
||||
.strip_prefix("CORTEX_STEP_TYPE_")
|
||||
.unwrap_or("?");
|
||||
*type_counts.entry(t).or_insert(0) += 1;
|
||||
}
|
||||
let summary: Vec<String> = type_counts
|
||||
.iter()
|
||||
.map(|(t, c)| {
|
||||
if *c > 1 {
|
||||
format!("{t}×{c}")
|
||||
} else {
|
||||
t.to_string()
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
debug!(
|
||||
"Poll {short_id}: {step_count} steps [{}]",
|
||||
summary.join(", ")
|
||||
);
|
||||
last_step_count = step_count;
|
||||
}
|
||||
|
||||
// Check if the cascade is truly done
|
||||
if is_response_done(steps) {
|
||||
let text = extract_response_text(steps);
|
||||
if !text.is_empty() {
|
||||
let usage = extract_model_usage(steps);
|
||||
let thinking_signature = extract_thinking_signature(steps);
|
||||
let thinking = extract_thinking_content(steps);
|
||||
let thinking_duration = extract_thinking_duration(steps);
|
||||
let elapsed = start.elapsed().as_secs_f32();
|
||||
if let Some(ref u) = usage {
|
||||
info!(
|
||||
"Response done ({short_id}), {:.1}s, {} chars, tokens: {}in/{}out ({}){}{}",
|
||||
elapsed, text.len(), u.input_tokens, u.output_tokens, u.model,
|
||||
if thinking.is_some() { format!(", thinking: {} chars", thinking.as_ref().unwrap().len()) } else { String::new() },
|
||||
if thinking_signature.is_some() { ", has sig" } else { "" }
|
||||
);
|
||||
} else {
|
||||
info!(
|
||||
"Response done ({short_id}), {:.1}s, {} chars (no usage){}{}",
|
||||
elapsed, text.len(),
|
||||
if thinking.is_some() { format!(", thinking: {} chars", thinking.as_ref().unwrap().len()) } else { String::new() },
|
||||
if thinking_signature.is_some() { ", has sig" } else { "" }
|
||||
);
|
||||
}
|
||||
return PollResult { text, usage, thinking_signature, thinking, thinking_duration };
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: check trajectory IDLE status (catches edge cases)
|
||||
// Only check every 5th poll to reduce network calls
|
||||
if step_count > 4 && step_count % 5 == 0 {
|
||||
if let Ok((ts, td)) = state.backend.get_trajectory(cascade_id).await
|
||||
{
|
||||
if ts == 200 {
|
||||
let run_status =
|
||||
td["status"].as_str().unwrap_or("");
|
||||
if run_status.contains("IDLE") {
|
||||
let text = extract_response_text(steps);
|
||||
if !text.is_empty() {
|
||||
let usage = extract_model_usage(steps);
|
||||
let thinking_signature = extract_thinking_signature(steps);
|
||||
let thinking = extract_thinking_content(steps);
|
||||
let thinking_duration = extract_thinking_duration(steps);
|
||||
let elapsed = start.elapsed().as_secs_f32();
|
||||
info!(
|
||||
"Trajectory IDLE ({short_id}), {:.1}s, {} chars",
|
||||
elapsed,
|
||||
text.len()
|
||||
);
|
||||
return PollResult { text, usage, thinking_signature, thinking, thinking_duration };
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let poll_ms: u64 = rand::thread_rng().gen_range(1000..1800);
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(poll_ms)).await;
|
||||
}
|
||||
|
||||
warn!("Timeout after {timeout}s on cascade {short_id}");
|
||||
PollResult {
|
||||
text: "[Timeout waiting for AI response]".to_string(),
|
||||
usage: None,
|
||||
thinking_signature: None,
|
||||
thinking: None,
|
||||
thinking_duration: None,
|
||||
}
|
||||
}
|
||||
686
src/api/responses.rs
Normal file
686
src/api/responses.rs
Normal file
@@ -0,0 +1,686 @@
|
||||
//! OpenAI Responses API (/v1/responses) handler.
|
||||
//!
|
||||
//! Strictly adheres to the official OpenAI Responses API protocol:
|
||||
//! https://platform.openai.com/docs/api-reference/responses
|
||||
|
||||
use axum::{
|
||||
extract::State,
|
||||
http::StatusCode,
|
||||
response::{sse::Event, IntoResponse, Json, Sse},
|
||||
};
|
||||
use rand::Rng;
|
||||
use std::sync::atomic::{AtomicU32, Ordering};
|
||||
use std::sync::Arc;
|
||||
use tracing::{debug, info};
|
||||
|
||||
use super::models::{lookup_model, DEFAULT_MODEL, MODELS};
|
||||
use super::polling::{extract_response_text, is_response_done, poll_for_response, extract_model_usage, extract_thinking_signature, extract_thinking_content, extract_thinking_duration};
|
||||
use super::types::*;
|
||||
use super::util::{err_response, now_unix, responses_sse_event};
|
||||
use super::AppState;
|
||||
|
||||
// ─── Input extraction ────────────────────────────────────────────────────────
|
||||
|
||||
/// Extract user text from Responses API `input` field.
|
||||
fn extract_responses_input(input: &serde_json::Value, instructions: Option<&str>) -> String {
|
||||
let user_text = match input {
|
||||
serde_json::Value::String(s) => s.clone(),
|
||||
serde_json::Value::Array(items) => {
|
||||
items
|
||||
.iter()
|
||||
.rev()
|
||||
.find(|item| item["role"].as_str() == Some("user"))
|
||||
.and_then(|item| match &item["content"] {
|
||||
serde_json::Value::String(s) => Some(s.clone()),
|
||||
serde_json::Value::Array(parts) => Some(
|
||||
parts
|
||||
.iter()
|
||||
.filter(|p| {
|
||||
let t = p["type"].as_str().unwrap_or("");
|
||||
t == "input_text" || t == "text"
|
||||
})
|
||||
.filter_map(|p| p["text"].as_str())
|
||||
.collect::<Vec<_>>()
|
||||
.join(" "),
|
||||
),
|
||||
_ => None,
|
||||
})
|
||||
.unwrap_or_default()
|
||||
}
|
||||
_ => String::new(),
|
||||
};
|
||||
|
||||
match instructions {
|
||||
Some(inst) if !inst.is_empty() => format!("{inst}\n\n{user_text}"),
|
||||
_ => user_text,
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract conversation/session ID from Responses API `conversation` field.
|
||||
fn extract_conversation_id(conv: &Option<serde_json::Value>) -> Option<String> {
|
||||
match conv {
|
||||
Some(serde_json::Value::String(s)) => Some(s.clone()),
|
||||
Some(obj) => obj["id"].as_str().map(|s| s.to_string()),
|
||||
None => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a full Response object matching the official OpenAI schema.
|
||||
fn build_response_object(
|
||||
id: &str,
|
||||
model: &str,
|
||||
status: &'static str,
|
||||
created_at: u64,
|
||||
completed_at: Option<u64>,
|
||||
output: Vec<ResponseOutput>,
|
||||
usage: Option<Usage>,
|
||||
instructions: Option<&str>,
|
||||
store: bool,
|
||||
temperature: f64,
|
||||
top_p: f64,
|
||||
max_output_tokens: Option<u64>,
|
||||
previous_response_id: Option<&str>,
|
||||
user: Option<&str>,
|
||||
metadata: &serde_json::Value,
|
||||
thinking_signature: Option<String>,
|
||||
thinking: Option<String>,
|
||||
thinking_duration: Option<String>,
|
||||
) -> ResponsesResponse {
|
||||
ResponsesResponse {
|
||||
id: id.to_string(),
|
||||
object: "response",
|
||||
created_at,
|
||||
status,
|
||||
completed_at,
|
||||
error: None,
|
||||
incomplete_details: None,
|
||||
instructions: instructions.map(|s| s.to_string()),
|
||||
max_output_tokens,
|
||||
model: model.to_string(),
|
||||
output,
|
||||
parallel_tool_calls: true,
|
||||
previous_response_id: previous_response_id.map(|s| s.to_string()),
|
||||
reasoning: Reasoning::default(),
|
||||
store,
|
||||
temperature,
|
||||
text: TextFormat::default(),
|
||||
tool_choice: "auto",
|
||||
tools: vec![],
|
||||
top_p,
|
||||
truncation: "disabled",
|
||||
usage,
|
||||
user: user.map(|s| s.to_string()),
|
||||
metadata: metadata.clone(),
|
||||
thinking_signature,
|
||||
thinking,
|
||||
thinking_duration,
|
||||
}
|
||||
}
|
||||
|
||||
/// Serialize a ResponsesResponse to serde_json::Value for SSE embedding.
|
||||
fn response_to_json(resp: &ResponsesResponse) -> serde_json::Value {
|
||||
serde_json::to_value(resp).unwrap_or(serde_json::json!({}))
|
||||
}
|
||||
|
||||
// ─── Handler ─────────────────────────────────────────────────────────────────
|
||||
|
||||
pub(crate) async fn handle_responses(
|
||||
State(state): State<Arc<AppState>>,
|
||||
Json(body): Json<ResponsesRequest>,
|
||||
) -> axum::response::Response {
|
||||
info!(
|
||||
"POST /v1/responses model={} stream={}",
|
||||
body.model.as_deref().unwrap_or(DEFAULT_MODEL),
|
||||
body.stream
|
||||
);
|
||||
|
||||
let model_name = body.model.as_deref().unwrap_or(DEFAULT_MODEL);
|
||||
let model = match lookup_model(model_name) {
|
||||
Some(m) => m,
|
||||
None => {
|
||||
let names: Vec<&str> = MODELS.iter().map(|m| m.name).collect();
|
||||
return err_response(
|
||||
StatusCode::BAD_REQUEST,
|
||||
format!("Unknown model: {model_name}. Available: {names:?}"),
|
||||
"invalid_request_error",
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
let token = state.backend.oauth_token().await;
|
||||
if token.is_empty() {
|
||||
return err_response(
|
||||
StatusCode::UNAUTHORIZED,
|
||||
"No OAuth token. POST to /v1/token or set ANTIGRAVITY_OAUTH_TOKEN env var.".into(),
|
||||
"authentication_error",
|
||||
);
|
||||
}
|
||||
|
||||
let user_text = extract_responses_input(&body.input, body.instructions.as_deref());
|
||||
if user_text.is_empty() {
|
||||
return err_response(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"No user input found".to_string(),
|
||||
"invalid_request_error",
|
||||
);
|
||||
}
|
||||
|
||||
let response_id = format!(
|
||||
"resp_{}",
|
||||
uuid::Uuid::new_v4().to_string().replace('-', "")
|
||||
);
|
||||
|
||||
// Session/conversation management
|
||||
let session_id_str = extract_conversation_id(&body.conversation);
|
||||
let cascade_id = if let Some(ref sid) = session_id_str {
|
||||
match state
|
||||
.sessions
|
||||
.get_or_create(Some(sid), || state.backend.create_cascade())
|
||||
.await
|
||||
{
|
||||
Ok(sr) => sr.cascade_id,
|
||||
Err(e) => {
|
||||
return err_response(
|
||||
StatusCode::BAD_GATEWAY,
|
||||
format!("StartCascade failed: {e}"),
|
||||
"server_error",
|
||||
);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
match state.backend.create_cascade().await {
|
||||
Ok(cid) => cid,
|
||||
Err(e) => {
|
||||
return err_response(
|
||||
StatusCode::BAD_GATEWAY,
|
||||
format!("StartCascade failed: {e}"),
|
||||
"server_error",
|
||||
);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Send message
|
||||
match state
|
||||
.backend
|
||||
.send_message(&cascade_id, &user_text, model.model_enum)
|
||||
.await
|
||||
{
|
||||
Ok((status, _)) if status == 200 => {
|
||||
let bg = Arc::clone(&state.backend);
|
||||
let cid = cascade_id.clone();
|
||||
tokio::spawn(async move {
|
||||
let _ = bg.update_annotations(&cid).await;
|
||||
});
|
||||
}
|
||||
Ok((status, _)) => {
|
||||
return err_response(
|
||||
StatusCode::BAD_GATEWAY,
|
||||
format!("Antigravity returned {status}"),
|
||||
"server_error",
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
return err_response(
|
||||
StatusCode::BAD_GATEWAY,
|
||||
format!("Send message failed: {e}"),
|
||||
"server_error",
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Capture request params for response building
|
||||
let req_params = RequestParams {
|
||||
user_text: user_text.clone(),
|
||||
instructions: body.instructions.clone(),
|
||||
store: body.store,
|
||||
temperature: body.temperature.unwrap_or(1.0),
|
||||
top_p: body.top_p.unwrap_or(1.0),
|
||||
max_output_tokens: body.max_output_tokens,
|
||||
previous_response_id: body.previous_response_id.clone(),
|
||||
user: body.user.clone(),
|
||||
metadata: body.metadata.clone().unwrap_or(serde_json::json!({})),
|
||||
};
|
||||
|
||||
if body.stream {
|
||||
handle_responses_stream(
|
||||
state, response_id, model_name.to_string(), cascade_id,
|
||||
body.timeout, req_params,
|
||||
)
|
||||
.await
|
||||
} else {
|
||||
handle_responses_sync(
|
||||
state, response_id, model_name.to_string(), cascade_id,
|
||||
body.timeout, req_params,
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
/// Captured request parameters needed to echo back in the response.
|
||||
struct RequestParams {
|
||||
user_text: String,
|
||||
instructions: Option<String>,
|
||||
store: bool,
|
||||
temperature: f64,
|
||||
top_p: f64,
|
||||
max_output_tokens: Option<u64>,
|
||||
previous_response_id: Option<String>,
|
||||
user: Option<String>,
|
||||
metadata: serde_json::Value,
|
||||
}
|
||||
|
||||
/// Build Usage from the best available source:
|
||||
/// 1. MITM intercepted data (real API tokens, including cache stats)
|
||||
/// 2. LS trajectory data (real tokens, no cache info)
|
||||
/// 3. Estimation from text lengths (fallback)
|
||||
async fn usage_from_poll(
|
||||
mitm_store: &crate::mitm::store::MitmStore,
|
||||
cascade_id: &str,
|
||||
model_usage: &Option<super::polling::ModelUsage>,
|
||||
input_text: &str,
|
||||
output_text: &str,
|
||||
) -> Usage {
|
||||
// Priority 1: MITM intercepted data (most accurate — includes cache tokens)
|
||||
if let Some(mitm_usage) = mitm_store.take_usage(cascade_id).await {
|
||||
tracing::debug!(
|
||||
input = mitm_usage.input_tokens,
|
||||
output = mitm_usage.output_tokens,
|
||||
cache_read = mitm_usage.cache_read_input_tokens,
|
||||
cache_create = mitm_usage.cache_creation_input_tokens,
|
||||
thinking = mitm_usage.thinking_output_tokens,
|
||||
"Using MITM intercepted usage"
|
||||
);
|
||||
return Usage {
|
||||
input_tokens: mitm_usage.input_tokens,
|
||||
input_tokens_details: InputTokensDetails {
|
||||
cached_tokens: mitm_usage.cache_read_input_tokens,
|
||||
},
|
||||
output_tokens: mitm_usage.output_tokens,
|
||||
output_tokens_details: OutputTokensDetails {
|
||||
reasoning_tokens: mitm_usage.thinking_output_tokens,
|
||||
},
|
||||
total_tokens: mitm_usage.input_tokens + mitm_usage.output_tokens,
|
||||
};
|
||||
}
|
||||
|
||||
// Priority 2: LS trajectory data (from CHECKPOINT/metadata steps)
|
||||
if let Some(u) = model_usage {
|
||||
return Usage {
|
||||
input_tokens: u.input_tokens,
|
||||
input_tokens_details: InputTokensDetails { cached_tokens: 0 },
|
||||
output_tokens: u.output_tokens,
|
||||
output_tokens_details: OutputTokensDetails { reasoning_tokens: 0 },
|
||||
total_tokens: u.input_tokens + u.output_tokens,
|
||||
};
|
||||
}
|
||||
|
||||
// Priority 3: Estimate from text lengths
|
||||
Usage::estimate(input_text, output_text)
|
||||
}
|
||||
|
||||
// ─── Sync response ───────────────────────────────────────────────────────────
|
||||
|
||||
async fn handle_responses_sync(
|
||||
state: Arc<AppState>,
|
||||
response_id: String,
|
||||
model_name: String,
|
||||
cascade_id: String,
|
||||
timeout: u64,
|
||||
params: RequestParams,
|
||||
) -> axum::response::Response {
|
||||
let created_at = now_unix();
|
||||
let poll_result = poll_for_response(&state, &cascade_id, timeout).await;
|
||||
let completed_at = now_unix();
|
||||
let msg_id = format!(
|
||||
"msg_{}",
|
||||
uuid::Uuid::new_v4().to_string().replace('-', "")
|
||||
);
|
||||
|
||||
let usage = usage_from_poll(&state.mitm_store, &cascade_id, &poll_result.usage, ¶ms.user_text, &poll_result.text).await;
|
||||
|
||||
let resp = build_response_object(
|
||||
&response_id,
|
||||
&model_name,
|
||||
"completed",
|
||||
created_at,
|
||||
Some(completed_at),
|
||||
vec![ResponseOutput {
|
||||
output_type: "message",
|
||||
id: msg_id,
|
||||
status: "completed",
|
||||
role: "assistant",
|
||||
content: vec![OutputContent {
|
||||
content_type: "output_text",
|
||||
text: poll_result.text,
|
||||
annotations: vec![],
|
||||
}],
|
||||
}],
|
||||
Some(usage),
|
||||
params.instructions.as_deref(),
|
||||
params.store,
|
||||
params.temperature,
|
||||
params.top_p,
|
||||
params.max_output_tokens,
|
||||
params.previous_response_id.as_deref(),
|
||||
params.user.as_deref(),
|
||||
¶ms.metadata,
|
||||
poll_result.thinking_signature,
|
||||
poll_result.thinking,
|
||||
poll_result.thinking_duration,
|
||||
);
|
||||
|
||||
Json(resp).into_response()
|
||||
}
|
||||
|
||||
// ─── Streaming response ─────────────────────────────────────────────────────
|
||||
|
||||
async fn handle_responses_stream(
|
||||
state: Arc<AppState>,
|
||||
response_id: String,
|
||||
model_name: String,
|
||||
cascade_id: String,
|
||||
timeout: u64,
|
||||
params: RequestParams,
|
||||
) -> axum::response::Response {
|
||||
let stream = async_stream::stream! {
|
||||
let msg_id = format!("msg_{}", uuid::Uuid::new_v4().to_string().replace('-', ""));
|
||||
let created_at = now_unix();
|
||||
let seq = AtomicU32::new(0);
|
||||
let next_seq = || seq.fetch_add(1, Ordering::Relaxed);
|
||||
const CONTENT_IDX: u32 = 0;
|
||||
const OUTPUT_IDX: u32 = 0;
|
||||
|
||||
// Build the in-progress response shell (no output yet)
|
||||
let in_progress_resp = build_response_object(
|
||||
&response_id, &model_name, "in_progress", created_at, None,
|
||||
vec![], None,
|
||||
params.instructions.as_deref(), params.store,
|
||||
params.temperature, params.top_p,
|
||||
params.max_output_tokens, params.previous_response_id.as_deref(),
|
||||
params.user.as_deref(), ¶ms.metadata,
|
||||
None, None, None,
|
||||
);
|
||||
let resp_json = response_to_json(&in_progress_resp);
|
||||
|
||||
// 1. response.created
|
||||
yield Ok::<_, std::convert::Infallible>(responses_sse_event(
|
||||
"response.created",
|
||||
serde_json::json!({
|
||||
"type": "response.created",
|
||||
"sequence_number": next_seq(),
|
||||
"response": resp_json,
|
||||
}),
|
||||
));
|
||||
|
||||
// 2. response.in_progress
|
||||
yield Ok(responses_sse_event(
|
||||
"response.in_progress",
|
||||
serde_json::json!({
|
||||
"type": "response.in_progress",
|
||||
"sequence_number": next_seq(),
|
||||
"response": resp_json,
|
||||
}),
|
||||
));
|
||||
|
||||
// 3. response.output_item.added
|
||||
yield Ok(responses_sse_event(
|
||||
"response.output_item.added",
|
||||
serde_json::json!({
|
||||
"type": "response.output_item.added",
|
||||
"sequence_number": next_seq(),
|
||||
"output_index": OUTPUT_IDX,
|
||||
"item": {
|
||||
"type": "message",
|
||||
"id": &msg_id,
|
||||
"status": "in_progress",
|
||||
"role": "assistant",
|
||||
"content": [],
|
||||
}
|
||||
}),
|
||||
));
|
||||
|
||||
// 4. response.content_part.added
|
||||
yield Ok(responses_sse_event(
|
||||
"response.content_part.added",
|
||||
serde_json::json!({
|
||||
"type": "response.content_part.added",
|
||||
"sequence_number": next_seq(),
|
||||
"output_index": OUTPUT_IDX,
|
||||
"content_index": CONTENT_IDX,
|
||||
"part": {
|
||||
"type": "output_text",
|
||||
"text": "",
|
||||
"annotations": [],
|
||||
}
|
||||
}),
|
||||
));
|
||||
|
||||
// 5. Poll and emit text deltas
|
||||
let start = std::time::Instant::now();
|
||||
let mut last_text = String::new();
|
||||
|
||||
while start.elapsed().as_secs() < timeout {
|
||||
if let Ok((status, data)) = state.backend.get_steps(&cascade_id).await {
|
||||
if status == 200 {
|
||||
if let Some(steps) = data["steps"].as_array() {
|
||||
let text = extract_response_text(steps);
|
||||
|
||||
if !text.is_empty() && text != last_text {
|
||||
let new_content = if text.len() > last_text.len()
|
||||
&& text.starts_with(&*last_text)
|
||||
{
|
||||
&text[last_text.len()..]
|
||||
} else {
|
||||
&text
|
||||
};
|
||||
|
||||
if !new_content.is_empty() {
|
||||
yield Ok(responses_sse_event(
|
||||
"response.output_text.delta",
|
||||
serde_json::json!({
|
||||
"type": "response.output_text.delta",
|
||||
"sequence_number": next_seq(),
|
||||
"item_id": &msg_id,
|
||||
"output_index": OUTPUT_IDX,
|
||||
"content_index": CONTENT_IDX,
|
||||
"delta": new_content,
|
||||
}),
|
||||
));
|
||||
last_text = text.to_string();
|
||||
}
|
||||
}
|
||||
|
||||
// Check if response is done AND we have text
|
||||
if is_response_done(steps) && !last_text.is_empty() {
|
||||
debug!("Response done, text length={}", last_text.len());
|
||||
let mu = extract_model_usage(steps);
|
||||
let usage = usage_from_poll(&state.mitm_store, &cascade_id, &mu, ¶ms.user_text, &last_text).await;
|
||||
let ts = extract_thinking_signature(steps);
|
||||
let tc = extract_thinking_content(steps);
|
||||
let td = extract_thinking_duration(steps);
|
||||
for evt in completion_events(
|
||||
&response_id, &model_name, &msg_id,
|
||||
OUTPUT_IDX, CONTENT_IDX, &last_text, usage,
|
||||
created_at, &seq, ¶ms, ts, tc, td,
|
||||
) {
|
||||
yield Ok(evt);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// IDLE fallback: check trajectory status periodically
|
||||
let step_count = steps.len();
|
||||
if step_count > 4 && step_count % 5 == 0 {
|
||||
if let Ok((ts, td)) = state.backend.get_trajectory(&cascade_id).await {
|
||||
if ts == 200 {
|
||||
let run_status = td["status"].as_str().unwrap_or("");
|
||||
if run_status.contains("IDLE") && !last_text.is_empty() {
|
||||
debug!("Trajectory IDLE, text length={}", last_text.len());
|
||||
let mu = extract_model_usage(steps);
|
||||
let usage = usage_from_poll(&state.mitm_store, &cascade_id, &mu, ¶ms.user_text, &last_text).await;
|
||||
let ts = extract_thinking_signature(steps);
|
||||
let tc = extract_thinking_content(steps);
|
||||
let td = extract_thinking_duration(steps);
|
||||
for evt in completion_events(
|
||||
&response_id, &model_name, &msg_id,
|
||||
OUTPUT_IDX, CONTENT_IDX, &last_text, usage,
|
||||
created_at, &seq, ¶ms, ts, tc, td,
|
||||
) {
|
||||
yield Ok(evt);
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let poll_ms: u64 = rand::thread_rng().gen_range(800..1200);
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(poll_ms)).await;
|
||||
}
|
||||
|
||||
// Timeout — emit incomplete response
|
||||
let timeout_resp = build_response_object(
|
||||
&response_id, &model_name, "incomplete", created_at, None,
|
||||
vec![], Some(Usage::estimate(¶ms.user_text, "")),
|
||||
params.instructions.as_deref(), params.store,
|
||||
params.temperature, params.top_p,
|
||||
params.max_output_tokens, params.previous_response_id.as_deref(),
|
||||
params.user.as_deref(), ¶ms.metadata,
|
||||
None, None, None,
|
||||
);
|
||||
yield Ok(responses_sse_event(
|
||||
"response.completed",
|
||||
serde_json::json!({
|
||||
"type": "response.completed",
|
||||
"sequence_number": next_seq(),
|
||||
"response": response_to_json(&timeout_resp),
|
||||
}),
|
||||
));
|
||||
};
|
||||
|
||||
Sse::new(stream)
|
||||
.keep_alive(
|
||||
axum::response::sse::KeepAlive::new()
|
||||
.interval(std::time::Duration::from_secs(15))
|
||||
.text(""),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
// ─── SSE completion events ───────────────────────────────────────────────────
|
||||
|
||||
/// Build the completion SSE events sequence matching the official protocol:
|
||||
/// 1. response.output_text.done
|
||||
/// 2. response.content_part.done
|
||||
/// 3. response.output_item.done
|
||||
/// 4. response.completed
|
||||
fn completion_events(
|
||||
resp_id: &str,
|
||||
model: &str,
|
||||
msg_id: &str,
|
||||
out_idx: u32,
|
||||
content_idx: u32,
|
||||
text: &str,
|
||||
usage: Usage,
|
||||
created_at: u64,
|
||||
seq: &AtomicU32,
|
||||
params: &RequestParams,
|
||||
thinking_signature: Option<String>,
|
||||
thinking: Option<String>,
|
||||
thinking_duration: Option<String>,
|
||||
) -> Vec<Event> {
|
||||
let next_seq = || seq.fetch_add(1, Ordering::Relaxed);
|
||||
let completed_at = now_unix();
|
||||
|
||||
let output_item = serde_json::json!({
|
||||
"type": "message",
|
||||
"id": msg_id,
|
||||
"status": "completed",
|
||||
"role": "assistant",
|
||||
"content": [{
|
||||
"type": "output_text",
|
||||
"text": text,
|
||||
"annotations": [],
|
||||
}],
|
||||
});
|
||||
|
||||
let completed_resp = build_response_object(
|
||||
resp_id, model, "completed", created_at, Some(completed_at),
|
||||
vec![ResponseOutput {
|
||||
output_type: "message",
|
||||
id: msg_id.to_string(),
|
||||
status: "completed",
|
||||
role: "assistant",
|
||||
content: vec![OutputContent {
|
||||
content_type: "output_text",
|
||||
text: text.to_string(),
|
||||
annotations: vec![],
|
||||
}],
|
||||
}],
|
||||
Some(usage),
|
||||
params.instructions.as_deref(),
|
||||
params.store,
|
||||
params.temperature,
|
||||
params.top_p,
|
||||
params.max_output_tokens,
|
||||
params.previous_response_id.as_deref(),
|
||||
params.user.as_deref(),
|
||||
¶ms.metadata,
|
||||
thinking_signature,
|
||||
thinking,
|
||||
thinking_duration,
|
||||
);
|
||||
|
||||
vec![
|
||||
// 1. response.output_text.done
|
||||
responses_sse_event(
|
||||
"response.output_text.done",
|
||||
serde_json::json!({
|
||||
"type": "response.output_text.done",
|
||||
"sequence_number": next_seq(),
|
||||
"item_id": msg_id,
|
||||
"output_index": out_idx,
|
||||
"content_index": content_idx,
|
||||
"text": text,
|
||||
}),
|
||||
),
|
||||
// 2. response.content_part.done
|
||||
responses_sse_event(
|
||||
"response.content_part.done",
|
||||
serde_json::json!({
|
||||
"type": "response.content_part.done",
|
||||
"sequence_number": next_seq(),
|
||||
"output_index": out_idx,
|
||||
"content_index": content_idx,
|
||||
"part": {
|
||||
"type": "output_text",
|
||||
"text": text,
|
||||
"annotations": [],
|
||||
},
|
||||
}),
|
||||
),
|
||||
// 3. response.output_item.done
|
||||
responses_sse_event(
|
||||
"response.output_item.done",
|
||||
serde_json::json!({
|
||||
"type": "response.output_item.done",
|
||||
"sequence_number": next_seq(),
|
||||
"output_index": out_idx,
|
||||
"item": output_item,
|
||||
}),
|
||||
),
|
||||
// 4. response.completed
|
||||
responses_sse_event(
|
||||
"response.completed",
|
||||
serde_json::json!({
|
||||
"type": "response.completed",
|
||||
"sequence_number": next_seq(),
|
||||
"response": response_to_json(&completed_resp),
|
||||
}),
|
||||
),
|
||||
]
|
||||
}
|
||||
241
src/api/types.rs
Normal file
241
src/api/types.rs
Normal file
@@ -0,0 +1,241 @@
|
||||
//! Request/response types for the OpenAI-compatible API.
|
||||
//!
|
||||
//! All response shapes strictly match the official OpenAI Responses API spec:
|
||||
//! https://platform.openai.com/docs/api-reference/responses
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
// ─── Request types ───────────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub(crate) struct ResponsesRequest {
|
||||
pub model: Option<String>,
|
||||
pub input: serde_json::Value,
|
||||
#[serde(default)]
|
||||
pub instructions: Option<String>,
|
||||
#[serde(default)]
|
||||
pub stream: bool,
|
||||
#[serde(default = "default_timeout")]
|
||||
pub timeout: u64,
|
||||
pub conversation: Option<serde_json::Value>,
|
||||
#[serde(default = "default_true")]
|
||||
pub store: bool,
|
||||
#[serde(default)]
|
||||
pub temperature: Option<f64>,
|
||||
#[serde(default)]
|
||||
pub top_p: Option<f64>,
|
||||
#[serde(default)]
|
||||
pub max_output_tokens: Option<u64>,
|
||||
#[serde(default)]
|
||||
pub previous_response_id: Option<String>,
|
||||
#[serde(default)]
|
||||
pub metadata: Option<serde_json::Value>,
|
||||
#[serde(default)]
|
||||
pub user: Option<String>,
|
||||
}
|
||||
|
||||
/// Chat Completions request (OpenAI-compatible).
|
||||
#[derive(Deserialize)]
|
||||
pub(crate) struct CompletionRequest {
|
||||
pub model: Option<String>,
|
||||
pub messages: Vec<CompletionMessage>,
|
||||
#[serde(default)]
|
||||
pub stream: bool,
|
||||
#[serde(default = "default_timeout")]
|
||||
pub timeout: u64,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub(crate) struct CompletionMessage {
|
||||
pub role: String,
|
||||
pub content: serde_json::Value,
|
||||
}
|
||||
|
||||
fn default_timeout() -> u64 {
|
||||
120
|
||||
}
|
||||
|
||||
fn default_true() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
// ─── Response types (official OpenAI Responses API shape) ────────────────────
|
||||
|
||||
/// Top-level Response object — matches OpenAI exactly.
|
||||
#[derive(Serialize, Clone)]
|
||||
pub(crate) struct ResponsesResponse {
|
||||
pub id: String,
|
||||
pub object: &'static str,
|
||||
pub created_at: u64,
|
||||
pub status: &'static str,
|
||||
#[serde(serialize_with = "serialize_option_u64")]
|
||||
pub completed_at: Option<u64>,
|
||||
pub error: Option<serde_json::Value>,
|
||||
pub incomplete_details: Option<serde_json::Value>,
|
||||
pub instructions: Option<String>,
|
||||
#[serde(serialize_with = "serialize_option_u64")]
|
||||
pub max_output_tokens: Option<u64>,
|
||||
pub model: String,
|
||||
pub output: Vec<ResponseOutput>,
|
||||
pub parallel_tool_calls: bool,
|
||||
pub previous_response_id: Option<String>,
|
||||
pub reasoning: Reasoning,
|
||||
pub store: bool,
|
||||
pub temperature: f64,
|
||||
pub text: TextFormat,
|
||||
pub tool_choice: &'static str,
|
||||
pub tools: Vec<serde_json::Value>,
|
||||
pub top_p: f64,
|
||||
pub truncation: &'static str,
|
||||
pub usage: Option<Usage>,
|
||||
pub user: Option<String>,
|
||||
pub metadata: serde_json::Value,
|
||||
/// Proxy extension: opaque thinking verification signature.
|
||||
/// Present for all models. Required for multi-turn chaining with thinking models.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub thinking_signature: Option<String>,
|
||||
/// Proxy extension: the model's internal reasoning/thinking content.
|
||||
/// Available for all models (Opus, Gemini Flash, Gemini Pro).
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub thinking: Option<String>,
|
||||
/// Proxy extension: time spent thinking (e.g. "0.041999832s").
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub thinking_duration: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Clone)]
|
||||
pub(crate) struct ResponseOutput {
|
||||
#[serde(rename = "type")]
|
||||
pub output_type: &'static str,
|
||||
pub id: String,
|
||||
pub status: &'static str,
|
||||
pub role: &'static str,
|
||||
pub content: Vec<OutputContent>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Clone)]
|
||||
pub(crate) struct OutputContent {
|
||||
#[serde(rename = "type")]
|
||||
pub content_type: &'static str,
|
||||
pub text: String,
|
||||
pub annotations: Vec<serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Clone)]
|
||||
pub(crate) struct Usage {
|
||||
pub input_tokens: u64,
|
||||
pub input_tokens_details: InputTokensDetails,
|
||||
pub output_tokens: u64,
|
||||
pub output_tokens_details: OutputTokensDetails,
|
||||
pub total_tokens: u64,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Clone)]
|
||||
pub(crate) struct InputTokensDetails {
|
||||
pub cached_tokens: u64,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Clone)]
|
||||
pub(crate) struct OutputTokensDetails {
|
||||
pub reasoning_tokens: u64,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Clone)]
|
||||
pub(crate) struct Reasoning {
|
||||
pub effort: Option<String>,
|
||||
pub summary: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Clone)]
|
||||
pub(crate) struct TextFormat {
|
||||
pub format: TextFormatInner,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Clone)]
|
||||
pub(crate) struct TextFormatInner {
|
||||
#[serde(rename = "type")]
|
||||
pub format_type: &'static str,
|
||||
}
|
||||
|
||||
impl Usage {
|
||||
/// Estimate token counts from actual text.
|
||||
/// Uses ~4 chars/token heuristic (standard GPT tokenizer average).
|
||||
pub fn estimate(input_text: &str, output_text: &str) -> Self {
|
||||
let input_tokens = (input_text.len() as u64 + 3) / 4;
|
||||
let output_tokens = (output_text.len() as u64 + 3) / 4;
|
||||
Self {
|
||||
input_tokens,
|
||||
input_tokens_details: InputTokensDetails { cached_tokens: 0 },
|
||||
output_tokens,
|
||||
output_tokens_details: OutputTokensDetails {
|
||||
reasoning_tokens: 0,
|
||||
},
|
||||
total_tokens: input_tokens + output_tokens,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Usage {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
input_tokens: 0,
|
||||
input_tokens_details: InputTokensDetails { cached_tokens: 0 },
|
||||
output_tokens: 0,
|
||||
output_tokens_details: OutputTokensDetails {
|
||||
reasoning_tokens: 0,
|
||||
},
|
||||
total_tokens: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Reasoning {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
effort: None,
|
||||
summary: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for TextFormat {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
format: TextFormatInner {
|
||||
format_type: "text",
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Helpers ─────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Serialize Option<u64> as either the number or JSON null (not omitted).
|
||||
fn serialize_option_u64<S>(val: &Option<u64>, s: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
match val {
|
||||
Some(v) => s.serialize_u64(*v),
|
||||
None => s.serialize_none(),
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Shared types ────────────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub(crate) struct TokenRequest {
|
||||
pub token: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub(crate) struct ErrorResponse {
|
||||
pub error: ErrorDetail,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub(crate) struct ErrorDetail {
|
||||
pub message: String,
|
||||
#[serde(rename = "type")]
|
||||
pub error_type: String,
|
||||
}
|
||||
36
src/api/util.rs
Normal file
36
src/api/util.rs
Normal file
@@ -0,0 +1,36 @@
|
||||
//! Shared utilities for API handlers.
|
||||
|
||||
use axum::{
|
||||
http::StatusCode,
|
||||
response::{sse::Event, IntoResponse, Json},
|
||||
};
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
use super::types::{ErrorDetail, ErrorResponse};
|
||||
|
||||
pub(crate) fn err_response(
|
||||
status: StatusCode,
|
||||
message: String,
|
||||
error_type: &str,
|
||||
) -> axum::response::Response {
|
||||
let body = ErrorResponse {
|
||||
error: ErrorDetail {
|
||||
message,
|
||||
error_type: error_type.to_string(),
|
||||
},
|
||||
};
|
||||
(status, Json(body)).into_response()
|
||||
}
|
||||
|
||||
pub(crate) fn now_unix() -> u64 {
|
||||
SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs()
|
||||
}
|
||||
|
||||
pub(crate) fn responses_sse_event(event_type: &str, data: serde_json::Value) -> Event {
|
||||
Event::default()
|
||||
.event(event_type)
|
||||
.data(serde_json::to_string(&data).unwrap())
|
||||
}
|
||||
Reference in New Issue
Block a user