//! Axum API server — OpenAI-compatible Responses + Chat Completions endpoints. mod completions; mod gemini; mod models; mod polling; mod responses; mod search; mod types; mod util; use crate::constants::safe_truncate; use crate::session::SessionManager; use axum::{ extract::{DefaultBodyLimit, State}, http::StatusCode, response::{IntoResponse, Json}, routing::{delete, get, post}, Router, }; use std::sync::Arc; use tower_http::cors::CorsLayer; use tracing::warn; use self::models::MODELS; use self::types::TokenRequest; // ─── Shared state ──────────────────────────────────────────────────────────── pub struct AppState { pub backend: Arc, pub sessions: SessionManager, pub mitm_store: crate::mitm::store::MitmStore, pub quota_store: crate::quota::QuotaStore, /// Whether the MITM proxy is active (false when --no-mitm). pub mitm_enabled: bool, /// Per-call debug trace collector. pub trace: crate::trace::TraceCollector, } // ─── Router ────────────────────────────────────────────────────────────────── pub fn router(state: Arc) -> Router { Router::new() .route("/v1/responses", post(responses::handle_responses)) .route( "/v1/chat/completions", post(completions::handle_completions), ) .route("/v1beta/{*path}", post(gemini::handle_gemini_v1beta)) .route("/v1/models", get(handle_models)) .route("/v1/search", get(search::handle_search_get)) .route("/v1/search", post(search::handle_search_post)) .route("/v1/sessions", get(handle_list_sessions)) .route("/v1/sessions/{id}", delete(handle_delete_session)) .route("/v1/token", post(handle_set_token)) .route("/v1/usage", get(handle_usage)) .route("/v1/quota", get(handle_quota)) .route("/health", get(handle_health)) .route("/", get(handle_root)) .layer(CorsLayer::permissive()) .layer(DefaultBodyLimit::max(1_048_576)) // 1 MB .with_state(state) } // ─── Simple handlers ───────────────────────────────────────────────────────── async fn handle_root() -> Json { Json(serde_json::json!({ "service": "zerogravity", "version": "3.3.0", "runtime": "rust", "endpoints": [ "/v1/chat/completions", "/v1/responses", "/v1beta/models/{model}:generateContent", "/v1beta/models/{model}:streamGenerateContent", "/v1/models", "/v1/sessions", "/v1/token", "/v1/usage", "/v1/quota", "/health", ], })) } async fn handle_health() -> Json { Json(serde_json::json!({"status": "ok"})) } async fn handle_models() -> Json { let models: Vec = MODELS .iter() .map(|m| { serde_json::json!({ "id": m.name, "object": "model", "created": 1700000000u64, "owned_by": "antigravity", "permission": [], "root": m.name, "parent": null, "meta": { "label": m.label, "enum_value": m.model_enum, }, }) }) .collect(); Json(serde_json::json!({"object": "list", "data": models})) } async fn handle_list_sessions(State(state): State>) -> Json { let sessions = state.sessions.list_sessions().await; Json(serde_json::json!({"sessions": sessions})) } async fn handle_delete_session( State(state): State>, axum::extract::Path(id): axum::extract::Path, ) -> impl IntoResponse { if state.sessions.delete_session(&id).await { ( StatusCode::OK, Json(serde_json::json!({"status": "deleted", "session_id": id})), ) } else { ( StatusCode::NOT_FOUND, Json(serde_json::json!({"error": format!("Session not found: {id}")})), ) } } async fn handle_set_token( State(state): State>, Json(body): Json, ) -> impl IntoResponse { if !body.token.starts_with("ya29.") { return ( StatusCode::BAD_REQUEST, Json(serde_json::json!({"error": "Invalid token. Must start with ya29."})), ); } state.backend.set_oauth_token(body.token.clone()).await; // Also persist to file let token_path = crate::constants::token_file_path(); if let Err(e) = tokio::fs::write(&token_path, &body.token).await { warn!("Failed to write token file: {e}"); } let preview = safe_truncate(&body.token, 20); ( StatusCode::OK, Json(serde_json::json!({"status": "ok", "token_prefix": preview})), ) } async fn handle_usage(State(state): State>) -> Json { let stats = state.mitm_store.stats().await; Json(serde_json::json!({ "mitm": { "total_requests": stats.total_requests, "total_input_tokens": stats.total_input_tokens, "total_output_tokens": stats.total_output_tokens, "total_cache_read_tokens": stats.total_cache_read_tokens, "total_cache_creation_tokens": stats.total_cache_creation_tokens, "total_thinking_output_tokens": stats.total_thinking_output_tokens, "total_response_output_tokens": stats.total_response_output_tokens, "total_tokens": stats.total_input_tokens + stats.total_output_tokens, "per_model": stats.per_model, } })) } async fn handle_quota(State(state): State>) -> Json { let snap = state.quota_store.snapshot().await; Json(serde_json::to_value(snap).unwrap_or_default()) }