feat: initial commit — antigravity proxy with MITM, standalone LS, and snapshot tooling

This commit is contained in:
Nikketryhard
2026-02-14 02:24:35 -06:00
commit d5e7f09225
30 changed files with 9980 additions and 0 deletions

686
src/api/responses.rs Normal file
View File

@@ -0,0 +1,686 @@
//! OpenAI Responses API (/v1/responses) handler.
//!
//! Strictly adheres to the official OpenAI Responses API protocol:
//! https://platform.openai.com/docs/api-reference/responses
use axum::{
extract::State,
http::StatusCode,
response::{sse::Event, IntoResponse, Json, Sse},
};
use rand::Rng;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
use tracing::{debug, info};
use super::models::{lookup_model, DEFAULT_MODEL, MODELS};
use super::polling::{extract_response_text, is_response_done, poll_for_response, extract_model_usage, extract_thinking_signature, extract_thinking_content, extract_thinking_duration};
use super::types::*;
use super::util::{err_response, now_unix, responses_sse_event};
use super::AppState;
// ─── Input extraction ────────────────────────────────────────────────────────
/// Extract user text from Responses API `input` field.
fn extract_responses_input(input: &serde_json::Value, instructions: Option<&str>) -> String {
let user_text = match input {
serde_json::Value::String(s) => s.clone(),
serde_json::Value::Array(items) => {
items
.iter()
.rev()
.find(|item| item["role"].as_str() == Some("user"))
.and_then(|item| match &item["content"] {
serde_json::Value::String(s) => Some(s.clone()),
serde_json::Value::Array(parts) => Some(
parts
.iter()
.filter(|p| {
let t = p["type"].as_str().unwrap_or("");
t == "input_text" || t == "text"
})
.filter_map(|p| p["text"].as_str())
.collect::<Vec<_>>()
.join(" "),
),
_ => None,
})
.unwrap_or_default()
}
_ => String::new(),
};
match instructions {
Some(inst) if !inst.is_empty() => format!("{inst}\n\n{user_text}"),
_ => user_text,
}
}
/// Extract conversation/session ID from Responses API `conversation` field.
fn extract_conversation_id(conv: &Option<serde_json::Value>) -> Option<String> {
match conv {
Some(serde_json::Value::String(s)) => Some(s.clone()),
Some(obj) => obj["id"].as_str().map(|s| s.to_string()),
None => None,
}
}
/// Build a full Response object matching the official OpenAI schema.
fn build_response_object(
id: &str,
model: &str,
status: &'static str,
created_at: u64,
completed_at: Option<u64>,
output: Vec<ResponseOutput>,
usage: Option<Usage>,
instructions: Option<&str>,
store: bool,
temperature: f64,
top_p: f64,
max_output_tokens: Option<u64>,
previous_response_id: Option<&str>,
user: Option<&str>,
metadata: &serde_json::Value,
thinking_signature: Option<String>,
thinking: Option<String>,
thinking_duration: Option<String>,
) -> ResponsesResponse {
ResponsesResponse {
id: id.to_string(),
object: "response",
created_at,
status,
completed_at,
error: None,
incomplete_details: None,
instructions: instructions.map(|s| s.to_string()),
max_output_tokens,
model: model.to_string(),
output,
parallel_tool_calls: true,
previous_response_id: previous_response_id.map(|s| s.to_string()),
reasoning: Reasoning::default(),
store,
temperature,
text: TextFormat::default(),
tool_choice: "auto",
tools: vec![],
top_p,
truncation: "disabled",
usage,
user: user.map(|s| s.to_string()),
metadata: metadata.clone(),
thinking_signature,
thinking,
thinking_duration,
}
}
/// Serialize a ResponsesResponse to serde_json::Value for SSE embedding.
fn response_to_json(resp: &ResponsesResponse) -> serde_json::Value {
serde_json::to_value(resp).unwrap_or(serde_json::json!({}))
}
// ─── Handler ─────────────────────────────────────────────────────────────────
pub(crate) async fn handle_responses(
State(state): State<Arc<AppState>>,
Json(body): Json<ResponsesRequest>,
) -> axum::response::Response {
info!(
"POST /v1/responses model={} stream={}",
body.model.as_deref().unwrap_or(DEFAULT_MODEL),
body.stream
);
let model_name = body.model.as_deref().unwrap_or(DEFAULT_MODEL);
let model = match lookup_model(model_name) {
Some(m) => m,
None => {
let names: Vec<&str> = MODELS.iter().map(|m| m.name).collect();
return err_response(
StatusCode::BAD_REQUEST,
format!("Unknown model: {model_name}. Available: {names:?}"),
"invalid_request_error",
);
}
};
let token = state.backend.oauth_token().await;
if token.is_empty() {
return err_response(
StatusCode::UNAUTHORIZED,
"No OAuth token. POST to /v1/token or set ANTIGRAVITY_OAUTH_TOKEN env var.".into(),
"authentication_error",
);
}
let user_text = extract_responses_input(&body.input, body.instructions.as_deref());
if user_text.is_empty() {
return err_response(
StatusCode::BAD_REQUEST,
"No user input found".to_string(),
"invalid_request_error",
);
}
let response_id = format!(
"resp_{}",
uuid::Uuid::new_v4().to_string().replace('-', "")
);
// Session/conversation management
let session_id_str = extract_conversation_id(&body.conversation);
let cascade_id = if let Some(ref sid) = session_id_str {
match state
.sessions
.get_or_create(Some(sid), || state.backend.create_cascade())
.await
{
Ok(sr) => sr.cascade_id,
Err(e) => {
return err_response(
StatusCode::BAD_GATEWAY,
format!("StartCascade failed: {e}"),
"server_error",
);
}
}
} else {
match state.backend.create_cascade().await {
Ok(cid) => cid,
Err(e) => {
return err_response(
StatusCode::BAD_GATEWAY,
format!("StartCascade failed: {e}"),
"server_error",
);
}
}
};
// Send message
match state
.backend
.send_message(&cascade_id, &user_text, model.model_enum)
.await
{
Ok((status, _)) if status == 200 => {
let bg = Arc::clone(&state.backend);
let cid = cascade_id.clone();
tokio::spawn(async move {
let _ = bg.update_annotations(&cid).await;
});
}
Ok((status, _)) => {
return err_response(
StatusCode::BAD_GATEWAY,
format!("Antigravity returned {status}"),
"server_error",
);
}
Err(e) => {
return err_response(
StatusCode::BAD_GATEWAY,
format!("Send message failed: {e}"),
"server_error",
);
}
}
// Capture request params for response building
let req_params = RequestParams {
user_text: user_text.clone(),
instructions: body.instructions.clone(),
store: body.store,
temperature: body.temperature.unwrap_or(1.0),
top_p: body.top_p.unwrap_or(1.0),
max_output_tokens: body.max_output_tokens,
previous_response_id: body.previous_response_id.clone(),
user: body.user.clone(),
metadata: body.metadata.clone().unwrap_or(serde_json::json!({})),
};
if body.stream {
handle_responses_stream(
state, response_id, model_name.to_string(), cascade_id,
body.timeout, req_params,
)
.await
} else {
handle_responses_sync(
state, response_id, model_name.to_string(), cascade_id,
body.timeout, req_params,
)
.await
}
}
/// Captured request parameters needed to echo back in the response.
struct RequestParams {
user_text: String,
instructions: Option<String>,
store: bool,
temperature: f64,
top_p: f64,
max_output_tokens: Option<u64>,
previous_response_id: Option<String>,
user: Option<String>,
metadata: serde_json::Value,
}
/// Build Usage from the best available source:
/// 1. MITM intercepted data (real API tokens, including cache stats)
/// 2. LS trajectory data (real tokens, no cache info)
/// 3. Estimation from text lengths (fallback)
async fn usage_from_poll(
mitm_store: &crate::mitm::store::MitmStore,
cascade_id: &str,
model_usage: &Option<super::polling::ModelUsage>,
input_text: &str,
output_text: &str,
) -> Usage {
// Priority 1: MITM intercepted data (most accurate — includes cache tokens)
if let Some(mitm_usage) = mitm_store.take_usage(cascade_id).await {
tracing::debug!(
input = mitm_usage.input_tokens,
output = mitm_usage.output_tokens,
cache_read = mitm_usage.cache_read_input_tokens,
cache_create = mitm_usage.cache_creation_input_tokens,
thinking = mitm_usage.thinking_output_tokens,
"Using MITM intercepted usage"
);
return Usage {
input_tokens: mitm_usage.input_tokens,
input_tokens_details: InputTokensDetails {
cached_tokens: mitm_usage.cache_read_input_tokens,
},
output_tokens: mitm_usage.output_tokens,
output_tokens_details: OutputTokensDetails {
reasoning_tokens: mitm_usage.thinking_output_tokens,
},
total_tokens: mitm_usage.input_tokens + mitm_usage.output_tokens,
};
}
// Priority 2: LS trajectory data (from CHECKPOINT/metadata steps)
if let Some(u) = model_usage {
return Usage {
input_tokens: u.input_tokens,
input_tokens_details: InputTokensDetails { cached_tokens: 0 },
output_tokens: u.output_tokens,
output_tokens_details: OutputTokensDetails { reasoning_tokens: 0 },
total_tokens: u.input_tokens + u.output_tokens,
};
}
// Priority 3: Estimate from text lengths
Usage::estimate(input_text, output_text)
}
// ─── Sync response ───────────────────────────────────────────────────────────
async fn handle_responses_sync(
state: Arc<AppState>,
response_id: String,
model_name: String,
cascade_id: String,
timeout: u64,
params: RequestParams,
) -> axum::response::Response {
let created_at = now_unix();
let poll_result = poll_for_response(&state, &cascade_id, timeout).await;
let completed_at = now_unix();
let msg_id = format!(
"msg_{}",
uuid::Uuid::new_v4().to_string().replace('-', "")
);
let usage = usage_from_poll(&state.mitm_store, &cascade_id, &poll_result.usage, &params.user_text, &poll_result.text).await;
let resp = build_response_object(
&response_id,
&model_name,
"completed",
created_at,
Some(completed_at),
vec![ResponseOutput {
output_type: "message",
id: msg_id,
status: "completed",
role: "assistant",
content: vec![OutputContent {
content_type: "output_text",
text: poll_result.text,
annotations: vec![],
}],
}],
Some(usage),
params.instructions.as_deref(),
params.store,
params.temperature,
params.top_p,
params.max_output_tokens,
params.previous_response_id.as_deref(),
params.user.as_deref(),
&params.metadata,
poll_result.thinking_signature,
poll_result.thinking,
poll_result.thinking_duration,
);
Json(resp).into_response()
}
// ─── Streaming response ─────────────────────────────────────────────────────
async fn handle_responses_stream(
state: Arc<AppState>,
response_id: String,
model_name: String,
cascade_id: String,
timeout: u64,
params: RequestParams,
) -> axum::response::Response {
let stream = async_stream::stream! {
let msg_id = format!("msg_{}", uuid::Uuid::new_v4().to_string().replace('-', ""));
let created_at = now_unix();
let seq = AtomicU32::new(0);
let next_seq = || seq.fetch_add(1, Ordering::Relaxed);
const CONTENT_IDX: u32 = 0;
const OUTPUT_IDX: u32 = 0;
// Build the in-progress response shell (no output yet)
let in_progress_resp = build_response_object(
&response_id, &model_name, "in_progress", created_at, None,
vec![], None,
params.instructions.as_deref(), params.store,
params.temperature, params.top_p,
params.max_output_tokens, params.previous_response_id.as_deref(),
params.user.as_deref(), &params.metadata,
None, None, None,
);
let resp_json = response_to_json(&in_progress_resp);
// 1. response.created
yield Ok::<_, std::convert::Infallible>(responses_sse_event(
"response.created",
serde_json::json!({
"type": "response.created",
"sequence_number": next_seq(),
"response": resp_json,
}),
));
// 2. response.in_progress
yield Ok(responses_sse_event(
"response.in_progress",
serde_json::json!({
"type": "response.in_progress",
"sequence_number": next_seq(),
"response": resp_json,
}),
));
// 3. response.output_item.added
yield Ok(responses_sse_event(
"response.output_item.added",
serde_json::json!({
"type": "response.output_item.added",
"sequence_number": next_seq(),
"output_index": OUTPUT_IDX,
"item": {
"type": "message",
"id": &msg_id,
"status": "in_progress",
"role": "assistant",
"content": [],
}
}),
));
// 4. response.content_part.added
yield Ok(responses_sse_event(
"response.content_part.added",
serde_json::json!({
"type": "response.content_part.added",
"sequence_number": next_seq(),
"output_index": OUTPUT_IDX,
"content_index": CONTENT_IDX,
"part": {
"type": "output_text",
"text": "",
"annotations": [],
}
}),
));
// 5. Poll and emit text deltas
let start = std::time::Instant::now();
let mut last_text = String::new();
while start.elapsed().as_secs() < timeout {
if let Ok((status, data)) = state.backend.get_steps(&cascade_id).await {
if status == 200 {
if let Some(steps) = data["steps"].as_array() {
let text = extract_response_text(steps);
if !text.is_empty() && text != last_text {
let new_content = if text.len() > last_text.len()
&& text.starts_with(&*last_text)
{
&text[last_text.len()..]
} else {
&text
};
if !new_content.is_empty() {
yield Ok(responses_sse_event(
"response.output_text.delta",
serde_json::json!({
"type": "response.output_text.delta",
"sequence_number": next_seq(),
"item_id": &msg_id,
"output_index": OUTPUT_IDX,
"content_index": CONTENT_IDX,
"delta": new_content,
}),
));
last_text = text.to_string();
}
}
// Check if response is done AND we have text
if is_response_done(steps) && !last_text.is_empty() {
debug!("Response done, text length={}", last_text.len());
let mu = extract_model_usage(steps);
let usage = usage_from_poll(&state.mitm_store, &cascade_id, &mu, &params.user_text, &last_text).await;
let ts = extract_thinking_signature(steps);
let tc = extract_thinking_content(steps);
let td = extract_thinking_duration(steps);
for evt in completion_events(
&response_id, &model_name, &msg_id,
OUTPUT_IDX, CONTENT_IDX, &last_text, usage,
created_at, &seq, &params, ts, tc, td,
) {
yield Ok(evt);
}
return;
}
// IDLE fallback: check trajectory status periodically
let step_count = steps.len();
if step_count > 4 && step_count % 5 == 0 {
if let Ok((ts, td)) = state.backend.get_trajectory(&cascade_id).await {
if ts == 200 {
let run_status = td["status"].as_str().unwrap_or("");
if run_status.contains("IDLE") && !last_text.is_empty() {
debug!("Trajectory IDLE, text length={}", last_text.len());
let mu = extract_model_usage(steps);
let usage = usage_from_poll(&state.mitm_store, &cascade_id, &mu, &params.user_text, &last_text).await;
let ts = extract_thinking_signature(steps);
let tc = extract_thinking_content(steps);
let td = extract_thinking_duration(steps);
for evt in completion_events(
&response_id, &model_name, &msg_id,
OUTPUT_IDX, CONTENT_IDX, &last_text, usage,
created_at, &seq, &params, ts, tc, td,
) {
yield Ok(evt);
}
return;
}
}
}
}
}
}
}
let poll_ms: u64 = rand::thread_rng().gen_range(800..1200);
tokio::time::sleep(tokio::time::Duration::from_millis(poll_ms)).await;
}
// Timeout — emit incomplete response
let timeout_resp = build_response_object(
&response_id, &model_name, "incomplete", created_at, None,
vec![], Some(Usage::estimate(&params.user_text, "")),
params.instructions.as_deref(), params.store,
params.temperature, params.top_p,
params.max_output_tokens, params.previous_response_id.as_deref(),
params.user.as_deref(), &params.metadata,
None, None, None,
);
yield Ok(responses_sse_event(
"response.completed",
serde_json::json!({
"type": "response.completed",
"sequence_number": next_seq(),
"response": response_to_json(&timeout_resp),
}),
));
};
Sse::new(stream)
.keep_alive(
axum::response::sse::KeepAlive::new()
.interval(std::time::Duration::from_secs(15))
.text(""),
)
.into_response()
}
// ─── SSE completion events ───────────────────────────────────────────────────
/// Build the completion SSE events sequence matching the official protocol:
/// 1. response.output_text.done
/// 2. response.content_part.done
/// 3. response.output_item.done
/// 4. response.completed
fn completion_events(
resp_id: &str,
model: &str,
msg_id: &str,
out_idx: u32,
content_idx: u32,
text: &str,
usage: Usage,
created_at: u64,
seq: &AtomicU32,
params: &RequestParams,
thinking_signature: Option<String>,
thinking: Option<String>,
thinking_duration: Option<String>,
) -> Vec<Event> {
let next_seq = || seq.fetch_add(1, Ordering::Relaxed);
let completed_at = now_unix();
let output_item = serde_json::json!({
"type": "message",
"id": msg_id,
"status": "completed",
"role": "assistant",
"content": [{
"type": "output_text",
"text": text,
"annotations": [],
}],
});
let completed_resp = build_response_object(
resp_id, model, "completed", created_at, Some(completed_at),
vec![ResponseOutput {
output_type: "message",
id: msg_id.to_string(),
status: "completed",
role: "assistant",
content: vec![OutputContent {
content_type: "output_text",
text: text.to_string(),
annotations: vec![],
}],
}],
Some(usage),
params.instructions.as_deref(),
params.store,
params.temperature,
params.top_p,
params.max_output_tokens,
params.previous_response_id.as_deref(),
params.user.as_deref(),
&params.metadata,
thinking_signature,
thinking,
thinking_duration,
);
vec![
// 1. response.output_text.done
responses_sse_event(
"response.output_text.done",
serde_json::json!({
"type": "response.output_text.done",
"sequence_number": next_seq(),
"item_id": msg_id,
"output_index": out_idx,
"content_index": content_idx,
"text": text,
}),
),
// 2. response.content_part.done
responses_sse_event(
"response.content_part.done",
serde_json::json!({
"type": "response.content_part.done",
"sequence_number": next_seq(),
"output_index": out_idx,
"content_index": content_idx,
"part": {
"type": "output_text",
"text": text,
"annotations": [],
},
}),
),
// 3. response.output_item.done
responses_sse_event(
"response.output_item.done",
serde_json::json!({
"type": "response.output_item.done",
"sequence_number": next_seq(),
"output_index": out_idx,
"item": output_item,
}),
),
// 4. response.completed
responses_sse_event(
"response.completed",
serde_json::json!({
"type": "response.completed",
"sequence_number": next_seq(),
"response": response_to_json(&completed_resp),
}),
),
]
}