feat: initial commit — antigravity proxy with MITM, standalone LS, and snapshot tooling

This commit is contained in:
Nikketryhard
2026-02-14 02:24:35 -06:00
commit d5e7f09225
30 changed files with 9980 additions and 0 deletions

8
.gitignore vendored Normal file
View File

@@ -0,0 +1,8 @@
# Build
/target/
# Debug artifacts
*.log
*.txt
!README.txt
test_output.json

2470
Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

46
Cargo.toml Normal file
View File

@@ -0,0 +1,46 @@
[package]
name = "antigravity-proxy"
version = "3.0.0"
edition = "2021"
[dependencies]
axum = { version = "0.8", features = ["json"] }
tokio = { version = "1", features = ["full"] }
wreq = { version = "6.0.0-rc.28", features = ["json"] }
wreq-util = "3.0.0-rc.10"
serde = { version = "1", features = ["derive"] }
serde_json = "1"
uuid = { version = "1", features = ["v4"] }
regex = "1"
async-stream = "0.3"
tower-http = { version = "0.6", features = ["cors"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
clap = { version = "4", features = ["derive"] }
rand = "0.8"
flate2 = "1"
brotli = "7"
chrono = "0.4"
# MITM proxy dependencies
rcgen = "0.13"
rustls = { version = "0.23", features = ["ring"] }
tokio-rustls = "0.26"
rustls-native-certs = "0.8"
rustls-pemfile = "2"
time = "0.3"
base64 = "0.22"
httparse = "1"
# HTTP/2 + gRPC interception
hyper = { version = "1", features = ["http2", "client", "server"] }
hyper-util = { version = "0.1", features = ["tokio"] }
http-body-util = "0.1"
http = "1"
bytes = "1"
tokio-stream = "0.1"
[profile.release]
opt-level = "z"
lto = true
strip = true

155
GEMINI.md Normal file
View File

@@ -0,0 +1,155 @@
# Antigravity Rust Proxy
OpenAI-compatible proxy that intercepts and relays requests to Google's Antigravity language server, impersonating the real Electron webview.
## Quick Start
```bash
# Build
cargo build --release
# Run (language server must be running)
RUST_LOG=info ./target/release/antigravity-proxy
# Custom port
RUST_LOG=info ./target/release/antigravity-proxy --port 9000
```
Default port: **8741**
## Endpoints
| Method | Path | Description |
| -------- | ---------------------- | ----------------------------------------------------------- |
| `POST` | `/v1/responses` | **Responses API** (primary) — supports `stream: true/false` |
| `POST` | `/v1/chat/completions` | Chat Completions API (OpenAI compat shim) |
| `GET` | `/v1/models` | List available models |
| `GET` | `/v1/sessions` | List active sessions |
| `DELETE` | `/v1/sessions/:id` | Delete a session |
| `POST` | `/v1/token` | Set OAuth token at runtime |
| `GET` | `/v1/usage` | MITM-intercepted token usage stats |
| `GET` | `/v1/quota` | LS quota — credits, per-model rate limits, reset timers |
| `GET` | `/health` | Health check |
## Available Models
| Name | Label |
| ------------------- | ---------------------------------------- |
| `opus-4.6` | Claude Opus 4.6 (Thinking) — **default** |
| `opus-4.5` | Claude Opus 4.5 (Thinking) |
| `gemini-3-pro-high` | Gemini 3 Pro (High) |
| `gemini-3-pro` | Gemini 3 Pro (Low) |
| `gemini-3-flash` | Gemini 3 Flash |
## Example: Responses API
### Sync
```bash
curl -s http://localhost:8741/v1/responses \
-H "Content-Type: application/json" \
-d '{
"model": "gemini-3-flash",
"input": "Say hello in exactly 3 words",
"stream": false,
"timeout": 60
}' | jq .
```
### Streaming
```bash
curl -N http://localhost:8741/v1/responses \
-H "Content-Type: application/json" \
-d '{
"model": "gemini-3-flash",
"input": "Say hello in exactly 3 words",
"stream": true,
"timeout": 60
}'
```
### Multi-turn (session reuse)
```bash
curl -s http://localhost:8741/v1/responses \
-H "Content-Type: application/json" \
-d '{
"model": "gemini-3-flash",
"input": "What is 2+2?",
"conversation": "my-session-1",
"stream": false
}' | jq .
# Follow-up in same cascade:
curl -s http://localhost:8741/v1/responses \
-H "Content-Type: application/json" \
-d '{
"model": "gemini-3-flash",
"input": "Now multiply that by 10",
"conversation": "my-session-1",
"stream": false
}' | jq .
```
## Authentication
The proxy needs an OAuth token. Three ways to provide it:
1. **Environment variable**: `export ANTIGRAVITY_OAUTH_TOKEN=ya29.xxx`
2. **Token file**: `echo 'ya29.xxx' > ~/.config/antigravity-proxy-token`
3. **Runtime API**: `curl -X POST http://localhost:8741/v1/token -d '{"token":"ya29.xxx"}'`
## Version Detection
Version strings (Antigravity, Chrome, Electron, Client) are **auto-detected** at startup from the installed Antigravity app:
- `product.json` → app version + client/IDE version
- Binary → Chrome + Electron versions via `strings`
Falls back to hardcoded values if the app isn't installed. No manual updates needed when Antigravity updates.
## Stealth Features
- **TLS fingerprint**: BoringSSL with Chrome 142 JA3/JA4 + H2 fingerprint via `wreq`
- **Protobuf**: Hand-rolled encoder producing byte-exact match to real webview traffic
- **Warmup**: Mimics real webview startup RPC calls
- **Heartbeat**: Periodic keep-alive matching real webview lifecycle
- **Jitter**: Randomized polling intervals to avoid automation fingerprint
- **Session reuse**: Cascades are reused for multi-turn, matching real webview behavior
- **MITM proxy**: TLS-intercepting proxy for real token usage capture (opt-in)
## MITM Proxy
Built-in MITM proxy intercepts LS ↔ Google/Anthropic traffic to capture **real** token usage (input, output, cache read, cache creation). Disabled with `--no-mitm`.
### Setup
```bash
# 1. Start proxy (generates CA cert automatically)
RUST_LOG=info ./target/release/antigravity-proxy
# 2. Install wrapper (patches LS binary to route through MITM)
./scripts/mitm-wrapper.sh install
# 3. Restart Antigravity — done!
# Check status
./scripts/mitm-wrapper.sh status
# Uninstall
./scripts/mitm-wrapper.sh uninstall
```
### Usage Stats
```bash
curl -s http://localhost:8741/v1/usage | jq .
```
Returns aggregate token counts from all intercepted API calls.
### CLI Flags
- `--no-mitm`: Disable MITM proxy entirely
- `--mitm-port <PORT>`: Override MITM proxy port (default: auto-assign)

109
KNOWN_ISSUES.md Normal file
View File

@@ -0,0 +1,109 @@
# Known Issues & Future Work
---
## Medium
### 1. Cascade Correlation Is Heuristic
**File:** `src/mitm/intercept.rs``extract_cascade_hint()`
The MITM proxy matches intercepted API traffic to cascade IDs by scanning for `metadata.user_id` or `workspace_id` in the request body. If neither is found, it stores under `_latest`. Since `take_usage()` no longer falls back to `_latest`, unidentified requests will have **no MITM usage data at all**.
**Fix:** Investigate the actual request body format the LS sends for better correlation keys. Alternatively, use timing-based correlation (match MITM capture timestamp to cascade polling window).
---
### 2. Domain Certificate Cache Is Unbounded
**File:** `src/mitm/ca.rs``domain_cache`
The `domain_cache` (`HashMap<String, Arc<ServerConfig>>`) grows without bound. Each unique domain gets a cached entry containing a full `ServerConfig` with an RSA key pair. In practice, only ~510 domains are intercepted so this is unlikely to matter, but there's no eviction.
**Fix:** Set a max cache size (e.g., 100 entries) with LRU eviction, or use a TTL since leaf certs are generated with a 1-year validity.
---
### 3. Request Modification Not Implemented
**File:** `src/mitm/proxy.rs``modify_requests: false`
The `MitmConfig.modify_requests` flag exists and is plumbed through, but no actual modification logic is implemented. The flag is hardcoded to `false`.
**Fix:** When needed, implement request body mutation in `handle_http_over_tls()` — parse JSON, modify, reserialize, update `Content-Length`.
---
### 4. `total_cost_usd` Is Dead
**File:** `src/mitm/store.rs` (line 28)
`ApiUsage.total_cost_usd` is `Option<f64>` but is **always `None`** — set to `None` in all 4 construction sites (`h2_handler.rs` ×2, `intercept.rs` ×2). Neither Anthropic nor Google include cost in API responses.
**Fix:** Either remove the field (simpler), or populate it via a pricing table lookup (model → $/1K tokens) at `record_usage()` time.
---
## 🟢 Low
### 5. Wrapper Script Fallback Paths May Be Stale
**File:** `scripts/mitm-wrapper.sh``LS_FALLBACK_DIRS`
The fallback glob patterns (e.g., `~/.cursor/extensions/antigravity.antigravity-*/...`) assume a specific extension naming convention. These are only used when no running LS process is found via `/proc` scanning (Method 1), which is the primary and robust discovery mechanism.
**Impact:** Only affects `install` when the LS isn't running. Low priority.
---
### 6. No Integration Tests for MITM Module
The MITM module has unit tests for protobuf decoding and intercept parsing, but no integration tests that verify:
- TLS interception end-to-end with the generated CA
- Full HTTP/1.1 request/response cycle through the proxy
- gRPC (HTTP/2) request/response cycle through `h2_handler`
- Store recording and retrieval under concurrency
- Wrapper script install/uninstall lifecycle
---
## 🔍 Investigation
### 7. LS Exposes Credit/Quota Data via `GetUserStatus`
**Confirmed via live probing.** The LS's `GetUserStatus` RPC already returns structured cost/quota data:
```json
"planStatus": {
"planInfo": {
"planName": "Pro",
"monthlyPromptCredits": 50000,
"monthlyFlowCredits": 150000,
"monthlyFlexCreditPurchaseAmount": 25000,
"canBuyMoreCredits": true
},
"availablePromptCredits": 500,
"availableFlowCredits": 100
}
```
Each model also includes **per-model quota info**:
```json
"quotaInfo": {
"remainingFraction": 0.2,
"resetTime": "2026-02-14T07:41:37Z"
}
```
**Key findings:**
- `GetUserStatus` is the single source for credit/quota data (exposed via `LanguageServerService`)
- `SeatManagementService` methods (`GetPlanStatus`, `GetTeamCreditEntries`, `GetCascadeAnalytics`, `GetUserSubscription`) are **not routed through the LS** — they're backend-only
- `PredictionService/RetrieveUserQuota` is also backend-only (not proxied by LS)
- `GetUserAnalyticsSummary` returns empty `{}` (may not be implemented or requires different context)
- `GetModelStatuses` returns empty `{}` (separate from the model configs in `GetUserStatus`)
- `userTier` field shows subscription level: `{"id": "g1-ultra-tier", "name": "Google AI Ultra"}`
**Potential use:** We could poll `GetUserStatus` periodically and expose `availablePromptCredits`, `availableFlowCredits`, and per-model `remainingFraction` via the `/v1/usage` endpoint — giving users real-time credit burn visibility without needing MITM token counting.

239
README.md Normal file
View File

@@ -0,0 +1,239 @@
# Antigravity Proxy
OpenAI-compatible proxy that intercepts and relays requests to Google's Antigravity language server, impersonating the real Electron webview.
## Quick Start
```bash
# Build
cargo build --release
# Run (language server must be running)
RUST_LOG=info ./target/release/antigravity-proxy
# Custom port
RUST_LOG=info ./target/release/antigravity-proxy --port 9000
```
Default port: **8741**
## Endpoints
| Method | Path | Description |
| -------- | ---------------------- | ----------------------------------------------------------- |
| `POST` | `/v1/responses` | **Responses API** (primary) — supports `stream: true/false` |
| `POST` | `/v1/chat/completions` | Chat Completions API (OpenAI compat shim) |
| `GET` | `/v1/models` | List available models |
| `GET` | `/v1/sessions` | List active sessions |
| `DELETE` | `/v1/sessions/:id` | Delete a session |
| `POST` | `/v1/token` | Set OAuth token at runtime |
| `GET` | `/v1/usage` | MITM-intercepted token usage stats |
| `GET` | `/v1/quota` | LS quota — credits, per-model rate limits, reset timers |
| `GET` | `/health` | Health check |
## Available Models
| Name | Label |
| ------------------- | ---------------------------------------- |
| `opus-4.6` | Claude Opus 4.6 (Thinking) — **default** |
| `opus-4.5` | Claude Opus 4.5 (Thinking) |
| `gemini-3-pro-high` | Gemini 3 Pro (High) |
| `gemini-3-pro` | Gemini 3 Pro (Low) |
| `gemini-3-flash` | Gemini 3 Flash |
## Example: Responses API
### Sync
```bash
curl -s http://localhost:8741/v1/responses \
-H "Content-Type: application/json" \
-d '{
"model": "gemini-3-flash",
"input": "Say hello in exactly 3 words",
"stream": false,
"timeout": 60
}' | jq .
```
### Streaming
```bash
curl -N http://localhost:8741/v1/responses \
-H "Content-Type: application/json" \
-d '{
"model": "gemini-3-flash",
"input": "Say hello in exactly 3 words",
"stream": true,
"timeout": 60
}'
```
### Multi-turn (session reuse)
```bash
# First message
curl -s http://localhost:8741/v1/responses \
-H "Content-Type: application/json" \
-d '{
"model": "gemini-3-flash",
"input": "What is 2+2?",
"conversation": "my-session-1",
"stream": false
}' | jq .
# Follow-up in same cascade
curl -s http://localhost:8741/v1/responses \
-H "Content-Type: application/json" \
-d '{
"model": "gemini-3-flash",
"input": "Now multiply that by 10",
"conversation": "my-session-1",
"stream": false
}' | jq .
```
## Authentication
The proxy needs an OAuth token. Three ways to provide it:
1. **Environment variable**: `export ANTIGRAVITY_OAUTH_TOKEN=ya29.xxx`
2. **Token file**: `echo 'ya29.xxx' > ~/.config/antigravity-proxy-token`
3. **Runtime API**: `curl -X POST http://localhost:8741/v1/token -d '{"token":"ya29.xxx"}'`
## Stealth Features
- **TLS fingerprint**: BoringSSL with Chrome 142 JA3/JA4 + H2 fingerprint via `wreq`
- **Protobuf**: Hand-rolled encoder producing byte-exact match to real webview traffic
- **Warmup**: Mimics real webview startup RPC calls
- **Heartbeat**: Periodic keep-alive matching real webview lifecycle
- **Jitter**: Randomized polling intervals to avoid automation fingerprint
- **Session reuse**: Cascades are reused for multi-turn, matching real webview behavior
- **Version detection**: Auto-detects Antigravity/Chrome/Electron versions from installed app
## MITM Proxy
Built-in TLS-intercepting proxy captures real token usage from LS ↔ Google/Anthropic traffic. Disabled with `--no-mitm`.
### Setup
```bash
# 1. Start proxy (generates CA cert automatically)
RUST_LOG=info ./target/release/antigravity-proxy
# 2. Install wrapper (patches LS binary to route through MITM)
sudo ./scripts/mitm-wrapper.sh install
# 3. Restart Antigravity — done!
# Check status
./scripts/mitm-wrapper.sh status
# Uninstall
sudo ./scripts/mitm-wrapper.sh uninstall
```
### Usage Stats
```bash
curl -s http://localhost:8741/v1/usage | jq .
```
## Standalone Language Server
Launch an isolated LS instance for experimentation:
```bash
# Basic test (starts, checks quota, exits)
./scripts/standalone-ls.sh
# Foreground mode (stays alive)
./scripts/standalone-ls.sh --fg
# With MITM traffic interception
./scripts/standalone-ls.sh --mitm
# Capture a clean traffic snapshot
./scripts/standalone-ls.sh --snapshot
# Snapshot with custom prompt
./scripts/standalone-ls.sh --snapshot --prompt "Explain quantum computing"
```
The standalone LS shares the main Antigravity app's OAuth (via its extension server) but has its own port, data directory, and cascades.
### Traffic Snapshots
The `--snapshot` flag captures all HTTP/2 traffic and formats it into a clean, color-coded report:
```
══════════════════════════════════════════════════════════════════════
STANDALONE LS TRAFFIC SNAPSHOT
══════════════════════════════════════════════════════════════════════
▸ Outbound Connections
→ antigravity-unleash.goog (Feature Flags)
→ play.googleapis.com (Telemetry)
══════════════════════════════════════════════════════════════════════
antigravity-unleash.goog — Feature Flags
══════════════════════════════════════════════════════════════════════
→ POST /api/client/register
authorization: *:production.e4455...
unleash-appname: codeium-language-server
Body (561 bytes, JSON):
{"appName":"codeium-language-server","instanceId":"..."}
```
## Architecture
```mermaid
graph LR
A[Your App<br/>OpenAI SDK] -->|HTTP| B[Proxy<br/>:8741]
B -->|gRPC| C[Language<br/>Server]
C -->|HTTPS| D[Google /<br/>Anthropic]
E[MITM Proxy<br/>:8742] -.->|intercept| D
C -.->|routed via| E
```
## Project Structure
```
src/
├── main.rs # Entry point, CLI args, lifecycle
├── backend.rs # LS discovery and RPC communication
├── constants.rs # Version detection + stealth constants
├── proto.rs # Hand-rolled protobuf encoder
├── quota.rs # LS quota polling and caching
├── session.rs # Multi-turn session management
├── warmup.rs # Startup warmup (mimics real webview)
├── api/
│ └── mod.rs # Axum API server + route handlers
└── mitm/
├── mod.rs # MITM module root
├── ca.rs # Dynamic CA cert generation
├── proxy.rs # TLS-intercepting proxy server
├── intercept.rs # API response parser (usage extraction)
└── store.rs # Token usage aggregation store
scripts/
├── mitm-wrapper.sh # Install/uninstall MITM wrapper on LS binary
├── standalone-ls.sh # Launch isolated LS instance
└── parse-snapshot.py # HTTP/2 traffic snapshot parser
```
## CLI Flags
```
antigravity-proxy [OPTIONS]
Options:
--port <PORT> API server port (default: 8741)
--no-mitm Disable MITM proxy
--mitm-port <PORT> Override MITM proxy port (default: auto)
```
## License
Private. Do not distribute.

331
scripts/mitm-wrapper.sh Executable file
View File

@@ -0,0 +1,331 @@
#!/usr/bin/env bash
# ╔═══════════════════════════════════════════════════════════════════════════╗
# ║ Antigravity MITM LS Wrapper ║
# ║ ║
# ║ This script replaces the real Antigravity language server binary. ║
# ║ It injects HTTPS_PROXY and NODE_EXTRA_CA_CERTS environment variables ║
# ║ so the MITM proxy can intercept LS<->API traffic. ║
# ║ ║
# ║ Install: ./mitm-wrapper.sh install ║
# ║ Uninstall: ./mitm-wrapper.sh uninstall ║
# ║ (No args = act as wrapper, exec the real binary with injected env) ║
# ╚═══════════════════════════════════════════════════════════════════════════╝
set -euo pipefail
# ── Config ────────────────────────────────────────────────────────────────────
# Resolve the real user's home (not /root when running under sudo)
if [[ -n "${SUDO_USER:-}" ]]; then
REAL_HOME="$(getent passwd "$SUDO_USER" | cut -d: -f6)"
else
REAL_HOME="$HOME"
fi
MITM_PORT_FILE="${REAL_HOME}/.config/antigravity-proxy/mitm-port"
if [[ -n "${ANTIGRAVITY_MITM_PORT:-}" ]]; then
MITM_PORT="$ANTIGRAVITY_MITM_PORT"
elif [[ -f "$MITM_PORT_FILE" ]]; then
MITM_PORT="$(cat "$MITM_PORT_FILE" 2>/dev/null || echo 8742)"
else
MITM_PORT="8742"
fi
CA_PATH="${REAL_HOME}/.config/antigravity-proxy/mitm-ca.pem"
# Antigravity LS — discovered dynamically from running processes.
# Hardcoded paths are only used as a fallback if no LS process is running.
LS_FALLBACK_DIRS=(
"/usr/share/antigravity/resources/app/extensions/antigravity/bin"
"${REAL_HOME}/.antigravity/extensions/antigravity.antigravity-*/dist/bundled/language-server/bin"
"${REAL_HOME}/.cursor/extensions/antigravity.antigravity-*/dist/bundled/language-server/bin"
"${REAL_HOME}/.vscode/extensions/antigravity.antigravity-*/dist/bundled/language-server/bin"
"/opt/antigravity/language-server/bin"
)
BACKUP_SUFFIX=".real"
# ── Colors ────────────────────────────────────────────────────────────────────
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[0;33m'
CYAN='\033[0;36m'
BOLD='\033[1m'
NC='\033[0m'
# ── Find LS binary ───────────────────────────────────────────────────────────
find_ls_binary() {
# Method 1: Find from running process via /proc
if [[ -d /proc ]]; then
for pid_dir in /proc/[0-9]*; do
local exe_target
exe_target="$(readlink "${pid_dir}/exe" 2>/dev/null)" || continue
# Strip " (deleted)" suffix that appears when the binary was unlinked
exe_target="${exe_target% (deleted)}"
if [[ "$exe_target" == *language_server_linux* ]] || \
[[ "$exe_target" == *antigravity-language-server* ]]; then
# FIX: If the running process is the backup (.real), strip the suffix
# so we return the path to the base binary name.
echo "${exe_target%$BACKUP_SUFFIX}"
return 0
fi
done
fi
# Method 2: Fallback — scan known directories for common binary names
local bin_names=("language_server_linux_x64" "language_server_linux_arm64" "antigravity-language-server")
for dir_pattern in "${LS_FALLBACK_DIRS[@]}"; do
for dir in $dir_pattern; do
[[ -d "$dir" ]] || continue
for name in "${bin_names[@]}"; do
local path="${dir}/${name}"
if [[ -f "$path" || -f "${path}${BACKUP_SUFFIX}" ]]; then
echo "$path"
return 0
fi
done
done
done
return 1
}
# ── Install ──────────────────────────────────────────────────────────────────
cmd_install() {
# Find the LS binary first (quiet, just to check permissions)
local ls_path
ls_path=$(find_ls_binary) || ls_path="${1:-}"
# Allow override
if [[ -n "${1:-}" ]]; then
ls_path="$1"
fi
# Check permissions upfront — re-exec with sudo before doing anything
if [[ -n "$ls_path" ]]; then
local ls_dir
ls_dir="$(dirname "$ls_path")"
if [[ ! -w "$ls_dir" ]] && [[ "$EUID" -ne 0 ]]; then
echo -e " ${RED}${NC} ${ls_dir} requires elevated permissions"
echo -e " run: sudo $0 install ${1:-}"
exit 1
fi
fi
echo -e "${BOLD}${CYAN}Antigravity MITM Wrapper Installer${NC}"
echo -e "───────────────────────────────────"
echo ""
# Find the LS binary (for real this time, with output)
if [[ -z "$ls_path" ]]; then
echo -e " ${RED}${NC} Could not find Antigravity language server binary."
echo -e " No LS process found in /proc, and fallback paths didn't match."
echo ""
echo -e " Set the path manually:"
echo -e " $0 install /path/to/language_server_linux_x64"
exit 1
fi
echo -e " ${GREEN}${NC} Found LS: ${ls_path}"
local real_path="${ls_path}${BACKUP_SUFFIX}"
local wrapper_dir
wrapper_dir="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
local wrapper_src="${wrapper_dir}/mitm-wrapper.sh"
# Verify the binary exists and is not already wrapped
if [[ -f "$real_path" ]]; then
echo -e " ${YELLOW}!${NC} Already installed (backup exists at ${real_path})"
echo -e " Run '$0 uninstall' first to reinstall."
exit 0
fi
if [[ ! -f "$ls_path" ]]; then
echo -e " ${RED}${NC} Binary not found: ${ls_path}"
exit 1
fi
# Verify it's a real binary, not already our wrapper
if head -c 100 "$ls_path" | grep -q 'ANTIGRAVITY_MITM_PORT'; then
echo -e " ${YELLOW}!${NC} Already wrapped (script detected). Run '$0 uninstall' first."
exit 0
fi
# Check CA cert
if [[ ! -f "$CA_PATH" ]]; then
echo -e " ${YELLOW}!${NC} CA cert not found at ${CA_PATH}"
echo -e " Start the proxy first to generate it."
echo -e " Continuing install anyway..."
else
echo -e " ${GREEN}${NC} CA cert: ${CA_PATH}"
fi
# Back up the real binary
cp -p "$ls_path" "$real_path"
echo -e " ${GREEN}${NC} Backed up real binary to ${real_path}"
# Remove the original before writing (avoids "Text file busy" if LS is running)
rm -f "$ls_path"
# Create the wrapper script in-place
tee "$ls_path" > /dev/null << 'WRAPPER_EOF'
#!/usr/bin/env bash
# Antigravity MITM LS Wrapper — auto-generated, do not edit.
# The LS is a Go binary — it reads HTTPS_PROXY and SSL_CERT_FILE (not NODE_EXTRA_CA_CERTS).
# Go's gRPC library also reads GRPC_DEFAULT_SSL_ROOTS_FILE_PATH for root certs.
# We build a combined CA bundle (system CAs + MITM CA) and inject it.
REAL_BINARY="${BASH_SOURCE[0]}.real"
if [[ ! -f "$REAL_BINARY" ]]; then
echo "ERROR: Real LS binary not found at $REAL_BINARY" >&2
echo "Run 'mitm-wrapper.sh uninstall' and reinstall." >&2
exit 1
fi
# Inject MITM proxy (don't override if already set)
export HTTPS_PROXY="${HTTPS_PROXY:-http://127.0.0.1:__MITM_PORT__}"
# Build combined CA bundle: system CAs + MITM CA
MITM_CA="__CA_PATH__"
COMBINED_CA="/tmp/antigravity-mitm-combined-ca.pem"
if [[ -f "$MITM_CA" ]]; then
# Find system CA bundle
SYS_CA=""
for candidate in /etc/ssl/certs/ca-certificates.crt /etc/pki/tls/certs/ca-bundle.crt /etc/ssl/cert.pem; do
if [[ -f "$candidate" ]]; then
SYS_CA="$candidate"
break
fi
done
if [[ -n "$SYS_CA" ]]; then
cat "$SYS_CA" "$MITM_CA" > "$COMBINED_CA" 2>/dev/null
export SSL_CERT_FILE="$COMBINED_CA"
# Go's gRPC library may use this instead of SSL_CERT_FILE
export GRPC_DEFAULT_SSL_ROOTS_FILE_PATH="$COMBINED_CA"
fi
fi
exec "$REAL_BINARY" "$@"
WRAPPER_EOF
# Substitute actual values
sed -i "s|__MITM_PORT__|${MITM_PORT}|g" "$ls_path"
sed -i "s|__CA_PATH__|${CA_PATH}|g" "$ls_path"
# Make executable
chmod +x "$ls_path"
echo -e " ${GREEN}${NC} Wrapper installed at ${ls_path}"
echo ""
echo -e " ${BOLD}How it works:${NC}"
echo -e " When Antigravity starts the LS, the wrapper will:"
echo -e " 1. Set ${CYAN}HTTPS_PROXY${NC}=http://127.0.0.1:${MITM_PORT}"
echo -e " 2. Build combined CA bundle (system + MITM) at /tmp/antigravity-mitm-combined-ca.pem"
echo -e " 3. Set ${CYAN}SSL_CERT_FILE${NC} to the combined bundle"
echo -e " 4. Exec the real LS binary with all original args"
echo ""
echo -e " ${YELLOW}Note:${NC} Restart Antigravity for the wrapper to take effect."
echo ""
}
# ── Uninstall ────────────────────────────────────────────────────────────────
cmd_uninstall() {
# Check permissions upfront
local ls_path
ls_path=$(find_ls_binary) || true
if [[ -n "$ls_path" ]] && [[ ! -w "$(dirname "$ls_path")" ]] && [[ "$EUID" -ne 0 ]]; then
echo -e " ${RED}${NC} $(dirname "$ls_path") requires elevated permissions"
echo -e " run: sudo $0 uninstall"
exit 1
fi
echo -e "${BOLD}${CYAN}Antigravity MITM Wrapper Uninstaller${NC}"
echo -e "─────────────────────────────────────"
echo ""
if [[ -n "$ls_path" ]]; then
local real_path="${ls_path}${BACKUP_SUFFIX}"
if [[ -f "$real_path" ]]; then
mv -f "$real_path" "$ls_path"
echo -e " ${GREEN}${NC} Restored real binary at ${ls_path}"
else
echo -e " ${YELLOW}!${NC} No backup found at ${real_path}"
echo -e " The LS binary may not be wrapped."
fi
else
echo -e " ${RED}${NC} Could not find Antigravity language server binary."
fi
echo ""
echo -e " ${YELLOW}Note:${NC} Restart Antigravity for the change to take effect."
echo ""
}
# ── Status ───────────────────────────────────────────────────────────────────
cmd_status() {
echo -e "${BOLD}${CYAN}Antigravity MITM Wrapper Status${NC}"
echo -e "────────────────────────────────"
echo ""
local ls_path
if ls_path=$(find_ls_binary); then
echo -e " ${GREEN}${NC} LS binary: ${ls_path}"
local real_path="${ls_path}${BACKUP_SUFFIX}"
if [[ -f "$real_path" ]]; then
echo -e " ${GREEN}${NC} Wrapper: ${BOLD}installed${NC}"
echo -e " ${GREEN}${NC} Real binary: ${real_path}"
# Check if wrapper is valid
if head -c 200 "$ls_path" | grep -q 'MITM LS Wrapper'; then
echo -e " ${GREEN}${NC} Wrapper script: valid"
else
echo -e " ${RED}${NC} Wrapper script: ${BOLD}corrupted or replaced${NC}"
fi
else
echo -e " ${YELLOW}${NC} Wrapper: ${BOLD}not installed${NC}"
fi
else
echo -e " ${RED}${NC} LS binary: not found"
fi
# Check CA cert
if [[ -f "$CA_PATH" ]]; then
echo -e " ${GREEN}${NC} CA cert: ${CA_PATH}"
else
echo -e " ${RED}${NC} CA cert: not found (start proxy first)"
fi
# Check MITM port
if ss -tlnp 2>/dev/null | grep -q ":${MITM_PORT} "; then
echo -e " ${GREEN}${NC} MITM proxy: listening on :${MITM_PORT}"
else
echo -e " ${YELLOW}${NC} MITM proxy: not running on :${MITM_PORT}"
fi
echo ""
}
# ── Main ─────────────────────────────────────────────────────────────────────
case "${1:-}" in
install)
shift
cmd_install "${1:-}"
;;
uninstall)
cmd_uninstall
;;
status)
cmd_status
;;
-h|--help)
echo "Usage: $0 {install|uninstall|status}"
echo ""
echo "Commands:"
echo " install [path] Install MITM wrapper (auto-detect or specify path)"
echo " uninstall Restore original LS binary"
echo " status Show wrapper installation status"
echo ""
echo "Environment:"
echo " ANTIGRAVITY_MITM_PORT MITM proxy port (default: 8742)"
;;
*)
echo "Usage: $0 {install|uninstall|status}"
exit 1
;;
esac

475
scripts/parse-snapshot.py Normal file
View File

@@ -0,0 +1,475 @@
#!/usr/bin/env python3
"""
Parse Go GODEBUG=http2debug=2 output into a clean, readable snapshot.
Usage:
python3 parse-snapshot.py < raw-http2-dump.log
python3 parse-snapshot.py /path/to/logfile
"""
import sys
import re
import json
import gzip
from collections import defaultdict
from io import BytesIO
# ── Colors ────────────────────────────────────────────────────────────────────
BOLD = "\033[1m"
DIM = "\033[2m"
RED = "\033[91m"
GREEN = "\033[92m"
YELLOW = "\033[93m"
CYAN = "\033[96m"
MAGENTA = "\033[95m"
NC = "\033[0m"
# ── Regexes ───────────────────────────────────────────────────────────────────
RE_ENCODING_HEADER = re.compile(
r'http2: Transport encoding header "([^"]+)" = "([^"]*)"'
)
RE_DECODED_HEADER = re.compile(
r'http2: decoded hpack field header field "([^"]+)" = "([^"]*)"'
)
RE_SERVER_ENCODING = re.compile(
r'http2: server encoding header "([^"]+)" = "([^"]*)"'
)
RE_WROTE_DATA = re.compile(
r'http2: Framer [^:]+: wrote DATA flags=(\S+) stream=(\d+) len=(\d+) data="(.*?)"'
)
RE_READ_DATA = re.compile(
r'http2: Framer [^:]+: read DATA flags=(\S+) stream=(\d+) len=(\d+) data="(.*?)"'
)
RE_TRANSPORT_CONN = re.compile(
r'http2: Transport creating client conn [^ ]+ to (.+)'
)
RE_SERVER_READ_DATA = re.compile(
r'http2: server read frame DATA flags=(\S+) stream=(\d+) len=(\d+) data="(.*?)"'
)
RE_WROTE_HEADERS = re.compile(
r'http2: Framer [^:]+: wrote HEADERS flags=(\S+) stream=(\d+)'
)
RE_TIMESTAMP = re.compile(r'^(\d{4}/\d{2}/\d{2} \d{2}:\d{2}:\d{2})')
RE_LS_LOG = re.compile(r'^[IWE]\d{4} ')
RE_MAXPROCS = re.compile(r'^.*maxprocs:')
RE_BYTES_OMITTED = re.compile(r'\((\d+) bytes omitted\)$')
# Known domain purposes
DOMAIN_INFO = {
"antigravity-unleash.goog": ("Feature Flags", "Unleash SDK — controls A/B tests, feature rollouts"),
"daily-cloudcode-pa.googleapis.com": ("LLM API (gRPC)", "Primary Gemini/Claude API endpoint"),
"cloudcode-pa.googleapis.com": ("LLM API (gRPC)", "Production Gemini/Claude API endpoint"),
"api.anthropic.com": ("Claude API", "Direct Anthropic API calls"),
"lh3.googleusercontent.com": ("Profile Picture", "User avatar image"),
"play.googleapis.com": ("Telemetry", "Google Play telemetry/logging"),
"firebaseinstallations.googleapis.com": ("Firebase", "Firebase installation tracking"),
"oauth2.googleapis.com": ("OAuth", "Token refresh/exchange"),
"speech.googleapis.com": ("Speech", "Voice input processing"),
"modelarmor.googleapis.com": ("Safety", "Content safety/filtering"),
}
class Request:
def __init__(self):
self.method = ""
self.path = ""
self.authority = ""
self.scheme = ""
self.headers = {}
self.data = b""
self.data_len = 0
self.stream_id = None
self.timestamp = ""
self.direction = "outgoing" # outgoing = LS→upstream, incoming = LS←upstream
class Snapshot:
def __init__(self):
self.connections = [] # (timestamp, target)
self.requests = [] # list of Request
self.responses = defaultdict(lambda: {"headers": {}, "data": b"", "data_len": 0})
self.ls_logs = []
def parse(self, lines):
current_headers = {}
current_direction = "outgoing"
current_stream = None
for line in lines:
line = line.rstrip()
# Skip empty
if not line:
continue
# LS process logs
if RE_LS_LOG.match(line) or RE_MAXPROCS.match(line):
self.ls_logs.append(line)
continue
# New connection
m = RE_TRANSPORT_CONN.search(line)
if m:
ts = ""
ts_m = RE_TIMESTAMP.match(line)
if ts_m:
ts = ts_m.group(1)
self.connections.append((ts, m.group(1)))
continue
# Outgoing headers (Transport encoding = LS sending to upstream)
m = RE_ENCODING_HEADER.search(line)
if m:
key, val = m.group(1), m.group(2)
if key == ":method":
# New request starting
if current_headers.get(":path"):
self._finalize_request(current_headers, "outgoing", line)
current_headers = {}
current_direction = "outgoing"
current_headers[key] = val
ts_m = RE_TIMESTAMP.match(line)
if ts_m and "timestamp" not in current_headers:
current_headers["timestamp"] = ts_m.group(1)
continue
# Incoming headers (decoded hpack = upstream responding, OR server receiving)
m = RE_DECODED_HEADER.search(line)
if m:
key, val = m.group(1), m.group(2)
if key == ":authority" and "server read frame" not in line:
# This is a request received by our LS
if current_headers.get(":path"):
self._finalize_request(current_headers, current_direction, line)
current_headers = {}
current_direction = "incoming"
current_headers[key] = val
continue
# Server encoding (our LS responding)
m = RE_SERVER_ENCODING.search(line)
if m:
continue # Skip server response headers for now
# Headers frame written (triggers finalization)
m = RE_WROTE_HEADERS.search(line)
if m:
current_stream = m.group(2)
if current_headers.get(":path") or current_headers.get(":method"):
req = self._finalize_request(current_headers, current_direction, line)
if req:
req.stream_id = current_stream
current_headers = {}
continue
# Data frames (wrote = LS sending, read = LS receiving)
for pattern, direction in [
(RE_WROTE_DATA, "sent"),
(RE_READ_DATA, "received"),
(RE_SERVER_READ_DATA, "server_received"),
]:
m = pattern.search(line)
if m:
flags, stream, length, data_str = (
m.group(1),
m.group(2),
int(m.group(3)),
m.group(4),
)
# Find matching request by stream
for req in reversed(self.requests):
if req.stream_id == stream:
raw = self._decode_data_str(data_str, line)
if direction == "sent" or direction == "server_received":
req.data += raw
req.data_len = max(req.data_len, length)
break
# Also check omitted bytes
om = RE_BYTES_OMITTED.search(line)
if om:
pass # length already captured
break
# Finalize any remaining headers
if current_headers.get(":path") or current_headers.get(":method"):
self._finalize_request(current_headers, current_direction, "")
def _finalize_request(self, headers, direction, _line):
req = Request()
req.method = headers.pop(":method", "GET")
req.path = headers.pop(":path", "/")
req.authority = headers.pop(":authority", "")
req.scheme = headers.pop(":scheme", "https")
req.timestamp = headers.pop("timestamp", "")
req.direction = direction
req.headers = {k: v for k, v in headers.items() if not k.startswith(":")}
self.requests.append(req)
return req
def _decode_data_str(self, s, full_line):
"""Decode escaped string from GODEBUG output back to bytes."""
try:
# Handle Go's escaped bytes
result = bytearray()
i = 0
while i < len(s):
if s[i] == "\\" and i + 1 < len(s):
if s[i + 1] == "x" and i + 3 < len(s):
result.append(int(s[i + 2 : i + 4], 16))
i += 4
elif s[i + 1] == "n":
result.append(10)
i += 2
elif s[i + 1] == "r":
result.append(13)
i += 2
elif s[i + 1] == "t":
result.append(9)
i += 2
elif s[i + 1] == "\\":
result.append(92)
i += 2
elif s[i + 1] == '"':
result.append(34)
i += 2
else:
result.append(ord(s[i]))
i += 1
else:
result.append(ord(s[i]))
i += 1
return bytes(result)
except Exception:
return s.encode("utf-8", errors="replace")
def render(self):
out = []
# Header
out.append(f"\n{BOLD}{CYAN}{'' * 70}{NC}")
out.append(f"{BOLD}{CYAN} STANDALONE LS TRAFFIC SNAPSHOT{NC}")
out.append(f"{BOLD}{CYAN}{'' * 70}{NC}\n")
# LS Logs
if self.ls_logs:
out.append(f"{BOLD}▸ Language Server Logs{NC}")
out.append(f"{DIM}{'' * 60}{NC}")
for log in self.ls_logs:
out.append(f" {DIM}{log}{NC}")
out.append("")
# Connections
if self.connections:
out.append(f"{BOLD}▸ Outbound Connections{NC}")
out.append(f"{DIM}{'' * 60}{NC}")
for ts, target in self.connections:
domain = target.split(":")[0] if ":" in target else target
info = DOMAIN_INFO.get(domain, ("Unknown", ""))
out.append(
f" {GREEN}{NC} {BOLD}{target}{NC} {DIM}({info[0]}){NC}"
)
if info[1]:
out.append(f" {DIM}{info[1]}{NC}")
out.append("")
# Group requests by domain
by_domain = defaultdict(list)
for req in self.requests:
by_domain[req.authority].append(req)
# Render each domain's requests
for domain, reqs in by_domain.items():
if domain.startswith("127.0.0.1"):
label = "Local (our requests to LS)"
color = DIM
else:
info = DOMAIN_INFO.get(domain, ("External", ""))
label = info[0]
color = YELLOW if "API" in info[0] else CYAN
out.append(f"{BOLD}{'' * 70}{NC}")
out.append(f"{BOLD}{color} {domain}{NC} {DIM}{label}{NC}")
out.append(f"{BOLD}{'' * 70}{NC}")
for i, req in enumerate(reqs):
arrow = "" if req.direction == "outgoing" else ""
method_color = GREEN if req.method == "GET" else YELLOW
out.append(f"\n {BOLD}{arrow} {method_color}{req.method}{NC} {req.path}")
# Important headers
interesting = [
"authorization",
"content-type",
"user-agent",
"unleash-appname",
"unleash-instanceid",
"unleash-sdk",
"x-goog-api-key",
"x-goog-api-client",
"grpc-encoding",
"te",
]
shown = False
for key in interesting:
if key in req.headers:
val = req.headers[key]
# Mask tokens partially
if key == "authorization" and len(val) > 30:
if val.startswith("Bearer "):
val = f"Bearer {val[7:20]}...{val[-10:]}"
elif len(val) > 40:
val = f"{val[:30]}...{val[-10:]}"
out.append(f" {DIM}{key}:{NC} {val}")
shown = True
# All other headers (collapsed)
other = {
k: v
for k, v in req.headers.items()
if k not in interesting and not k.startswith(":")
}
if other:
if not shown:
out.append(f" {DIM}Headers:{NC}")
for k, v in other.items():
out.append(f" {DIM}{k}:{NC} {v}")
# Body
if req.data:
out.append(self._render_body(req.data, req.data_len))
out.append("")
return "\n".join(out)
def _render_body(self, data, total_len):
"""Render body data in the most readable format possible."""
lines = []
# Try JSON
try:
text = data.decode("utf-8")
obj = json.loads(text)
pretty = json.dumps(obj, indent=2, ensure_ascii=False)
lines.append(f" {BOLD}Body ({len(data)} bytes, JSON):{NC}")
for l in pretty.split("\n")[:30]:
lines.append(f" {GREEN}{l}{NC}")
if len(pretty.split("\n")) > 30:
lines.append(f" {DIM}... ({len(pretty.split(chr(10))) - 30} more lines){NC}")
return "\n".join(lines)
except (json.JSONDecodeError, UnicodeDecodeError):
pass
# Try gzip
if data[:2] == b"\x1f\x8b":
try:
decompressed = gzip.decompress(data)
try:
text = decompressed.decode("utf-8")
try:
obj = json.loads(text)
pretty = json.dumps(obj, indent=2, ensure_ascii=False)
lines.append(
f" {BOLD}Body ({len(data)} bytes gzip → {len(decompressed)} bytes, JSON):{NC}"
)
for l in pretty.split("\n")[:50]:
lines.append(f" {GREEN}{l}{NC}")
if len(pretty.split("\n")) > 50:
lines.append(
f" {DIM}... ({len(pretty.split(chr(10))) - 50} more lines){NC}"
)
return "\n".join(lines)
except json.JSONDecodeError:
lines.append(
f" {BOLD}Body ({len(data)} bytes gzip → {len(decompressed)} bytes, text):{NC}"
)
for l in text.split("\n")[:20]:
lines.append(f" {l[:200]}")
return "\n".join(lines)
except UnicodeDecodeError:
lines.append(
f" {BOLD}Body ({len(data)} bytes gzip → {len(decompressed)} bytes, binary):{NC}"
)
lines.append(f" {DIM}{self._extract_strings(decompressed)}{NC}")
return "\n".join(lines)
except Exception:
pass
# Try protobuf (extract readable strings)
if data[:1] in (b"\x08", b"\x0a", b"\x10", b"\x12", b"\x18", b"\x1a", b"\x20", b"\x22"):
strings = self._extract_strings(data)
if strings:
lines.append(f" {BOLD}Body ({total_len} bytes, protobuf):{NC}")
lines.append(f" {DIM}Extracted strings:{NC}")
for s in strings.split(" | ")[:20]:
s = s.strip()
if len(s) > 3:
lines.append(f" {MAGENTA}{s}{NC}")
return "\n".join(lines)
# Try plain text
try:
text = data.decode("utf-8")
lines.append(f" {BOLD}Body ({len(data)} bytes, text):{NC}")
for l in text.split("\n")[:10]:
lines.append(f" {l[:200]}")
return "\n".join(lines)
except UnicodeDecodeError:
pass
# PNG
if data[:4] == b"\x89PNG":
lines.append(f" {BOLD}Body ({total_len} bytes, PNG image){NC}")
return "\n".join(lines)
# Binary fallback
lines.append(f" {BOLD}Body ({total_len} bytes, binary):{NC}")
strings = self._extract_strings(data)
if strings:
lines.append(f" {DIM}Extracted strings:{NC}")
for s in strings.split(" | ")[:15]:
s = s.strip()
if len(s) > 3:
lines.append(f" {MAGENTA}{s}{NC}")
else:
lines.append(f" {DIM}(no readable strings){NC}")
return "\n".join(lines)
def _extract_strings(self, data, min_len=4):
"""Extract printable ASCII strings from binary data."""
strings = []
current = bytearray()
for b in data:
if 32 <= b <= 126:
current.append(b)
else:
if len(current) >= min_len:
strings.append(current.decode("ascii"))
current = bytearray()
if len(current) >= min_len:
strings.append(current.decode("ascii"))
# Deduplicate while preserving order
seen = set()
unique = []
for s in strings:
if s not in seen:
seen.add(s)
unique.append(s)
return " | ".join(unique[:30])
def main():
if len(sys.argv) > 1:
with open(sys.argv[1]) as f:
lines = f.readlines()
else:
lines = sys.stdin.readlines()
snap = Snapshot()
snap.parse(lines)
print(snap.render())
if __name__ == "__main__":
main()

277
scripts/standalone-ls.sh Executable file
View File

@@ -0,0 +1,277 @@
#!/usr/bin/env bash
# ╔═══════════════════════════════════════════════════════════════════════════╗
# ║ Standalone Language Server Launcher ║
# ║ ║
# ║ Launches an isolated LS instance that: ║
# ║ - Shares OAuth via the main app's extension server ║
# ║ - Has its own HTTPS port, data dir, and cascades ║
# ║ - Optionally routes traffic through our MITM proxy ║
# ║ - Can capture a clean traffic snapshot ║
# ║ ║
# ║ Usage: ║
# ║ ./standalone-ls.sh # Launch, test, exit ║
# ║ ./standalone-ls.sh --fg # Foreground (stay alive) ║
# ║ ./standalone-ls.sh --mitm # Route through MITM proxy ║
# ║ ./standalone-ls.sh --snapshot # Capture clean traffic dump ║
# ║ ./standalone-ls.sh --snapshot --prompt "Say hello" ║
# ╚═══════════════════════════════════════════════════════════════════════════╝
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
# ── Defaults ──────────────────────────────────────────────────────────────────
LS_BIN="/usr/share/antigravity/resources/app/extensions/antigravity/bin/language_server_linux_x64"
HTTPS_PORT="42200"
DATA_DIR="/tmp/antigravity-standalone"
FOREGROUND=false
USE_MITM=false
SNAPSHOT=false
TIMEOUT=15
PROMPT=""
MODEL="MODEL_PLACEHOLDER_M3"
# ── Parse args ────────────────────────────────────────────────────────────────
while [[ $# -gt 0 ]]; do
case "$1" in
--port) HTTPS_PORT="$2"; shift 2 ;;
--mitm) USE_MITM=true; shift ;;
--fg) FOREGROUND=true; shift ;;
--timeout) TIMEOUT="$2"; shift 2 ;;
--snapshot) SNAPSHOT=true; TIMEOUT=30; shift ;;
--prompt) PROMPT="$2"; shift 2 ;;
--model) MODEL="$2"; shift 2 ;;
-h|--help)
echo "Usage: $0 [OPTIONS]"
echo ""
echo "Options:"
echo " --port PORT HTTPS port for standalone LS (default: 42200)"
echo " --mitm Route traffic through MITM proxy"
echo " --fg Run in foreground (stay alive)"
echo " --timeout SECS Background mode timeout (default: 15)"
echo " --snapshot Capture clean traffic snapshot"
echo " --prompt TEXT Prompt to send (snapshot mode)"
echo " --model MODEL Model alias (default: MODEL_PLACEHOLDER_M3)"
exit 0 ;;
*) echo "Unknown option: $1"; exit 1 ;;
esac
done
# ── Discover main LS config ──────────────────────────────────────────────────
MAIN_PID=$(pgrep -f 'language_server_linux_x64' | head -1 || true)
if [[ -z "$MAIN_PID" ]]; then
echo "[-] No main LS process found. Main Antigravity must be running."
exit 1
fi
MAIN_CSRF=$(tr '\0' '\n' < /proc/"$MAIN_PID"/cmdline | grep -A1 'csrf_token' | tail -1)
EXT_PORT=$(tr '\0' '\n' < /proc/"$MAIN_PID"/cmdline | grep -A1 'extension_server_port' | tail -1)
echo "[*] Main LS PID: $MAIN_PID"
echo "[*] CSRF: $MAIN_CSRF"
echo "[*] Extension server: $EXT_PORT"
# ── Build protobuf metadata for stdin ─────────────────────────────────────────
TS=$(date +%s)
METADATA=$(python3 -c "
import sys
def v(n):
r = bytearray()
while n > 0x7f:
r.append((n & 0x7f) | 0x80)
n >>= 7
r.append(n & 0x7f)
return bytes(r)
def s(f, val):
t = v((f << 3) | 2)
d = val.encode()
return t + v(len(d)) + d
buf = bytearray()
buf += s(1, 'standalone-api-key-$TS')
buf += s(3, 'antigravity')
buf += s(4, '1.15.8')
buf += s(5, '1.16.39')
buf += s(6, 'en_US')
buf += s(10, 'standalone-session-$TS')
buf += s(11, 'antigravity')
sys.stdout.buffer.write(bytes(buf))
" | base64)
# ── Setup data directory ──────────────────────────────────────────────────────
mkdir -p "$DATA_DIR/.gemini"
# ── MITM environment ─────────────────────────────────────────────────────────
MITM_ENV=()
if $USE_MITM; then
REAL_HOME="${SUDO_USER:+$(getent passwd "$SUDO_USER" | cut -d: -f6)}"
REAL_HOME="${REAL_HOME:-$HOME}"
MITM_PORT_FILE="${REAL_HOME}/.config/antigravity-proxy/mitm-port"
CA_PATH="${REAL_HOME}/.config/antigravity-proxy/mitm-ca.pem"
if [[ -f "$MITM_PORT_FILE" ]]; then
MITM_PORT=$(cat "$MITM_PORT_FILE")
else
MITM_PORT="8742"
fi
if [[ ! -f "$CA_PATH" ]]; then
echo "[-] MITM CA cert not found at $CA_PATH"
echo " Start the proxy first to generate it."
exit 1
fi
COMBINED_CA="/tmp/antigravity-mitm-combined-ca.pem"
SYS_CA=""
for candidate in /etc/ssl/certs/ca-certificates.crt /etc/pki/tls/certs/ca-bundle.crt /etc/ssl/cert.pem; do
if [[ -f "$candidate" ]]; then SYS_CA="$candidate"; break; fi
done
if [[ -n "$SYS_CA" ]]; then
cat "$SYS_CA" "$CA_PATH" > "$COMBINED_CA"
else
echo "[-] No system CA bundle found"
exit 1
fi
MITM_ENV=(
"HTTPS_PROXY=http://127.0.0.1:${MITM_PORT}"
"SSL_CERT_FILE=${COMBINED_CA}"
"GRPC_DEFAULT_SSL_ROOTS_FILE_PATH=${COMBINED_CA}"
)
echo "[*] MITM: enabled (port $MITM_PORT)"
else
echo "[*] MITM: disabled (use --mitm to enable)"
fi
# ── LS args ───────────────────────────────────────────────────────────────────
LS_ARGS=(
-enable_lsp
-extension_server_port "$EXT_PORT"
-csrf_token "$MAIN_CSRF"
-server_port "$HTTPS_PORT"
-workspace_id "standalone_$TS"
-cloud_code_endpoint "https://daily-cloudcode-pa.googleapis.com"
-app_data_dir "antigravity-standalone"
-gemini_dir "$DATA_DIR/.gemini"
)
# ── Extra env for snapshot mode ───────────────────────────────────────────────
EXTRA_ENV=()
if $SNAPSHOT; then
EXTRA_ENV=("GODEBUG=http2debug=2")
echo "[*] Snapshot: enabled (HTTP/2 debug tracing)"
fi
# ── Banner ────────────────────────────────────────────────────────────────────
echo ""
echo "========================================="
echo " Standalone LS"
echo " Port: $HTTPS_PORT (HTTPS)"
echo " Data: $DATA_DIR"
echo " Mode: $($FOREGROUND && echo "foreground" || echo "background ($TIMEOUT s)")"
echo "========================================="
echo ""
# ── Foreground mode ───────────────────────────────────────────────────────────
if $FOREGROUND; then
echo "$METADATA" | base64 -d | \
env "${MITM_ENV[@]+"${MITM_ENV[@]}"}" \
"${EXTRA_ENV[@]+"${EXTRA_ENV[@]}"}" \
ANTIGRAVITY_EDITOR_APP_ROOT="/usr/share/antigravity/resources/app" \
exec "$LS_BIN" "${LS_ARGS[@]}"
exit 0
fi
# ── Background mode ──────────────────────────────────────────────────────────
LOG="/tmp/standalone-ls.log"
rm -f "$LOG"
echo "$METADATA" | base64 -d | \
env "${MITM_ENV[@]+"${MITM_ENV[@]}"}" \
"${EXTRA_ENV[@]+"${EXTRA_ENV[@]}"}" \
ANTIGRAVITY_EDITOR_APP_ROOT="/usr/share/antigravity/resources/app" \
timeout "$TIMEOUT" "$LS_BIN" "${LS_ARGS[@]}" \
> "$LOG" 2>&1 &
LS_PID=$!
echo "[*] PID: $LS_PID"
# Wait for init
for i in $(seq 1 5); do
sleep 1
if ! kill -0 "$LS_PID" 2>/dev/null; then
echo "[-] LS died after ${i}s"
echo "=== LOGS ==="
cat "$LOG"
exit 1
fi
done
echo "[+] LS alive and initialized"
# ── Snapshot mode: send a prompt and capture traffic ──────────────────────────
if $SNAPSHOT; then
if [[ -z "$PROMPT" ]]; then
PROMPT="Say exactly: Hello standalone world"
fi
echo ""
echo "[*] Sending cascade: \"$PROMPT\""
CASCADE_ID=$(curl -sk --max-time 10 \
"https://127.0.0.1:${HTTPS_PORT}/exa.language_server_pb.LanguageServerService/StartCascade" \
-H "Content-Type: application/json" \
-H "x-codeium-csrf-token: $MAIN_CSRF" \
-H "Origin: vscode-file://vscode-app" \
-d "{
\"prompt\": \"$PROMPT\",
\"modelOrAlias\": {\"model\": \"$MODEL\"},
\"workspaceRootPaths\": [\"$DATA_DIR\"]
}" 2>/dev/null | python3 -c "import json,sys; print(json.load(sys.stdin).get('cascadeId',''))" 2>/dev/null || true)
echo "[*] Cascade: $CASCADE_ID"
echo "[*] Waiting 15s for upstream API calls..."
sleep 15
# Kill LS to flush logs
kill "$LS_PID" 2>/dev/null
wait "$LS_PID" 2>/dev/null || true
# Parse and display
echo ""
python3 "$SCRIPT_DIR/parse-snapshot.py" "$LOG"
# Also save raw log
SNAPSHOT_FILE="/tmp/standalone-snapshot-$(date +%Y%m%d-%H%M%S).log"
cp "$LOG" "$SNAPSHOT_FILE"
echo ""
echo "[*] Raw log saved to: $SNAPSHOT_FILE"
exit 0
fi
# ── Normal mode: test and report ──────────────────────────────────────────────
echo ""
echo "=== GetUserStatus ==="
curl -sk "https://127.0.0.1:${HTTPS_PORT}/exa.language_server_pb.LanguageServerService/GetUserStatus" \
-H "Content-Type: application/json" \
-H "x-codeium-csrf-token: $MAIN_CSRF" \
-H "Origin: vscode-file://vscode-app" \
-d '{}' 2>/dev/null | python3 -c "
import json, sys
try:
d = json.load(sys.stdin)
us = d.get('userStatus', {})
ps = us.get('planStatus', {})
pi = ps.get('planInfo', {})
print(f'Plan: {pi.get(\"planName\",\"?\")}, Prompt: {ps.get(\"availablePromptCredits\",\"?\")}, Flow: {ps.get(\"availableFlowCredits\",\"?\")}')
ut = us.get('userTier', {})
print(f'Tier: {ut.get(\"name\",\"?\")}')
models = us.get('cascadeModelConfigData', {}).get('clientModelConfigs', [])
print(f'Models: {len(models)}')
for m in models[:5]:
qi = m.get('quotaInfo', {})
print(f' - {m.get(\"label\")}: remaining={qi.get(\"remainingFraction\",\"?\")}')
except Exception as e:
print(f'Error: {e}')
" 2>/dev/null
echo ""
kill "$LS_PID" 2>/dev/null || true
wait "$LS_PID" 2>/dev/null || true
echo "[*] Done"

343
src/api/completions.rs Normal file
View File

@@ -0,0 +1,343 @@
//! OpenAI Chat Completions API (/v1/chat/completions) handler.
use axum::{
extract::State,
http::StatusCode,
response::{sse::Event, IntoResponse, Json, Sse},
};
use rand::Rng;
use std::sync::Arc;
use tracing::{debug, info, warn};
use super::models::{lookup_model, DEFAULT_MODEL, MODELS};
use super::polling::{extract_response_text, is_response_done, poll_for_response};
use super::types::*;
use super::util::{err_response, now_unix};
use super::AppState;
// ─── Input extraction ────────────────────────────────────────────────────────
/// Extract user text from Chat Completions messages array.
fn extract_chat_input(messages: &[CompletionMessage]) -> String {
let mut system_parts = Vec::new();
let mut user_parts = Vec::new();
for msg in messages {
let text = match &msg.content {
serde_json::Value::String(s) => s.clone(),
serde_json::Value::Array(arr) => arr
.iter()
.filter_map(|item| item["text"].as_str())
.collect::<Vec<_>>()
.join("\n"),
_ => continue,
};
match msg.role.as_str() {
"system" | "developer" => system_parts.push(text),
"user" => user_parts.push(text),
_ => {}
}
}
let mut result = String::new();
if !system_parts.is_empty() {
result.push_str(&system_parts.join("\n"));
result.push_str("\n\n");
}
// Use the last user message
if let Some(last) = user_parts.last() {
result.push_str(last);
}
result.trim().to_string()
}
// ─── Handler ─────────────────────────────────────────────────────────────────
/// POST /v1/chat/completions — OpenAI Chat Completions API compatibility shim.
/// Accepts standard messages format, reuses the same backend cascade, and
/// outputs in the Chat Completions streaming/sync format.
pub(crate) async fn handle_completions(
State(state): State<Arc<AppState>>,
Json(body): Json<CompletionRequest>,
) -> axum::response::Response {
let model_name = body.model.as_deref().unwrap_or(DEFAULT_MODEL);
info!(
"POST /v1/chat/completions model={} stream={}",
model_name, body.stream
);
let model = match lookup_model(model_name) {
Some(m) => m,
None => {
let names: Vec<&str> = MODELS.iter().map(|m| m.name).collect();
return err_response(
StatusCode::BAD_REQUEST,
format!("Unknown model: {model_name}. Available: {names:?}"),
"invalid_request_error",
);
}
};
let token = state.backend.oauth_token().await;
if token.is_empty() {
return err_response(
StatusCode::UNAUTHORIZED,
"No OAuth token. POST to /v1/token or set ANTIGRAVITY_OAUTH_TOKEN env var.".into(),
"authentication_error",
);
}
let user_text = extract_chat_input(&body.messages);
if user_text.is_empty() {
return err_response(
StatusCode::BAD_REQUEST,
"No user message found".to_string(),
"invalid_request_error",
);
}
// Fresh cascade per request
let cascade_id = match state.backend.create_cascade().await {
Ok(cid) => cid,
Err(e) => {
return err_response(
StatusCode::BAD_GATEWAY,
format!("StartCascade failed: {e}"),
"server_error",
);
}
};
// Send message
match state
.backend
.send_message(&cascade_id, &user_text, model.model_enum)
.await
{
Ok((200, _)) => {
let bg = Arc::clone(&state.backend);
let cid = cascade_id.clone();
tokio::spawn(async move {
let _ = bg.update_annotations(&cid).await;
});
}
Ok((status, _)) => {
return err_response(
StatusCode::BAD_GATEWAY,
format!("Backend returned {status}"),
"server_error",
);
}
Err(e) => {
return err_response(
StatusCode::BAD_GATEWAY,
format!("Send failed: {e}"),
"server_error",
);
}
}
let completion_id = format!(
"chatcmpl-{}",
uuid::Uuid::new_v4().to_string().replace('-', "")
);
if body.stream {
chat_completions_stream(
state,
completion_id,
model_name.to_string(),
cascade_id,
body.timeout,
)
.await
} else {
chat_completions_sync(
state,
completion_id,
model_name.to_string(),
cascade_id,
body.timeout,
)
.await
}
}
// ─── Streaming ───────────────────────────────────────────────────────────────
/// Streaming output in Chat Completions format.
async fn chat_completions_stream(
state: Arc<AppState>,
completion_id: String,
model_name: String,
cascade_id: String,
timeout: u64,
) -> axum::response::Response {
let stream = async_stream::stream! {
let start = std::time::Instant::now();
let mut last_text = String::new();
// Initial role chunk
yield Ok::<_, std::convert::Infallible>(Event::default().data(serde_json::to_string(&serde_json::json!({
"id": completion_id,
"object": "chat.completion.chunk",
"created": now_unix(),
"model": model_name,
"choices": [{
"index": 0,
"delta": {"role": "assistant", "content": ""},
"finish_reason": serde_json::Value::Null,
}],
})).unwrap_or_default()));
while start.elapsed().as_secs() < timeout {
if let Ok((status, data)) = state.backend.get_steps(&cascade_id).await {
if status == 200 {
if let Some(steps) = data["steps"].as_array() {
let text = extract_response_text(steps);
if !text.is_empty() && text != last_text {
let delta = if text.len() > last_text.len() && text.starts_with(&*last_text) {
&text[last_text.len()..]
} else {
&text
};
if !delta.is_empty() {
yield Ok(Event::default().data(serde_json::to_string(&serde_json::json!({
"id": completion_id,
"object": "chat.completion.chunk",
"created": now_unix(),
"model": model_name,
"choices": [{
"index": 0,
"delta": {"content": delta},
"finish_reason": serde_json::Value::Null,
}],
})).unwrap_or_default()));
last_text = text.to_string();
}
}
// Done check: need DONE status AND non-empty text
if is_response_done(steps) && !last_text.is_empty() {
debug!("Completions stream done, text length={}", last_text.len());
yield Ok(Event::default().data(serde_json::to_string(&serde_json::json!({
"id": completion_id,
"object": "chat.completion.chunk",
"created": now_unix(),
"model": model_name,
"choices": [{
"index": 0,
"delta": {},
"finish_reason": "stop",
}],
})).unwrap_or_default()));
yield Ok(Event::default().data("[DONE]".to_string()));
return;
}
// IDLE fallback: check trajectory status periodically
// Only check every 5th step count to reduce backend traffic
let step_count = steps.len();
if step_count > 4 && step_count % 5 == 0 {
if let Ok((ts, td)) = state.backend.get_trajectory(&cascade_id).await {
if ts == 200 {
let run_status = td["status"].as_str().unwrap_or("");
if run_status.contains("IDLE") && !last_text.is_empty() {
debug!("Completions IDLE, text length={}", last_text.len());
yield Ok(Event::default().data(serde_json::to_string(&serde_json::json!({
"id": completion_id,
"object": "chat.completion.chunk",
"created": now_unix(),
"model": model_name,
"choices": [{
"index": 0,
"delta": {},
"finish_reason": "stop",
}],
})).unwrap_or_default()));
yield Ok(Event::default().data("[DONE]".to_string()));
return;
}
}
}
}
}
}
}
let poll_ms: u64 = rand::thread_rng().gen_range(800..1200);
tokio::time::sleep(tokio::time::Duration::from_millis(poll_ms)).await;
}
// Timeout
warn!("Completions stream timeout after {}s", timeout);
yield Ok(Event::default().data(serde_json::to_string(&serde_json::json!({
"id": completion_id,
"object": "chat.completion.chunk",
"created": now_unix(),
"model": model_name,
"choices": [{
"index": 0,
"delta": {"content": if last_text.is_empty() { "[Timeout waiting for response]" } else { "" }},
"finish_reason": "stop",
}],
})).unwrap_or_default()));
yield Ok(Event::default().data("[DONE]".to_string()));
};
Sse::new(stream)
.keep_alive(
axum::response::sse::KeepAlive::new()
.interval(std::time::Duration::from_secs(15))
.text(""),
)
.into_response()
}
// ─── Sync ────────────────────────────────────────────────────────────────────
/// Sync output in Chat Completions format.
async fn chat_completions_sync(
state: Arc<AppState>,
completion_id: String,
model_name: String,
cascade_id: String,
timeout: u64,
) -> axum::response::Response {
let result = poll_for_response(&state, &cascade_id, timeout).await;
// Check MITM store first for real intercepted usage
let (prompt_tokens, completion_tokens, cached_tokens) = if let Some(mitm_usage) = state.mitm_store.take_usage(&cascade_id).await {
(mitm_usage.input_tokens, mitm_usage.output_tokens, mitm_usage.cache_read_input_tokens)
} else if let Some(u) = &result.usage {
(u.input_tokens, u.output_tokens, 0)
} else {
(0, 0, 0)
};
Json(serde_json::json!({
"id": completion_id,
"object": "chat.completion",
"created": now_unix(),
"model": model_name,
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": result.text,
},
"finish_reason": "stop",
}],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
"prompt_tokens_details": {
"cached_tokens": cached_tokens,
},
},
}))
.into_response()
}

176
src/api/mod.rs Normal file
View File

@@ -0,0 +1,176 @@
//! Axum API server — OpenAI-compatible Responses + Chat Completions endpoints.
mod completions;
mod models;
mod polling;
mod responses;
mod types;
mod util;
use crate::constants::safe_truncate;
use crate::session::SessionManager;
use axum::{
extract::{DefaultBodyLimit, State},
http::StatusCode,
response::{IntoResponse, Json},
routing::{delete, get, post},
Router,
};
use std::sync::Arc;
use tower_http::cors::CorsLayer;
use tracing::warn;
use self::models::MODELS;
use self::types::TokenRequest;
// ─── Shared state ────────────────────────────────────────────────────────────
pub struct AppState {
pub backend: Arc<crate::backend::Backend>,
pub sessions: SessionManager,
pub mitm_store: crate::mitm::store::MitmStore,
pub quota_store: crate::quota::QuotaStore,
}
// ─── Router ──────────────────────────────────────────────────────────────────
pub fn router(state: Arc<AppState>) -> Router {
Router::new()
.route("/v1/responses", post(responses::handle_responses))
.route(
"/v1/chat/completions",
post(completions::handle_completions),
)
.route("/v1/models", get(handle_models))
.route("/v1/sessions", get(handle_list_sessions))
.route("/v1/sessions/{id}", delete(handle_delete_session))
.route("/v1/token", post(handle_set_token))
.route("/v1/usage", get(handle_usage))
.route("/v1/quota", get(handle_quota))
.route("/health", get(handle_health))
.route("/", get(handle_root))
.layer(CorsLayer::permissive())
.layer(DefaultBodyLimit::max(1_048_576)) // 1 MB
.with_state(state)
}
// ─── Simple handlers ─────────────────────────────────────────────────────────
async fn handle_root() -> Json<serde_json::Value> {
Json(serde_json::json!({
"service": "antigravity-openai-proxy",
"version": "3.2.0",
"runtime": "rust",
"endpoints": [
"/v1/chat/completions",
"/v1/responses",
"/v1/models",
"/v1/sessions",
"/v1/token",
"/v1/usage",
"/v1/quota",
"/health",
],
}))
}
async fn handle_health() -> Json<serde_json::Value> {
Json(serde_json::json!({"status": "ok"}))
}
async fn handle_models() -> Json<serde_json::Value> {
let models: Vec<serde_json::Value> = MODELS
.iter()
.map(|m| {
serde_json::json!({
"id": m.name,
"object": "model",
"created": 1700000000u64,
"owned_by": "antigravity",
"permission": [],
"root": m.name,
"parent": null,
"meta": {
"label": m.label,
"enum_value": m.model_enum,
},
})
})
.collect();
Json(serde_json::json!({"object": "list", "data": models}))
}
async fn handle_list_sessions(
State(state): State<Arc<AppState>>,
) -> Json<serde_json::Value> {
let sessions = state.sessions.list_sessions().await;
Json(serde_json::json!({"sessions": sessions}))
}
async fn handle_delete_session(
State(state): State<Arc<AppState>>,
axum::extract::Path(id): axum::extract::Path<String>,
) -> impl IntoResponse {
if state.sessions.delete_session(&id).await {
(
StatusCode::OK,
Json(serde_json::json!({"status": "deleted", "session_id": id})),
)
} else {
(
StatusCode::NOT_FOUND,
Json(serde_json::json!({"error": format!("Session not found: {id}")})),
)
}
}
async fn handle_set_token(
State(state): State<Arc<AppState>>,
Json(body): Json<TokenRequest>,
) -> impl IntoResponse {
if !body.token.starts_with("ya29.") {
return (
StatusCode::BAD_REQUEST,
Json(serde_json::json!({"error": "Invalid token. Must start with ya29."})),
);
}
state.backend.set_oauth_token(body.token.clone()).await;
// Also persist to file
let token_path = crate::constants::token_file_path();
if let Err(e) = tokio::fs::write(&token_path, &body.token).await {
warn!("Failed to write token file: {e}");
}
let preview = safe_truncate(&body.token, 20);
(
StatusCode::OK,
Json(serde_json::json!({"status": "ok", "token_prefix": preview})),
)
}
async fn handle_usage(
State(state): State<Arc<AppState>>,
) -> Json<serde_json::Value> {
let stats = state.mitm_store.stats().await;
Json(serde_json::json!({
"mitm": {
"total_requests": stats.total_requests,
"total_input_tokens": stats.total_input_tokens,
"total_output_tokens": stats.total_output_tokens,
"total_cache_read_tokens": stats.total_cache_read_tokens,
"total_cache_creation_tokens": stats.total_cache_creation_tokens,
"total_thinking_output_tokens": stats.total_thinking_output_tokens,
"total_response_output_tokens": stats.total_response_output_tokens,
"total_tokens": stats.total_input_tokens + stats.total_output_tokens,
"per_model": stats.per_model,
}
}))
}
async fn handle_quota(
State(state): State<Arc<AppState>>,
) -> Json<serde_json::Value> {
let snap = state.quota_store.snapshot().await;
Json(serde_json::to_value(snap).unwrap_or_default())
}

49
src/api/models.rs Normal file
View File

@@ -0,0 +1,49 @@
//! Model definitions and lookup.
/// Model definition: friendly name → (antigravity_id, protobuf_enum, label).
pub(crate) struct ModelDef {
pub name: &'static str,
#[allow(dead_code)]
pub ag_id: &'static str,
pub model_enum: u32,
pub label: &'static str,
}
pub(crate) const MODELS: &[ModelDef] = &[
ModelDef {
name: "opus-4.6",
ag_id: "MODEL_PLACEHOLDER_M26",
model_enum: 1026,
label: "Claude Opus 4.6 (Thinking)",
},
ModelDef {
name: "opus-4.5",
ag_id: "MODEL_PLACEHOLDER_M12",
model_enum: 1012,
label: "Claude Opus 4.5 (Thinking)",
},
ModelDef {
name: "gemini-3-pro-high",
ag_id: "MODEL_PLACEHOLDER_M8",
model_enum: 1008,
label: "Gemini 3 Pro (High)",
},
ModelDef {
name: "gemini-3-pro",
ag_id: "MODEL_PLACEHOLDER_M7",
model_enum: 1007,
label: "Gemini 3 Pro (Low)",
},
ModelDef {
name: "gemini-3-flash",
ag_id: "MODEL_PLACEHOLDER_M18",
model_enum: 1018,
label: "Gemini 3 Flash",
},
];
pub(crate) const DEFAULT_MODEL: &str = "opus-4.6";
pub(crate) fn lookup_model(name: &str) -> Option<&'static ModelDef> {
MODELS.iter().find(|m| m.name == name)
}

298
src/api/polling.rs Normal file
View File

@@ -0,0 +1,298 @@
//! Shared polling engine and step extraction helpers.
use rand::Rng;
use tracing::{debug, info, warn};
use super::AppState;
/// Real token usage reported by the language server.
#[derive(Debug, Clone, Default)]
#[allow(dead_code)]
pub(crate) struct ModelUsage {
pub input_tokens: u64,
pub output_tokens: u64,
pub api_provider: String,
pub model: String,
}
/// Result of polling — text + optional real usage data + thinking data.
pub(crate) struct PollResult {
pub text: String,
pub usage: Option<ModelUsage>,
/// Opaque Anthropic thinking verification signature from PLANNER_RESPONSE.
/// Required for multi-turn thinking model chaining.
pub thinking_signature: Option<String>,
/// The model's internal reasoning/thinking content.
/// Available for both Opus (Anthropic) and Gemini models.
pub thinking: Option<String>,
/// Time the model spent thinking, as reported by the LS (e.g. "0.041999832s").
pub thinking_duration: Option<String>,
}
/// Extract the response text from steps — scans in REVERSE to find the latest response.
pub(crate) fn extract_response_text(steps: &[serde_json::Value]) -> String {
for step in steps.iter().rev() {
let step_type = step["type"].as_str().unwrap_or("");
if step_type.contains("PLANNER_RESPONSE") {
let resp = &step["plannerResponse"];
let text = resp["rawResponse"]
.as_str()
.or_else(|| resp["response"].as_str())
.unwrap_or("");
if !text.is_empty() {
return text.to_string();
}
}
if step_type.contains("AI_RESPONSE") || step_type.contains("MODEL_RESPONSE") {
if let Some(text) = step["response"]
.as_str()
.or_else(|| step["rawResponse"].as_str())
.or_else(|| step["text"].as_str())
{
if !text.is_empty() {
return text.to_string();
}
}
}
}
String::new()
}
/// Extract real token usage from the LS's modelUsage field.
/// The LS reports this in CHECKPOINT steps and sometimes in retryInfos.
/// Scans in reverse to find the latest usage data.
pub(crate) fn extract_model_usage(steps: &[serde_json::Value]) -> Option<ModelUsage> {
for step in steps.iter().rev() {
if let Some(usage) = step.get("metadata").and_then(|m| m.get("modelUsage")) {
let input = usage["inputTokens"]
.as_str()
.and_then(|s| s.parse::<u64>().ok())
.or_else(|| usage["inputTokens"].as_u64())
.unwrap_or(0);
let output = usage["outputTokens"]
.as_str()
.and_then(|s| s.parse::<u64>().ok())
.or_else(|| usage["outputTokens"].as_u64())
.unwrap_or(0);
if input > 0 || output > 0 {
return Some(ModelUsage {
input_tokens: input,
output_tokens: output,
api_provider: usage["apiProvider"]
.as_str()
.unwrap_or("")
.to_string(),
model: usage["model"]
.as_str()
.unwrap_or("")
.to_string(),
});
}
}
}
None
}
/// Extract the thinking signature from PLANNER_RESPONSE steps.
/// This is an opaque Base64 blob used by Anthropic for extended thinking
/// verification. Needed to chain multi-turn conversations with thinking models.
pub(crate) fn extract_thinking_signature(steps: &[serde_json::Value]) -> Option<String> {
for step in steps.iter().rev() {
let step_type = step["type"].as_str().unwrap_or("");
if step_type.contains("PLANNER_RESPONSE") {
if let Some(sig) = step["plannerResponse"]["thinkingSignature"].as_str() {
if !sig.is_empty() {
return Some(sig.to_string());
}
}
}
}
None
}
/// Extract the model's thinking/reasoning content from PLANNER_RESPONSE steps.
/// This is the internal monologue produced during extended thinking.
/// Available for ALL models (Opus, Gemini Flash, Gemini Pro).
pub(crate) fn extract_thinking_content(steps: &[serde_json::Value]) -> Option<String> {
for step in steps.iter().rev() {
let step_type = step["type"].as_str().unwrap_or("");
if step_type.contains("PLANNER_RESPONSE") {
if let Some(thinking) = step["plannerResponse"]["thinking"].as_str() {
if !thinking.is_empty() {
return Some(thinking.to_string());
}
}
}
}
None
}
/// Extract thinking duration from PLANNER_RESPONSE steps.
/// Returns the raw duration string as reported by the LS (e.g. "0.041999832s").
pub(crate) fn extract_thinking_duration(steps: &[serde_json::Value]) -> Option<String> {
for step in steps.iter().rev() {
let step_type = step["type"].as_str().unwrap_or("");
if step_type.contains("PLANNER_RESPONSE") {
if let Some(dur) = step["plannerResponse"]["thinkingDuration"].as_str() {
if !dur.is_empty() {
return Some(dur.to_string());
}
}
}
}
None
}
/// Check if the cascade has truly finished — the last PLANNER_RESPONSE must be DONE
/// AND the very last step must be a terminal type (CHECKPOINT or PLANNER_RESPONSE with DONE).
/// This prevents false positives during agentic tool-call loops where intermediate
/// PLANNER_RESPONSE steps show DONE but the cascade keeps going.
pub(crate) fn is_response_done(steps: &[serde_json::Value]) -> bool {
if steps.is_empty() {
return false;
}
let last = &steps[steps.len() - 1];
let last_type = last["type"].as_str().unwrap_or("");
let last_status = last["status"].as_str().unwrap_or("");
// CHECKPOINT at the end = cascade is definitely done
if last_type.contains("CHECKPOINT") {
return true;
}
// Last step is a PLANNER_RESPONSE with DONE = final answer (no more tool calls coming)
if (last_type.contains("PLANNER_RESPONSE")
|| last_type.contains("AI_RESPONSE")
|| last_type.contains("MODEL_RESPONSE"))
&& last_status.contains("DONE")
{
return true;
}
false
}
/// Poll the backend until we get a response or timeout.
pub(crate) async fn poll_for_response(
state: &AppState,
cascade_id: &str,
timeout: u64,
) -> PollResult {
let start = std::time::Instant::now();
let short_id = &cascade_id[..8.min(cascade_id.len())];
info!("Polling for response on cascade {short_id} (timeout={timeout}s)");
let mut last_step_count: usize = 0;
while start.elapsed().as_secs() < timeout {
if let Ok((status, data)) = state.backend.get_steps(cascade_id).await {
if status == 200 {
if let Some(steps) = data["steps"].as_array() {
let step_count = steps.len();
// Only log when step count changes (denoised)
if step_count != last_step_count {
// Compact type summary: count unique types
let mut type_counts: std::collections::BTreeMap<&str, usize> =
std::collections::BTreeMap::new();
for s in steps.iter() {
let t = s["type"]
.as_str()
.unwrap_or("?")
.strip_prefix("CORTEX_STEP_TYPE_")
.unwrap_or("?");
*type_counts.entry(t).or_insert(0) += 1;
}
let summary: Vec<String> = type_counts
.iter()
.map(|(t, c)| {
if *c > 1 {
format!("{t}×{c}")
} else {
t.to_string()
}
})
.collect();
debug!(
"Poll {short_id}: {step_count} steps [{}]",
summary.join(", ")
);
last_step_count = step_count;
}
// Check if the cascade is truly done
if is_response_done(steps) {
let text = extract_response_text(steps);
if !text.is_empty() {
let usage = extract_model_usage(steps);
let thinking_signature = extract_thinking_signature(steps);
let thinking = extract_thinking_content(steps);
let thinking_duration = extract_thinking_duration(steps);
let elapsed = start.elapsed().as_secs_f32();
if let Some(ref u) = usage {
info!(
"Response done ({short_id}), {:.1}s, {} chars, tokens: {}in/{}out ({}){}{}",
elapsed, text.len(), u.input_tokens, u.output_tokens, u.model,
if thinking.is_some() { format!(", thinking: {} chars", thinking.as_ref().unwrap().len()) } else { String::new() },
if thinking_signature.is_some() { ", has sig" } else { "" }
);
} else {
info!(
"Response done ({short_id}), {:.1}s, {} chars (no usage){}{}",
elapsed, text.len(),
if thinking.is_some() { format!(", thinking: {} chars", thinking.as_ref().unwrap().len()) } else { String::new() },
if thinking_signature.is_some() { ", has sig" } else { "" }
);
}
return PollResult { text, usage, thinking_signature, thinking, thinking_duration };
}
}
// Fallback: check trajectory IDLE status (catches edge cases)
// Only check every 5th poll to reduce network calls
if step_count > 4 && step_count % 5 == 0 {
if let Ok((ts, td)) = state.backend.get_trajectory(cascade_id).await
{
if ts == 200 {
let run_status =
td["status"].as_str().unwrap_or("");
if run_status.contains("IDLE") {
let text = extract_response_text(steps);
if !text.is_empty() {
let usage = extract_model_usage(steps);
let thinking_signature = extract_thinking_signature(steps);
let thinking = extract_thinking_content(steps);
let thinking_duration = extract_thinking_duration(steps);
let elapsed = start.elapsed().as_secs_f32();
info!(
"Trajectory IDLE ({short_id}), {:.1}s, {} chars",
elapsed,
text.len()
);
return PollResult { text, usage, thinking_signature, thinking, thinking_duration };
}
}
}
}
}
}
}
}
let poll_ms: u64 = rand::thread_rng().gen_range(1000..1800);
tokio::time::sleep(tokio::time::Duration::from_millis(poll_ms)).await;
}
warn!("Timeout after {timeout}s on cascade {short_id}");
PollResult {
text: "[Timeout waiting for AI response]".to_string(),
usage: None,
thinking_signature: None,
thinking: None,
thinking_duration: None,
}
}

686
src/api/responses.rs Normal file
View File

@@ -0,0 +1,686 @@
//! OpenAI Responses API (/v1/responses) handler.
//!
//! Strictly adheres to the official OpenAI Responses API protocol:
//! https://platform.openai.com/docs/api-reference/responses
use axum::{
extract::State,
http::StatusCode,
response::{sse::Event, IntoResponse, Json, Sse},
};
use rand::Rng;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
use tracing::{debug, info};
use super::models::{lookup_model, DEFAULT_MODEL, MODELS};
use super::polling::{extract_response_text, is_response_done, poll_for_response, extract_model_usage, extract_thinking_signature, extract_thinking_content, extract_thinking_duration};
use super::types::*;
use super::util::{err_response, now_unix, responses_sse_event};
use super::AppState;
// ─── Input extraction ────────────────────────────────────────────────────────
/// Extract user text from Responses API `input` field.
fn extract_responses_input(input: &serde_json::Value, instructions: Option<&str>) -> String {
let user_text = match input {
serde_json::Value::String(s) => s.clone(),
serde_json::Value::Array(items) => {
items
.iter()
.rev()
.find(|item| item["role"].as_str() == Some("user"))
.and_then(|item| match &item["content"] {
serde_json::Value::String(s) => Some(s.clone()),
serde_json::Value::Array(parts) => Some(
parts
.iter()
.filter(|p| {
let t = p["type"].as_str().unwrap_or("");
t == "input_text" || t == "text"
})
.filter_map(|p| p["text"].as_str())
.collect::<Vec<_>>()
.join(" "),
),
_ => None,
})
.unwrap_or_default()
}
_ => String::new(),
};
match instructions {
Some(inst) if !inst.is_empty() => format!("{inst}\n\n{user_text}"),
_ => user_text,
}
}
/// Extract conversation/session ID from Responses API `conversation` field.
fn extract_conversation_id(conv: &Option<serde_json::Value>) -> Option<String> {
match conv {
Some(serde_json::Value::String(s)) => Some(s.clone()),
Some(obj) => obj["id"].as_str().map(|s| s.to_string()),
None => None,
}
}
/// Build a full Response object matching the official OpenAI schema.
fn build_response_object(
id: &str,
model: &str,
status: &'static str,
created_at: u64,
completed_at: Option<u64>,
output: Vec<ResponseOutput>,
usage: Option<Usage>,
instructions: Option<&str>,
store: bool,
temperature: f64,
top_p: f64,
max_output_tokens: Option<u64>,
previous_response_id: Option<&str>,
user: Option<&str>,
metadata: &serde_json::Value,
thinking_signature: Option<String>,
thinking: Option<String>,
thinking_duration: Option<String>,
) -> ResponsesResponse {
ResponsesResponse {
id: id.to_string(),
object: "response",
created_at,
status,
completed_at,
error: None,
incomplete_details: None,
instructions: instructions.map(|s| s.to_string()),
max_output_tokens,
model: model.to_string(),
output,
parallel_tool_calls: true,
previous_response_id: previous_response_id.map(|s| s.to_string()),
reasoning: Reasoning::default(),
store,
temperature,
text: TextFormat::default(),
tool_choice: "auto",
tools: vec![],
top_p,
truncation: "disabled",
usage,
user: user.map(|s| s.to_string()),
metadata: metadata.clone(),
thinking_signature,
thinking,
thinking_duration,
}
}
/// Serialize a ResponsesResponse to serde_json::Value for SSE embedding.
fn response_to_json(resp: &ResponsesResponse) -> serde_json::Value {
serde_json::to_value(resp).unwrap_or(serde_json::json!({}))
}
// ─── Handler ─────────────────────────────────────────────────────────────────
pub(crate) async fn handle_responses(
State(state): State<Arc<AppState>>,
Json(body): Json<ResponsesRequest>,
) -> axum::response::Response {
info!(
"POST /v1/responses model={} stream={}",
body.model.as_deref().unwrap_or(DEFAULT_MODEL),
body.stream
);
let model_name = body.model.as_deref().unwrap_or(DEFAULT_MODEL);
let model = match lookup_model(model_name) {
Some(m) => m,
None => {
let names: Vec<&str> = MODELS.iter().map(|m| m.name).collect();
return err_response(
StatusCode::BAD_REQUEST,
format!("Unknown model: {model_name}. Available: {names:?}"),
"invalid_request_error",
);
}
};
let token = state.backend.oauth_token().await;
if token.is_empty() {
return err_response(
StatusCode::UNAUTHORIZED,
"No OAuth token. POST to /v1/token or set ANTIGRAVITY_OAUTH_TOKEN env var.".into(),
"authentication_error",
);
}
let user_text = extract_responses_input(&body.input, body.instructions.as_deref());
if user_text.is_empty() {
return err_response(
StatusCode::BAD_REQUEST,
"No user input found".to_string(),
"invalid_request_error",
);
}
let response_id = format!(
"resp_{}",
uuid::Uuid::new_v4().to_string().replace('-', "")
);
// Session/conversation management
let session_id_str = extract_conversation_id(&body.conversation);
let cascade_id = if let Some(ref sid) = session_id_str {
match state
.sessions
.get_or_create(Some(sid), || state.backend.create_cascade())
.await
{
Ok(sr) => sr.cascade_id,
Err(e) => {
return err_response(
StatusCode::BAD_GATEWAY,
format!("StartCascade failed: {e}"),
"server_error",
);
}
}
} else {
match state.backend.create_cascade().await {
Ok(cid) => cid,
Err(e) => {
return err_response(
StatusCode::BAD_GATEWAY,
format!("StartCascade failed: {e}"),
"server_error",
);
}
}
};
// Send message
match state
.backend
.send_message(&cascade_id, &user_text, model.model_enum)
.await
{
Ok((status, _)) if status == 200 => {
let bg = Arc::clone(&state.backend);
let cid = cascade_id.clone();
tokio::spawn(async move {
let _ = bg.update_annotations(&cid).await;
});
}
Ok((status, _)) => {
return err_response(
StatusCode::BAD_GATEWAY,
format!("Antigravity returned {status}"),
"server_error",
);
}
Err(e) => {
return err_response(
StatusCode::BAD_GATEWAY,
format!("Send message failed: {e}"),
"server_error",
);
}
}
// Capture request params for response building
let req_params = RequestParams {
user_text: user_text.clone(),
instructions: body.instructions.clone(),
store: body.store,
temperature: body.temperature.unwrap_or(1.0),
top_p: body.top_p.unwrap_or(1.0),
max_output_tokens: body.max_output_tokens,
previous_response_id: body.previous_response_id.clone(),
user: body.user.clone(),
metadata: body.metadata.clone().unwrap_or(serde_json::json!({})),
};
if body.stream {
handle_responses_stream(
state, response_id, model_name.to_string(), cascade_id,
body.timeout, req_params,
)
.await
} else {
handle_responses_sync(
state, response_id, model_name.to_string(), cascade_id,
body.timeout, req_params,
)
.await
}
}
/// Captured request parameters needed to echo back in the response.
struct RequestParams {
user_text: String,
instructions: Option<String>,
store: bool,
temperature: f64,
top_p: f64,
max_output_tokens: Option<u64>,
previous_response_id: Option<String>,
user: Option<String>,
metadata: serde_json::Value,
}
/// Build Usage from the best available source:
/// 1. MITM intercepted data (real API tokens, including cache stats)
/// 2. LS trajectory data (real tokens, no cache info)
/// 3. Estimation from text lengths (fallback)
async fn usage_from_poll(
mitm_store: &crate::mitm::store::MitmStore,
cascade_id: &str,
model_usage: &Option<super::polling::ModelUsage>,
input_text: &str,
output_text: &str,
) -> Usage {
// Priority 1: MITM intercepted data (most accurate — includes cache tokens)
if let Some(mitm_usage) = mitm_store.take_usage(cascade_id).await {
tracing::debug!(
input = mitm_usage.input_tokens,
output = mitm_usage.output_tokens,
cache_read = mitm_usage.cache_read_input_tokens,
cache_create = mitm_usage.cache_creation_input_tokens,
thinking = mitm_usage.thinking_output_tokens,
"Using MITM intercepted usage"
);
return Usage {
input_tokens: mitm_usage.input_tokens,
input_tokens_details: InputTokensDetails {
cached_tokens: mitm_usage.cache_read_input_tokens,
},
output_tokens: mitm_usage.output_tokens,
output_tokens_details: OutputTokensDetails {
reasoning_tokens: mitm_usage.thinking_output_tokens,
},
total_tokens: mitm_usage.input_tokens + mitm_usage.output_tokens,
};
}
// Priority 2: LS trajectory data (from CHECKPOINT/metadata steps)
if let Some(u) = model_usage {
return Usage {
input_tokens: u.input_tokens,
input_tokens_details: InputTokensDetails { cached_tokens: 0 },
output_tokens: u.output_tokens,
output_tokens_details: OutputTokensDetails { reasoning_tokens: 0 },
total_tokens: u.input_tokens + u.output_tokens,
};
}
// Priority 3: Estimate from text lengths
Usage::estimate(input_text, output_text)
}
// ─── Sync response ───────────────────────────────────────────────────────────
async fn handle_responses_sync(
state: Arc<AppState>,
response_id: String,
model_name: String,
cascade_id: String,
timeout: u64,
params: RequestParams,
) -> axum::response::Response {
let created_at = now_unix();
let poll_result = poll_for_response(&state, &cascade_id, timeout).await;
let completed_at = now_unix();
let msg_id = format!(
"msg_{}",
uuid::Uuid::new_v4().to_string().replace('-', "")
);
let usage = usage_from_poll(&state.mitm_store, &cascade_id, &poll_result.usage, &params.user_text, &poll_result.text).await;
let resp = build_response_object(
&response_id,
&model_name,
"completed",
created_at,
Some(completed_at),
vec![ResponseOutput {
output_type: "message",
id: msg_id,
status: "completed",
role: "assistant",
content: vec![OutputContent {
content_type: "output_text",
text: poll_result.text,
annotations: vec![],
}],
}],
Some(usage),
params.instructions.as_deref(),
params.store,
params.temperature,
params.top_p,
params.max_output_tokens,
params.previous_response_id.as_deref(),
params.user.as_deref(),
&params.metadata,
poll_result.thinking_signature,
poll_result.thinking,
poll_result.thinking_duration,
);
Json(resp).into_response()
}
// ─── Streaming response ─────────────────────────────────────────────────────
async fn handle_responses_stream(
state: Arc<AppState>,
response_id: String,
model_name: String,
cascade_id: String,
timeout: u64,
params: RequestParams,
) -> axum::response::Response {
let stream = async_stream::stream! {
let msg_id = format!("msg_{}", uuid::Uuid::new_v4().to_string().replace('-', ""));
let created_at = now_unix();
let seq = AtomicU32::new(0);
let next_seq = || seq.fetch_add(1, Ordering::Relaxed);
const CONTENT_IDX: u32 = 0;
const OUTPUT_IDX: u32 = 0;
// Build the in-progress response shell (no output yet)
let in_progress_resp = build_response_object(
&response_id, &model_name, "in_progress", created_at, None,
vec![], None,
params.instructions.as_deref(), params.store,
params.temperature, params.top_p,
params.max_output_tokens, params.previous_response_id.as_deref(),
params.user.as_deref(), &params.metadata,
None, None, None,
);
let resp_json = response_to_json(&in_progress_resp);
// 1. response.created
yield Ok::<_, std::convert::Infallible>(responses_sse_event(
"response.created",
serde_json::json!({
"type": "response.created",
"sequence_number": next_seq(),
"response": resp_json,
}),
));
// 2. response.in_progress
yield Ok(responses_sse_event(
"response.in_progress",
serde_json::json!({
"type": "response.in_progress",
"sequence_number": next_seq(),
"response": resp_json,
}),
));
// 3. response.output_item.added
yield Ok(responses_sse_event(
"response.output_item.added",
serde_json::json!({
"type": "response.output_item.added",
"sequence_number": next_seq(),
"output_index": OUTPUT_IDX,
"item": {
"type": "message",
"id": &msg_id,
"status": "in_progress",
"role": "assistant",
"content": [],
}
}),
));
// 4. response.content_part.added
yield Ok(responses_sse_event(
"response.content_part.added",
serde_json::json!({
"type": "response.content_part.added",
"sequence_number": next_seq(),
"output_index": OUTPUT_IDX,
"content_index": CONTENT_IDX,
"part": {
"type": "output_text",
"text": "",
"annotations": [],
}
}),
));
// 5. Poll and emit text deltas
let start = std::time::Instant::now();
let mut last_text = String::new();
while start.elapsed().as_secs() < timeout {
if let Ok((status, data)) = state.backend.get_steps(&cascade_id).await {
if status == 200 {
if let Some(steps) = data["steps"].as_array() {
let text = extract_response_text(steps);
if !text.is_empty() && text != last_text {
let new_content = if text.len() > last_text.len()
&& text.starts_with(&*last_text)
{
&text[last_text.len()..]
} else {
&text
};
if !new_content.is_empty() {
yield Ok(responses_sse_event(
"response.output_text.delta",
serde_json::json!({
"type": "response.output_text.delta",
"sequence_number": next_seq(),
"item_id": &msg_id,
"output_index": OUTPUT_IDX,
"content_index": CONTENT_IDX,
"delta": new_content,
}),
));
last_text = text.to_string();
}
}
// Check if response is done AND we have text
if is_response_done(steps) && !last_text.is_empty() {
debug!("Response done, text length={}", last_text.len());
let mu = extract_model_usage(steps);
let usage = usage_from_poll(&state.mitm_store, &cascade_id, &mu, &params.user_text, &last_text).await;
let ts = extract_thinking_signature(steps);
let tc = extract_thinking_content(steps);
let td = extract_thinking_duration(steps);
for evt in completion_events(
&response_id, &model_name, &msg_id,
OUTPUT_IDX, CONTENT_IDX, &last_text, usage,
created_at, &seq, &params, ts, tc, td,
) {
yield Ok(evt);
}
return;
}
// IDLE fallback: check trajectory status periodically
let step_count = steps.len();
if step_count > 4 && step_count % 5 == 0 {
if let Ok((ts, td)) = state.backend.get_trajectory(&cascade_id).await {
if ts == 200 {
let run_status = td["status"].as_str().unwrap_or("");
if run_status.contains("IDLE") && !last_text.is_empty() {
debug!("Trajectory IDLE, text length={}", last_text.len());
let mu = extract_model_usage(steps);
let usage = usage_from_poll(&state.mitm_store, &cascade_id, &mu, &params.user_text, &last_text).await;
let ts = extract_thinking_signature(steps);
let tc = extract_thinking_content(steps);
let td = extract_thinking_duration(steps);
for evt in completion_events(
&response_id, &model_name, &msg_id,
OUTPUT_IDX, CONTENT_IDX, &last_text, usage,
created_at, &seq, &params, ts, tc, td,
) {
yield Ok(evt);
}
return;
}
}
}
}
}
}
}
let poll_ms: u64 = rand::thread_rng().gen_range(800..1200);
tokio::time::sleep(tokio::time::Duration::from_millis(poll_ms)).await;
}
// Timeout — emit incomplete response
let timeout_resp = build_response_object(
&response_id, &model_name, "incomplete", created_at, None,
vec![], Some(Usage::estimate(&params.user_text, "")),
params.instructions.as_deref(), params.store,
params.temperature, params.top_p,
params.max_output_tokens, params.previous_response_id.as_deref(),
params.user.as_deref(), &params.metadata,
None, None, None,
);
yield Ok(responses_sse_event(
"response.completed",
serde_json::json!({
"type": "response.completed",
"sequence_number": next_seq(),
"response": response_to_json(&timeout_resp),
}),
));
};
Sse::new(stream)
.keep_alive(
axum::response::sse::KeepAlive::new()
.interval(std::time::Duration::from_secs(15))
.text(""),
)
.into_response()
}
// ─── SSE completion events ───────────────────────────────────────────────────
/// Build the completion SSE events sequence matching the official protocol:
/// 1. response.output_text.done
/// 2. response.content_part.done
/// 3. response.output_item.done
/// 4. response.completed
fn completion_events(
resp_id: &str,
model: &str,
msg_id: &str,
out_idx: u32,
content_idx: u32,
text: &str,
usage: Usage,
created_at: u64,
seq: &AtomicU32,
params: &RequestParams,
thinking_signature: Option<String>,
thinking: Option<String>,
thinking_duration: Option<String>,
) -> Vec<Event> {
let next_seq = || seq.fetch_add(1, Ordering::Relaxed);
let completed_at = now_unix();
let output_item = serde_json::json!({
"type": "message",
"id": msg_id,
"status": "completed",
"role": "assistant",
"content": [{
"type": "output_text",
"text": text,
"annotations": [],
}],
});
let completed_resp = build_response_object(
resp_id, model, "completed", created_at, Some(completed_at),
vec![ResponseOutput {
output_type: "message",
id: msg_id.to_string(),
status: "completed",
role: "assistant",
content: vec![OutputContent {
content_type: "output_text",
text: text.to_string(),
annotations: vec![],
}],
}],
Some(usage),
params.instructions.as_deref(),
params.store,
params.temperature,
params.top_p,
params.max_output_tokens,
params.previous_response_id.as_deref(),
params.user.as_deref(),
&params.metadata,
thinking_signature,
thinking,
thinking_duration,
);
vec![
// 1. response.output_text.done
responses_sse_event(
"response.output_text.done",
serde_json::json!({
"type": "response.output_text.done",
"sequence_number": next_seq(),
"item_id": msg_id,
"output_index": out_idx,
"content_index": content_idx,
"text": text,
}),
),
// 2. response.content_part.done
responses_sse_event(
"response.content_part.done",
serde_json::json!({
"type": "response.content_part.done",
"sequence_number": next_seq(),
"output_index": out_idx,
"content_index": content_idx,
"part": {
"type": "output_text",
"text": text,
"annotations": [],
},
}),
),
// 3. response.output_item.done
responses_sse_event(
"response.output_item.done",
serde_json::json!({
"type": "response.output_item.done",
"sequence_number": next_seq(),
"output_index": out_idx,
"item": output_item,
}),
),
// 4. response.completed
responses_sse_event(
"response.completed",
serde_json::json!({
"type": "response.completed",
"sequence_number": next_seq(),
"response": response_to_json(&completed_resp),
}),
),
]
}

241
src/api/types.rs Normal file
View File

@@ -0,0 +1,241 @@
//! Request/response types for the OpenAI-compatible API.
//!
//! All response shapes strictly match the official OpenAI Responses API spec:
//! https://platform.openai.com/docs/api-reference/responses
use serde::{Deserialize, Serialize};
// ─── Request types ───────────────────────────────────────────────────────────
#[derive(Deserialize)]
pub(crate) struct ResponsesRequest {
pub model: Option<String>,
pub input: serde_json::Value,
#[serde(default)]
pub instructions: Option<String>,
#[serde(default)]
pub stream: bool,
#[serde(default = "default_timeout")]
pub timeout: u64,
pub conversation: Option<serde_json::Value>,
#[serde(default = "default_true")]
pub store: bool,
#[serde(default)]
pub temperature: Option<f64>,
#[serde(default)]
pub top_p: Option<f64>,
#[serde(default)]
pub max_output_tokens: Option<u64>,
#[serde(default)]
pub previous_response_id: Option<String>,
#[serde(default)]
pub metadata: Option<serde_json::Value>,
#[serde(default)]
pub user: Option<String>,
}
/// Chat Completions request (OpenAI-compatible).
#[derive(Deserialize)]
pub(crate) struct CompletionRequest {
pub model: Option<String>,
pub messages: Vec<CompletionMessage>,
#[serde(default)]
pub stream: bool,
#[serde(default = "default_timeout")]
pub timeout: u64,
}
#[derive(Deserialize)]
pub(crate) struct CompletionMessage {
pub role: String,
pub content: serde_json::Value,
}
fn default_timeout() -> u64 {
120
}
fn default_true() -> bool {
true
}
// ─── Response types (official OpenAI Responses API shape) ────────────────────
/// Top-level Response object — matches OpenAI exactly.
#[derive(Serialize, Clone)]
pub(crate) struct ResponsesResponse {
pub id: String,
pub object: &'static str,
pub created_at: u64,
pub status: &'static str,
#[serde(serialize_with = "serialize_option_u64")]
pub completed_at: Option<u64>,
pub error: Option<serde_json::Value>,
pub incomplete_details: Option<serde_json::Value>,
pub instructions: Option<String>,
#[serde(serialize_with = "serialize_option_u64")]
pub max_output_tokens: Option<u64>,
pub model: String,
pub output: Vec<ResponseOutput>,
pub parallel_tool_calls: bool,
pub previous_response_id: Option<String>,
pub reasoning: Reasoning,
pub store: bool,
pub temperature: f64,
pub text: TextFormat,
pub tool_choice: &'static str,
pub tools: Vec<serde_json::Value>,
pub top_p: f64,
pub truncation: &'static str,
pub usage: Option<Usage>,
pub user: Option<String>,
pub metadata: serde_json::Value,
/// Proxy extension: opaque thinking verification signature.
/// Present for all models. Required for multi-turn chaining with thinking models.
#[serde(skip_serializing_if = "Option::is_none")]
pub thinking_signature: Option<String>,
/// Proxy extension: the model's internal reasoning/thinking content.
/// Available for all models (Opus, Gemini Flash, Gemini Pro).
#[serde(skip_serializing_if = "Option::is_none")]
pub thinking: Option<String>,
/// Proxy extension: time spent thinking (e.g. "0.041999832s").
#[serde(skip_serializing_if = "Option::is_none")]
pub thinking_duration: Option<String>,
}
#[derive(Serialize, Clone)]
pub(crate) struct ResponseOutput {
#[serde(rename = "type")]
pub output_type: &'static str,
pub id: String,
pub status: &'static str,
pub role: &'static str,
pub content: Vec<OutputContent>,
}
#[derive(Serialize, Clone)]
pub(crate) struct OutputContent {
#[serde(rename = "type")]
pub content_type: &'static str,
pub text: String,
pub annotations: Vec<serde_json::Value>,
}
#[derive(Serialize, Clone)]
pub(crate) struct Usage {
pub input_tokens: u64,
pub input_tokens_details: InputTokensDetails,
pub output_tokens: u64,
pub output_tokens_details: OutputTokensDetails,
pub total_tokens: u64,
}
#[derive(Serialize, Clone)]
pub(crate) struct InputTokensDetails {
pub cached_tokens: u64,
}
#[derive(Serialize, Clone)]
pub(crate) struct OutputTokensDetails {
pub reasoning_tokens: u64,
}
#[derive(Serialize, Clone)]
pub(crate) struct Reasoning {
pub effort: Option<String>,
pub summary: Option<String>,
}
#[derive(Serialize, Clone)]
pub(crate) struct TextFormat {
pub format: TextFormatInner,
}
#[derive(Serialize, Clone)]
pub(crate) struct TextFormatInner {
#[serde(rename = "type")]
pub format_type: &'static str,
}
impl Usage {
/// Estimate token counts from actual text.
/// Uses ~4 chars/token heuristic (standard GPT tokenizer average).
pub fn estimate(input_text: &str, output_text: &str) -> Self {
let input_tokens = (input_text.len() as u64 + 3) / 4;
let output_tokens = (output_text.len() as u64 + 3) / 4;
Self {
input_tokens,
input_tokens_details: InputTokensDetails { cached_tokens: 0 },
output_tokens,
output_tokens_details: OutputTokensDetails {
reasoning_tokens: 0,
},
total_tokens: input_tokens + output_tokens,
}
}
}
impl Default for Usage {
fn default() -> Self {
Self {
input_tokens: 0,
input_tokens_details: InputTokensDetails { cached_tokens: 0 },
output_tokens: 0,
output_tokens_details: OutputTokensDetails {
reasoning_tokens: 0,
},
total_tokens: 0,
}
}
}
impl Default for Reasoning {
fn default() -> Self {
Self {
effort: None,
summary: None,
}
}
}
impl Default for TextFormat {
fn default() -> Self {
Self {
format: TextFormatInner {
format_type: "text",
},
}
}
}
// ─── Helpers ─────────────────────────────────────────────────────────────────
/// Serialize Option<u64> as either the number or JSON null (not omitted).
fn serialize_option_u64<S>(val: &Option<u64>, s: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
match val {
Some(v) => s.serialize_u64(*v),
None => s.serialize_none(),
}
}
// ─── Shared types ────────────────────────────────────────────────────────────
#[derive(Deserialize)]
pub(crate) struct TokenRequest {
pub token: String,
}
#[derive(Serialize)]
pub(crate) struct ErrorResponse {
pub error: ErrorDetail,
}
#[derive(Serialize)]
pub(crate) struct ErrorDetail {
pub message: String,
#[serde(rename = "type")]
pub error_type: String,
}

36
src/api/util.rs Normal file
View File

@@ -0,0 +1,36 @@
//! Shared utilities for API handlers.
use axum::{
http::StatusCode,
response::{sse::Event, IntoResponse, Json},
};
use std::time::{SystemTime, UNIX_EPOCH};
use super::types::{ErrorDetail, ErrorResponse};
pub(crate) fn err_response(
status: StatusCode,
message: String,
error_type: &str,
) -> axum::response::Response {
let body = ErrorResponse {
error: ErrorDetail {
message,
error_type: error_type.to_string(),
},
};
(status, Json(body)).into_response()
}
pub(crate) fn now_unix() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
}
pub(crate) fn responses_sse_event(event_type: &str, data: serde_json::Value) -> Event {
Event::default()
.event(event_type)
.data(serde_json::to_string(&data).unwrap())
}

462
src/backend.rs Normal file
View File

@@ -0,0 +1,462 @@
//! Backend: discovery of the local Antigravity language server and HTTP client.
//!
//! Uses wreq (BoringSSL) to impersonate Chrome's TLS + HTTP/2 fingerprint,
//! making our requests indistinguishable from the real Electron webview.
use crate::constants::*;
use flate2::read::{DeflateDecoder, GzDecoder};
use std::fs;
use std::io::Read;
use std::process::Command;
use std::sync::LazyLock;
use tokio::sync::RwLock;
use tracing::{debug, info, warn};
use wreq::header::{HeaderMap, HeaderName, HeaderValue};
/// Connection details for the local language server.
pub struct Backend {
inner: RwLock<BackendInner>,
client: wreq::Client,
}
struct BackendInner {
pid: String,
csrf: String,
https_port: String,
oauth_token: String,
}
/// Static headers that never change — built once, in Chrome's exact emission order.
///
/// Order matters: wreq preserves insertion order in HTTP/2 HEADERS frames.
/// This matches the order captured from Chrome DevTools on the real webview.
static STATIC_HEADERS: LazyLock<HeaderMap> = LazyLock::new(|| {
let mut h = HeaderMap::with_capacity(14);
// Chrome order: Origin → UA → Accept → Accept-Encoding → Accept-Language
// → sec-ch-ua → sec-ch-ua-mobile → sec-ch-ua-platform
// → Sec-Fetch-Dest → Sec-Fetch-Mode → Sec-Fetch-Site
// → Referer → Priority → Connect-Protocol-Version
h.insert("Origin", hv("vscode-file://vscode-app"));
h.insert("User-Agent", hv(&USER_AGENT));
h.insert("Accept", hv("*/*"));
h.insert("Accept-Encoding", hv("gzip, deflate, br, zstd"));
h.insert("Accept-Language", hv("en-US"));
h.insert(
HeaderName::from_static("sec-ch-ua"),
hv(&format!(
"\"Not_A Brand\";v=\"99\", \"Chromium\";v=\"{}\"",
*CHROME_MAJOR,
)),
);
h.insert(
HeaderName::from_static("sec-ch-ua-mobile"),
hv("?0"),
);
h.insert(
HeaderName::from_static("sec-ch-ua-platform"),
hv("\"Linux\""),
);
h.insert("Sec-Fetch-Dest", hv("empty"));
h.insert("Sec-Fetch-Mode", hv("cors"));
h.insert("Sec-Fetch-Site", hv("cross-site"));
h.insert("Priority", hv("u=1, i"));
h.insert("Connect-Protocol-Version", hv("1"));
h
});
impl Backend {
/// Discover the running language server and build a BoringSSL-backed connection.
pub fn new() -> Result<Self, String> {
let inner = discover()?;
// wreq with Chrome impersonation: BoringSSL + Chrome JA3/JA4 + H2 fingerprint
let client = wreq::Client::builder()
.emulation(wreq_util::Emulation::Chrome142)
.cert_verification(false) // LS uses self-signed cert
.verify_hostname(false)
.build()
.map_err(|e| format!("wreq client build failed: {e}"))?;
Ok(Self {
inner: RwLock::new(inner),
client,
})
}
/// Re-discover language server connection details.
/// Runs blocking I/O on a spawn_blocking thread to avoid starving tokio.
pub async fn refresh(&self) -> Result<(), String> {
let new_inner = tokio::task::spawn_blocking(discover)
.await
.map_err(|e| format!("spawn_blocking failed: {e}"))??;
let mut guard = self.inner.write().await;
*guard = new_inner;
Ok(())
}
/// Get current connection info (for startup banner).
pub async fn info(&self) -> (String, String, String, String) {
let guard = self.inner.read().await;
let token_preview = if guard.oauth_token.is_empty() {
"NOT SET".to_string()
} else {
safe_truncate(&guard.oauth_token, 20)
};
let csrf_preview = safe_truncate(&guard.csrf, 8);
(
guard.pid.clone(),
guard.https_port.clone(),
csrf_preview,
token_preview,
)
}
/// Get current OAuth token.
///
/// Priority: token file > env var > cached value.
/// Uses async I/O for file reads. Single write-lock acquisition
/// eliminates the TOCTOU race of read-check-then-write.
pub async fn oauth_token(&self) -> String {
// Check file first (async I/O — won't block tokio)
let token_path = token_file_path();
if let Ok(contents) = tokio::fs::read_to_string(&token_path).await {
let token = contents.trim().to_string();
if !token.is_empty() && token.starts_with("ya29.") {
// Single lock: compare-and-set atomically
let mut guard = self.inner.write().await;
if guard.oauth_token != token {
info!("Token updated from file");
guard.oauth_token = token.clone();
}
return token;
}
}
// Then env var
if let Ok(env_token) = std::env::var("ANTIGRAVITY_OAUTH_TOKEN") {
if !env_token.is_empty() {
let mut guard = self.inner.write().await;
if guard.oauth_token != env_token {
info!("Token updated from env var");
guard.oauth_token = env_token.clone();
}
return env_token;
}
}
self.inner.read().await.oauth_token.clone()
}
/// Fire-and-forget: update conversation annotations alongside SendUserCascadeMessage.
///
/// The real webview calls this after every message to track lastUserViewTime.
/// Without it, the LS sees messages without annotation updates — a fingerprint.
pub async fn update_annotations(&self, cascade_id: &str) -> Result<(), String> {
let now = chrono::Utc::now().to_rfc3339_opts(chrono::SecondsFormat::Millis, true);
let body = serde_json::json!({
"cascadeId": cascade_id,
"annotations": {
"lastUserViewTime": now
},
"mergeAnnotations": true
});
match self.call_json("UpdateConversationAnnotations", &body).await {
Ok((status, _)) => {
debug!("UpdateConversationAnnotations: {status}");
Ok(())
}
Err(e) => {
warn!("UpdateConversationAnnotations failed: {e}");
Err(e)
}
}
}
/// Set OAuth token at runtime.
pub async fn set_oauth_token(&self, token: String) {
let mut guard = self.inner.write().await;
guard.oauth_token = token;
}
// ─── RPC calls ──────────────────────────────────────────────────────
/// Common headers: clone cached static + insert per-request CSRF.
fn common_headers(csrf: &str) -> HeaderMap {
let mut h = STATIC_HEADERS.clone();
if let Ok(val) = HeaderValue::from_str(csrf) {
h.insert(
HeaderName::from_static("x-codeium-csrf-token"),
val,
);
} else {
warn!("CSRF token contains invalid header characters, omitting");
}
h
}
/// Call a JSON RPC method on the language server.
pub async fn call_json(
&self,
method: &str,
body: &serde_json::Value,
) -> Result<(u16, serde_json::Value), String> {
let (base, csrf) = {
let guard = self.inner.read().await;
(
format!("https://127.0.0.1:{}", guard.https_port),
guard.csrf.clone(),
)
};
let url = format!("{base}/{LS_SERVICE}/{method}");
let mut headers = Self::common_headers(&csrf);
headers.insert("Content-Type", HeaderValue::from_static("application/json"));
let body_bytes = serde_json::to_vec(body)
.map_err(|e| format!("JSON serialize error: {e}"))?;
let resp = self
.client
.post(&url)
.headers(headers)
.body(body_bytes)
.send()
.await
.map_err(|e| format!("HTTP error: {e}"))?;
let status = resp.status().as_u16();
let encoding = resp
.headers()
.get("content-encoding")
.and_then(|v| v.to_str().ok())
.unwrap_or("")
.to_string();
let raw = resp.bytes().await
.map_err(|e| format!("Read body error: {e}"))?;
let resp_bytes = decompress(method, &raw, &encoding);
tracing::debug!(
"{method} response ({status}, {} bytes, enc={encoding})",
resp_bytes.len(),
);
tracing::trace!(
"{method} body: {}",
String::from_utf8_lossy(&resp_bytes[..resp_bytes.len().min(200)])
);
let data: serde_json::Value = match serde_json::from_slice(&resp_bytes) {
Ok(v) => v,
Err(e) => {
tracing::warn!("{method} response is not valid JSON: {e}");
serde_json::Value::Object(serde_json::Map::new())
}
};
Ok((status, data))
}
/// Call a binary protobuf RPC method.
pub async fn call_proto(
&self,
method: &str,
body: Vec<u8>,
) -> Result<(u16, Vec<u8>), String> {
let (base, csrf) = {
let guard = self.inner.read().await;
(
format!("https://127.0.0.1:{}", guard.https_port),
guard.csrf.clone(),
)
};
let url = format!("{base}/{LS_SERVICE}/{method}");
let mut headers = Self::common_headers(&csrf);
headers.insert("Content-Type", HeaderValue::from_static("application/proto"));
let resp = self
.client
.post(&url)
.headers(headers)
.body(body)
.send()
.await
.map_err(|e| format!("HTTP error: {e}"))?;
let status = resp.status().as_u16();
let encoding = resp
.headers()
.get("content-encoding")
.and_then(|v| v.to_str().ok())
.unwrap_or("")
.to_string();
let raw = resp
.bytes()
.await
.map_err(|e| format!("Read body error: {e}"))?;
let decompressed = decompress(method, &raw, &encoding);
Ok((status, decompressed))
}
/// StartCascade → returns cascade_id.
pub async fn create_cascade(&self) -> Result<String, String> {
let body = serde_json::json!({"prompt": "new chat"});
let (status, data) = self.call_json("StartCascade", &body).await?;
if status != 200 {
return Err(format!("StartCascade failed: {status}{data}"));
}
tracing::debug!("StartCascade response: {data}");
data["cascadeId"]
.as_str()
.map(|s| s.to_string())
.ok_or_else(|| format!("Missing cascadeId in response: {data}"))
}
/// SendUserCascadeMessage with binary protobuf body.
pub async fn send_message(
&self,
cascade_id: &str,
text: &str,
model_enum: u32,
) -> Result<(u16, Vec<u8>), String> {
let token = self.oauth_token().await;
if token.is_empty() {
return Err("No OAuth token available".to_string());
}
let proto = crate::proto::build_request(cascade_id, text, &token, model_enum);
self.call_proto("SendUserCascadeMessage", proto).await
}
/// GetCascadeTrajectorySteps → JSON with steps array.
pub async fn get_steps(
&self,
cascade_id: &str,
) -> Result<(u16, serde_json::Value), String> {
let body = serde_json::json!({"cascadeId": cascade_id});
self.call_json("GetCascadeTrajectorySteps", &body).await
}
/// GetCascadeTrajectory → JSON with trajectory status.
pub async fn get_trajectory(
&self,
cascade_id: &str,
) -> Result<(u16, serde_json::Value), String> {
let body = serde_json::json!({"cascadeId": cascade_id});
self.call_json("GetCascadeTrajectory", &body).await
}
}
// ─── Discovery helpers ───────────────────────────────────────────────────────
fn discover() -> Result<BackendInner, String> {
let pid_output = Command::new("sh")
.args(["-c", "pgrep -f language_server_linux | head -1"])
.output()
.map_err(|e| format!("pgrep failed: {e}"))?;
let pid = String::from_utf8_lossy(&pid_output.stdout)
.trim()
.to_string();
if pid.is_empty() {
return Err("Language server not running".to_string());
}
let cmdline = fs::read(format!("/proc/{pid}/cmdline"))
.map_err(|e| format!("Can't read cmdline for PID {pid}: {e}"))?;
let args: Vec<&[u8]> = cmdline.split(|&b| b == 0).collect();
let mut csrf = String::new();
for (i, arg) in args.iter().enumerate() {
if let Ok(s) = std::str::from_utf8(arg) {
if s == "--csrf_token" {
if let Some(next) = args.get(i + 1) {
if let Ok(token) = std::str::from_utf8(next) {
csrf = token.to_string();
}
}
}
}
}
let csrf_preview = safe_truncate(&csrf, 8);
debug!("Discovered LS PID={pid}, CSRF={csrf_preview}");
let log_base = log_base();
let mut https_port = String::new();
if let Ok(mut entries) = fs::read_dir(&log_base) {
let mut dirs: Vec<String> = Vec::new();
while let Some(Ok(entry)) = entries.next() {
let name = entry.file_name().to_string_lossy().to_string();
if name.starts_with("202") {
dirs.push(name);
}
}
dirs.sort_unstable_by(|a, b| b.cmp(a));
static PORT_RE: LazyLock<regex::Regex> =
LazyLock::new(|| regex::Regex::new(r"port at (\d+) for HTTPS").unwrap());
for d in &dirs {
let log_path = format!(
"{log_base}/{d}/window1/exthost/google.antigravity/Antigravity.log"
);
if let Ok(contents) = fs::read_to_string(&log_path) {
for line in contents.lines() {
if line.contains(&pid) && line.contains("listening") && line.contains("HTTPS") {
if let Some(caps) = PORT_RE.captures(line) {
https_port = caps[1].to_string();
}
}
}
if !https_port.is_empty() {
break;
}
}
}
}
if https_port.is_empty() {
warn!("Could not find HTTPS port in logs, defaulting to 3100");
https_port = "3100".to_string();
}
let oauth_token = std::env::var("ANTIGRAVITY_OAUTH_TOKEN")
.ok()
.filter(|s| !s.is_empty())
.or_else(|| {
let home = std::env::var("HOME").unwrap_or_default();
let path = format!("{home}/.config/antigravity-proxy-token");
fs::read_to_string(&path)
.ok()
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
})
.unwrap_or_default();
Ok(BackendInner {
pid,
csrf,
https_port,
oauth_token,
})
}
/// Shorthand for HeaderValue (panics on invalid — only for known-safe static values).
fn hv(s: &str) -> HeaderValue {
HeaderValue::from_str(s).expect("invalid header value in static constant")
}
/// Decompress response bytes based on Content-Encoding header.
fn decompress(method: &str, data: &[u8], encoding: &str) -> Vec<u8> {
let mut out = Vec::new();
let res = match encoding {
"gzip" => GzDecoder::new(data).read_to_end(&mut out),
"deflate" => DeflateDecoder::new(data).read_to_end(&mut out),
"br" => brotli::Decompressor::new(data, 4096).read_to_end(&mut out),
_ => return data.to_vec(),
};
match res {
Ok(_) => out,
Err(e) => {
if !encoding.is_empty() {
let preview = String::from_utf8_lossy(&data[..data.len().min(100)]);
warn!("{method}: {encoding} decompress failed ({} bytes): {e}. Raw: {}", data.len(), preview);
}
data.to_vec()
}
}
}

217
src/constants.rs Normal file
View File

@@ -0,0 +1,217 @@
//! Shared constants — auto-detected from the installed Antigravity binary at startup.
//!
//! On first access, we locate the Antigravity installation (via the running
//! language server PID or well-known paths), parse `product.json` for version
//! strings, and extract Chrome/Electron versions from the binary. If detection
//! fails, we fall back to hardcoded values.
use std::fs;
use std::process::Command;
use std::sync::LazyLock;
/// Auto-detected version info from the installed Antigravity app.
struct DetectedVersions {
antigravity: String,
chrome: String,
electron: String,
client: String,
}
/// Locate the Antigravity install directory by tracing the language server PID
/// back to its binary, then walking up to the app root. Falls back to
/// well-known install paths.
fn find_install_dir() -> Option<String> {
// 1. Try tracing the running language server → /usr/share/antigravity/resources/app/extensions/...
if let Ok(output) = Command::new("sh")
.args(["-c", "pgrep -f language_server_linux | head -1"])
.output()
{
let pid = String::from_utf8_lossy(&output.stdout).trim().to_string();
if !pid.is_empty() {
if let Ok(exe) = fs::read_link(format!("/proc/{pid}/exe")) {
let exe_str = exe.to_string_lossy().to_string();
// exe is like: /usr/share/antigravity/resources/app/extensions/antigravity/bin/language_server_linux_x64
// We want: /usr/share/antigravity
if let Some(idx) = exe_str.find("/resources/") {
return Some(exe_str[..idx].to_string());
}
}
}
}
// 2. Fall back to well-known install paths
for path in &["/usr/share/antigravity", "/opt/Antigravity"] {
if fs::metadata(format!("{path}/resources/app/product.json")).is_ok() {
return Some(path.to_string());
}
}
None
}
/// Read `product.json` from the install dir and extract version fields.
fn read_product_json(install_dir: &str) -> (Option<String>, Option<String>) {
let path = format!("{install_dir}/resources/app/product.json");
let Ok(contents) = fs::read_to_string(&path) else {
return (None, None);
};
let Ok(json) = serde_json::from_str::<serde_json::Value>(&contents) else {
return (None, None);
};
let version = json["version"].as_str().map(|s| s.to_string());
let ide_version = json["ideVersion"].as_str().map(|s| s.to_string());
(version, ide_version)
}
/// Extract Chrome and Electron versions from the main binary via `strings`.
/// Pattern: "Chrome/142.0.7444.175", "Electron/39.2.3".
fn extract_binary_versions(install_dir: &str) -> (Option<String>, Option<String>) {
let binary = format!("{install_dir}/antigravity");
if fs::metadata(&binary).is_err() {
return (None, None);
}
// Use grep -oP on the binary to avoid loading the whole thing into memory
let chrome = Command::new("sh")
.args([
"-c",
&format!(
"strings '{}' | grep -oP 'Chrome/[0-9]+\\.[0-9]+\\.[0-9]+\\.[0-9]+' | head -1",
binary
),
])
.output()
.ok()
.and_then(|o| {
let s = String::from_utf8_lossy(&o.stdout).trim().to_string();
s.strip_prefix("Chrome/").map(|v| v.to_string())
});
let electron = Command::new("sh")
.args([
"-c",
&format!(
"strings '{}' | grep -oP 'Electron/[0-9]+\\.[0-9]+\\.[0-9]+' | head -1",
binary
),
])
.output()
.ok()
.and_then(|o| {
let s = String::from_utf8_lossy(&o.stdout).trim().to_string();
s.strip_prefix("Electron/").map(|v| v.to_string())
});
(chrome, electron)
}
/// Detect all versions from the installed Antigravity app.
fn detect_versions() -> DetectedVersions {
// Hardcoded fallbacks — last known good values
const FALLBACK_ANTIGRAVITY: &str = "1.107.0";
const FALLBACK_CHROME: &str = "142.0.7444.175";
const FALLBACK_ELECTRON: &str = "39.2.3";
const FALLBACK_CLIENT: &str = "1.16.5";
let Some(install_dir) = find_install_dir() else {
eprintln!(
"[constants] ⚠ Could not find Antigravity install — using fallback versions"
);
return DetectedVersions {
antigravity: FALLBACK_ANTIGRAVITY.to_string(),
chrome: FALLBACK_CHROME.to_string(),
electron: FALLBACK_ELECTRON.to_string(),
client: FALLBACK_CLIENT.to_string(),
};
};
// product.json → antigravity version + client/IDE version
let (ag_ver, client_ver) = read_product_json(&install_dir);
// Binary → Chrome + Electron versions
let (chrome_ver, electron_ver) = extract_binary_versions(&install_dir);
let versions = DetectedVersions {
antigravity: ag_ver.unwrap_or_else(|| FALLBACK_ANTIGRAVITY.to_string()),
chrome: chrome_ver.unwrap_or_else(|| FALLBACK_CHROME.to_string()),
electron: electron_ver.unwrap_or_else(|| FALLBACK_ELECTRON.to_string()),
client: client_ver.unwrap_or_else(|| FALLBACK_CLIENT.to_string()),
};
eprintln!(
"[constants] ✓ Detected versions: Antigravity={}, Chrome={}, Electron={}, Client={}",
versions.antigravity, versions.chrome, versions.electron, versions.client
);
versions
}
// ─── Public API ──────────────────────────────────────────────────────────────
/// All detected versions — computed once on first access.
static VERSIONS: LazyLock<DetectedVersions> = LazyLock::new(detect_versions);
/// Antigravity app version (e.g. "1.107.0").
pub fn antigravity_version() -> &'static str {
&VERSIONS.antigravity
}
/// Chrome version bundled with Electron (e.g. "142.0.7444.175").
pub fn chrome_version() -> &'static str {
&VERSIONS.chrome
}
/// Electron version (e.g. "39.2.3").
pub fn electron_version() -> &'static str {
&VERSIONS.electron
}
/// Client/IDE version from product.json (e.g. "1.16.5").
pub fn client_version() -> &'static str {
&VERSIONS.client
}
pub const CLIENT_NAME: &str = "antigravity";
pub const LS_SERVICE: &str = "exa.language_server_pb.LanguageServerService";
/// Log base directory for Antigravity.
pub fn log_base() -> String {
let home = std::env::var("HOME").unwrap_or_else(|_| "/root".to_string());
format!("{home}/.config/Antigravity/logs")
}
/// Token file path.
pub fn token_file_path() -> String {
let home = std::env::var("HOME").unwrap_or_else(|_| "/root".to_string());
format!("{home}/.config/antigravity-proxy-token")
}
/// User-Agent string matching the Electron webview — computed once.
pub static USER_AGENT: LazyLock<String> = LazyLock::new(|| {
format!(
"Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 \
(KHTML, like Gecko) Antigravity/{} \
Chrome/{} Electron/{} Safari/537.36",
antigravity_version(),
chrome_version(),
electron_version()
)
});
/// Chrome major version for sec-ch-ua header — computed once.
pub static CHROME_MAJOR: LazyLock<String> = LazyLock::new(|| {
chrome_version()
.split('.')
.next()
.unwrap_or("142")
.to_string()
});
/// Safely truncate a string to at most `max` characters (not bytes).
pub fn safe_truncate(s: &str, max: usize) -> String {
match s.char_indices().nth(max) {
None => s.to_string(),
Some((idx, _)) => format!("{}...", &s[..idx]),
}
}

332
src/main.rs Normal file
View File

@@ -0,0 +1,332 @@
//! Antigravity OpenAI Proxy — Rust edition v3 (stealth hardened).
//!
//! Single-binary replacement for server.py. BoringSSL TLS impersonation,
//! byte-exact protobuf encoding, Chrome header fingerprinting, cascade
//! session management, warmup + heartbeat lifecycle mimicry.
mod api;
mod backend;
mod constants;
mod mitm;
mod proto;
mod quota;
mod session;
mod warmup;
use api::AppState;
use backend::Backend;
use clap::Parser;
use session::SessionManager;
use std::sync::Arc;
use tracing::{info, warn};
use mitm::store::MitmStore;
#[derive(Parser)]
#[command(name = "antigravity-proxy", about = "Antigravity OpenAI Proxy (stealth)")]
struct Cli {
/// Port to listen on
#[arg(long, default_value_t = 8741)]
port: u16,
/// Enable info-level logging (-v)
#[arg(short, long)]
verbose: bool,
/// Enable debug-level logging (-d)
#[arg(short, long)]
debug: bool,
/// Disable the MITM proxy (no API interception)
#[arg(long)]
no_mitm: bool,
/// MITM proxy port (default: 8742, matches wrapper script)
#[arg(long, default_value_t = 8742)]
mitm_port: u16,
}
#[tokio::main]
async fn main() {
let cli = Cli::parse();
// Flag > env var > default (warn)
let log_level = if cli.debug {
"debug"
} else if cli.verbose {
"info"
} else {
// Fall back to RUST_LOG env, or warn-only
""
};
let filter = if log_level.is_empty() {
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| "warn".into())
} else {
tracing_subscriber::EnvFilter::new(log_level)
};
tracing_subscriber::fmt()
.with_env_filter(filter)
.init();
// ── Step 1: Bind main port FIRST (fail fast, before spawning anything) ────
let addr = format!("127.0.0.1:{}", cli.port);
let listener = match tokio::net::TcpListener::bind(&addr).await {
Ok(l) => l,
Err(e) => {
eprintln!("Fatal: cannot bind to {addr}: {e}");
eprintln!("Hint: kill $(lsof -ti:{}) 2>/dev/null", cli.port);
std::process::exit(1);
}
};
// ── Step 2: Backend discovery ─────────────────────────────────────────────
let backend = Arc::new(match Backend::new() {
Ok(b) => b,
Err(e) => {
eprintln!("Fatal: {e}");
std::process::exit(1);
}
});
let (pid, https_port, csrf, token) = backend.info().await;
// ── Step 3: MITM proxy (after port is secured) ────────────────────────────
let mitm_store = MitmStore::new();
let (mitm_port_actual, mitm_handle) = if !cli.no_mitm {
let data_dir = dirs_data_dir();
match mitm::ca::MitmCa::load_or_generate(&data_dir) {
Ok(ca) => {
let ca = Arc::new(ca);
let ca_pem = ca.ca_pem_path.display().to_string();
let config = mitm::proxy::MitmConfig {
port: cli.mitm_port,
modify_requests: false,
};
match mitm::proxy::run(ca, mitm_store.clone(), config).await {
Ok((port, handle)) => {
info!(port, ca = %ca_pem, "MITM proxy started");
// Write actual port to file for wrapper script discovery
let port_file = data_dir.join("mitm-port");
if let Err(e) = std::fs::write(&port_file, port.to_string()) {
warn!("Failed to write MITM port file: {e}");
}
(Some((port, ca_pem)), Some(handle))
}
Err(e) => {
warn!("MITM proxy failed to start: {e}");
(None, None)
}
}
}
Err(e) => {
warn!("MITM CA generation failed: {e}");
(None, None)
}
}
} else {
info!("MITM proxy disabled (--no-mitm)");
(None, None)
};
// ── Step 4: Warmup + heartbeat ────────────────────────────────────────────
warmup::warmup_sequence(&backend).await;
let heartbeat_handle = warmup::start_heartbeat(Arc::clone(&backend));
// ── Step 4b: Quota monitor ────────────────────────────────────────────────
let quota_store = quota::QuotaStore::new();
quota_store.clone().start_polling(Arc::clone(&backend));
info!("Quota monitor started (polling every 60s)");
let state = Arc::new(AppState {
backend,
sessions: SessionManager::new(),
mitm_store,
quota_store,
});
// Periodic backend refresh — keeps LS connection details fresh
let refresh_backend = Arc::clone(&state.backend);
let refresh_handle = tokio::spawn(async move {
loop {
tokio::time::sleep(tokio::time::Duration::from_secs(60)).await;
if let Err(e) = refresh_backend.refresh().await {
warn!("Periodic refresh failed: {e}");
}
}
});
// ── Step 5: Start serving ─────────────────────────────────────────────────
let app = api::router(state.clone());
print_banner(cli.port, &pid, &https_port, &csrf, &token, &mitm_port_actual);
info!("Listening on http://{addr}");
axum::serve(listener, app)
.with_graceful_shutdown(shutdown_signal())
.await
.expect("server error");
// ── Cleanup: abort all background tasks ───────────────────────────────────
heartbeat_handle.abort();
refresh_handle.abort();
if let Some(h) = mitm_handle {
h.abort();
}
// Remove stale MITM port file
let _ = std::fs::remove_file(dirs_data_dir().join("mitm-port"));
info!("Server shutdown complete");
}
/// Wait for SIGINT (Ctrl+C) or SIGTERM for graceful shutdown.
async fn shutdown_signal() {
let ctrl_c = async {
tokio::signal::ctrl_c()
.await
.expect("failed to install Ctrl+C handler");
};
#[cfg(unix)]
let terminate = async {
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
.expect("failed to install SIGTERM handler")
.recv()
.await;
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
_ = ctrl_c => info!("Received SIGINT, shutting down..."),
_ = terminate => info!("Received SIGTERM, shutting down..."),
}
}
fn print_banner(port: u16, pid: &str, https_port: &str, csrf: &str, token: &str, mitm: &Option<(u16, String)>) {
let chrome_major = &*constants::CHROME_MAJOR;
let ver = crate::constants::antigravity_version();
println!();
println!(" \x1b[1;35m>> antigravity-proxy\x1b[0m \x1b[2mv{ver}\x1b[0m");
println!(" \x1b[2m────────────────────────────────────────────────\x1b[0m");
println!();
println!(" \x1b[1mcore\x1b[0m");
println!(" \x1b[36m tls\x1b[0m BoringSSL (Chrome {chrome_major})");
println!(" \x1b[36m listen\x1b[0m http://127.0.0.1:{port}");
println!(" \x1b[36m ls pid\x1b[0m {pid}");
println!(" \x1b[36m https\x1b[0m :{https_port}");
println!(" \x1b[36m csrf\x1b[0m {csrf}");
println!(" \x1b[36m oauth\x1b[0m {token}");
println!();
// MITM section
if let Some((mitm_port, ca_path)) = mitm {
println!(" \x1b[1mmitm\x1b[0m");
println!(" \x1b[36m proxy\x1b[0m 127.0.0.1:{mitm_port}");
println!(" \x1b[36m ca cert\x1b[0m {ca_path}");
// Check if wrapper is installed
let wrapper_installed = check_wrapper_installed();
if wrapper_installed {
println!(" \x1b[36m wrapper\x1b[0m \x1b[32minstalled\x1b[0m");
} else {
println!(" \x1b[36m wrapper\x1b[0m \x1b[33mnot installed\x1b[0m");
}
println!();
} else {
println!(" \x1b[1mmitm\x1b[0m \x1b[33mdisabled\x1b[0m");
println!();
}
// Routes
println!(" \x1b[1mroutes\x1b[0m");
println!(" \x1b[33m POST\x1b[0m /v1/responses");
println!(" \x1b[33m POST\x1b[0m /v1/chat/completions");
println!(" \x1b[32m GET \x1b[0m /v1/models");
println!(" \x1b[32m GET \x1b[0m /v1/sessions");
println!(" \x1b[31m DEL \x1b[0m /v1/sessions/:id");
println!(" \x1b[33m POST\x1b[0m /v1/token");
println!(" \x1b[32m GET \x1b[0m /v1/usage");
println!(" \x1b[32m GET \x1b[0m /v1/quota");
println!(" \x1b[32m GET \x1b[0m /health");
println!();
// Status line
let mitm_tag = if mitm.is_some() { "\x1b[32mmitm\x1b[0m" } else { "\x1b[31mmitm\x1b[0m" };
println!(" \x1b[2mstealth:\x1b[0m \x1b[32mwarmup\x1b[0m \x1b[32mheartbeat\x1b[0m \x1b[32mjitter\x1b[0m {mitm_tag}");
println!();
// Setup hints
if let Some((mitm_port, ca_path)) = mitm {
if !check_wrapper_installed() {
println!(" \x1b[1;33m[!]\x1b[0m mitm wrapper not installed");
println!(" \x1b[2mrun:\x1b[0m ./scripts/mitm-wrapper.sh install");
println!(" \x1b[2mor:\x1b[0m HTTPS_PROXY=http://127.0.0.1:{mitm_port}");
println!(" NODE_EXTRA_CA_CERTS={ca_path}");
println!();
}
}
if token == "NOT SET" {
println!(" \x1b[1;33m[!]\x1b[0m no oauth token");
println!(" export ANTIGRAVITY_OAUTH_TOKEN=ya29.xxx");
println!(" curl -X POST http://127.0.0.1:{port}/v1/token -d '{{\"token\":\"ya29.xxx\"}}'");
println!(" echo 'ya29.xxx' > ~/.config/antigravity-proxy-token");
println!();
}
}
/// Check if the MITM wrapper is installed by looking for the .real backup file
/// next to the LS binary. Uses /proc to find the real LS path dynamically.
fn check_wrapper_installed() -> bool {
// Find the LS binary path from known PID or by scanning /proc
if let Some(ls_path) = find_ls_binary_path() {
let real_path = format!("{ls_path}.real");
return std::path::Path::new(&real_path).exists();
}
false
}
/// Find the LS binary path by reading /proc/<pid>/exe for known language server processes.
fn find_ls_binary_path() -> Option<String> {
// Try all running processes, look for ones that look like the LS
let proc = std::path::Path::new("/proc");
if !proc.exists() {
return None;
}
if let Ok(entries) = std::fs::read_dir(proc) {
for entry in entries.flatten() {
let name = entry.file_name();
let name_str = name.to_string_lossy();
// Only look at numeric dirs (PIDs)
if !name_str.chars().all(|c| c.is_ascii_digit()) {
continue;
}
let exe_link = entry.path().join("exe");
if let Ok(target) = std::fs::read_link(&exe_link) {
let target_str = target.to_string_lossy();
// Strip " (deleted)" suffix from unlinked binaries
let target_clean = target_str.trim_end_matches(" (deleted)");
// Match any binary that looks like the Antigravity LS
if target_clean.contains("language_server_linux")
|| target_clean.contains("antigravity-language-server")
{
// Strip .real suffix — if the wrapper exec'd the backup, we want the base name
let path = target_clean.trim_end_matches(".real");
return Some(path.to_string());
}
}
}
}
None
}
/// Get the data directory for storing MITM CA cert/key.
fn dirs_data_dir() -> std::path::PathBuf {
let home = std::env::var("HOME").unwrap_or_else(|_| "/tmp".to_string());
std::path::PathBuf::from(home).join(".config").join("antigravity-proxy")
}

218
src/mitm/ca.rs Normal file
View File

@@ -0,0 +1,218 @@
//! Certificate Authority for MITM proxy.
//!
//! Generates a self-signed root CA at first run and caches it to disk.
//! Dynamically generates per-domain leaf certificates signed by this CA.
use rcgen::{
BasicConstraints, CertificateParams, DistinguishedName, DnType, ExtendedKeyUsagePurpose,
IsCa, KeyPair, KeyUsagePurpose, SanType,
};
use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer};
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::info;
/// MITM Certificate Authority.
pub struct MitmCa {
/// Root CA certificate (DER-encoded for rustls).
ca_cert_der: CertificateDer<'static>,
/// Root CA private key.
ca_key: KeyPair,
/// Signed root CA cert (needed by rcgen to sign leaf certs).
ca_signed: rcgen::Certificate,
/// Cache of per-domain TLS configs.
domain_cache: Arc<RwLock<HashMap<String, Arc<rustls::ServerConfig>>>>,
/// Path to the CA PEM file (for SSL_CERT_FILE combined bundle).
pub ca_pem_path: PathBuf,
}
impl MitmCa {
/// Load or generate the MITM CA.
///
/// The CA cert/key are stored at:
/// `<data_dir>/mitm-ca.pem` (cert, for NODE_EXTRA_CA_CERTS)
/// `<data_dir>/mitm-ca.key` (private key)
pub fn load_or_generate(data_dir: &Path) -> Result<Self, String> {
let cert_path = data_dir.join("mitm-ca.pem");
let key_path = data_dir.join("mitm-ca.key");
if cert_path.exists() && key_path.exists() {
info!("Loading existing MITM CA from {}", cert_path.display());
let cert_pem = std::fs::read_to_string(&cert_path)
.map_err(|e| format!("Failed to read CA cert: {e}"))?;
let key_pem = std::fs::read_to_string(&key_path)
.map_err(|e| format!("Failed to read CA key: {e}"))?;
let ca_key = KeyPair::from_pem(&key_pem)
.map_err(|e| format!("Failed to parse CA key: {e}"))?;
// Re-create params and self-sign to get the rcgen Certificate object
// (needed for signing leaf certs — rcgen 0.13 doesn't have from_ca_cert_pem).
// The re-signed cert will have a different serial/notBefore, but that's fine
// because we only use it for the rcgen signing API, NOT for the on-disk PEM.
let params = Self::ca_params();
let ca_signed = params.self_signed(&ca_key)
.map_err(|e| format!("Failed to self-sign CA: {e}"))?;
// Use the ORIGINAL on-disk PEM cert for DER — this is what the LS trusts
// (via the combined CA bundle built by the wrapper script). Writing the
// re-signed cert back would desync the LS's trust anchor.
let ca_cert_der = Self::pem_to_der(&cert_pem)
.unwrap_or_else(|| CertificateDer::from(ca_signed.der().to_vec()));
Ok(Self {
ca_cert_der,
ca_key,
ca_signed,
domain_cache: Arc::new(RwLock::new(HashMap::new())),
ca_pem_path: cert_path,
})
} else {
info!("Generating new MITM CA at {}", cert_path.display());
// Ensure data dir exists
std::fs::create_dir_all(data_dir)
.map_err(|e| format!("Failed to create data dir: {e}"))?;
let ca_key = KeyPair::generate()
.map_err(|e| format!("Failed to generate CA key: {e}"))?;
let params = Self::ca_params();
let ca_signed = params.self_signed(&ca_key)
.map_err(|e| format!("Failed to self-sign CA: {e}"))?;
// Write cert and key to disk
std::fs::write(&cert_path, ca_signed.pem())
.map_err(|e| format!("Failed to write CA cert: {e}"))?;
std::fs::write(&key_path, ca_key.serialize_pem())
.map_err(|e| format!("Failed to write CA key: {e}"))?;
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
let _ = std::fs::set_permissions(&key_path, std::fs::Permissions::from_mode(0o600));
}
let ca_cert_der = CertificateDer::from(ca_signed.der().to_vec());
Ok(Self {
ca_cert_der,
ca_key,
ca_signed,
domain_cache: Arc::new(RwLock::new(HashMap::new())),
ca_pem_path: cert_path,
})
}
}
/// Build the CA certificate parameters (reusable for both generate and load).
fn ca_params() -> CertificateParams {
let mut params = CertificateParams::default();
let mut dn = DistinguishedName::new();
dn.push(DnType::CommonName, "Antigravity MITM CA");
dn.push(DnType::OrganizationName, "Antigravity Proxy");
params.distinguished_name = dn;
params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained);
params.key_usages = vec![
KeyUsagePurpose::KeyCertSign,
KeyUsagePurpose::CrlSign,
];
// Valid for 10 years
let now = time::OffsetDateTime::now_utc();
params.not_before = now;
params.not_after = now + time::Duration::days(3650);
params
}
/// Parse a PEM certificate into a DER-encoded CertificateDer.
fn pem_to_der(pem: &str) -> Option<CertificateDer<'static>> {
// Extract base64 content between BEGIN/END markers
let mut in_cert = false;
let mut b64 = String::new();
for line in pem.lines() {
if line.contains("BEGIN CERTIFICATE") {
in_cert = true;
continue;
}
if line.contains("END CERTIFICATE") {
break;
}
if in_cert {
b64.push_str(line.trim());
}
}
if b64.is_empty() {
return None;
}
use base64::Engine;
let der = base64::engine::general_purpose::STANDARD.decode(&b64).ok()?;
Some(CertificateDer::from(der))
}
/// Get or create a TLS ServerConfig for the given domain.
pub async fn server_config_for_domain(&self, domain: &str) -> Result<Arc<rustls::ServerConfig>, String> {
// Check cache first
{
let cache = self.domain_cache.read().await;
if let Some(config) = cache.get(domain) {
return Ok(config.clone());
}
}
// Generate leaf cert for this domain
let mut params = CertificateParams::default();
let mut dn = DistinguishedName::new();
dn.push(DnType::CommonName, domain);
params.distinguished_name = dn;
params.subject_alt_names = vec![SanType::DnsName(domain.try_into().map_err(|e| format!("Invalid domain: {e}"))?)];
params.extended_key_usages = vec![ExtendedKeyUsagePurpose::ServerAuth];
params.key_usages = vec![
KeyUsagePurpose::DigitalSignature,
KeyUsagePurpose::KeyEncipherment,
];
// Valid for 1 year
let now = time::OffsetDateTime::now_utc();
params.not_before = now;
params.not_after = now + time::Duration::days(365);
let leaf_key = KeyPair::generate()
.map_err(|e| format!("Failed to generate leaf key: {e}"))?;
let leaf_cert = params.signed_by(&leaf_key, &self.ca_signed, &self.ca_key)
.map_err(|e| format!("Failed to sign leaf cert for {domain}: {e}"))?;
// Build rustls ServerConfig
let leaf_cert_der = CertificateDer::from(leaf_cert.der().to_vec());
let leaf_key_der = PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(leaf_key.serialize_der()));
let mut config = rustls::ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(
vec![leaf_cert_der, self.ca_cert_der.clone()],
leaf_key_der,
)
.map_err(|e| format!("Failed to build ServerConfig for {domain}: {e}"))?;
// Advertise both h2 and http/1.1 so gRPC clients can negotiate HTTP/2
config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
let config = Arc::new(config);
// Cache it
{
let mut cache = self.domain_cache.write().await;
cache.insert(domain.to_string(), config.clone());
}
Ok(config)
}
}

512
src/mitm/h2_handler.rs Normal file
View File

@@ -0,0 +1,512 @@
//! HTTP/2 handler for gRPC traffic interception.
//!
//! When the LS negotiates HTTP/2 via ALPN (which all gRPC connections do),
//! this module handles the bidirectional HTTP/2 connection:
//! 1. Accepts HTTP/2 frames from the client (LS)
//! 2. Connects to the real upstream via TLS + HTTP/2 (single connection reused)
//! 3. Forwards each request stream to upstream
//! 4. For non-streaming: buffers response, extracts usage, forwards
//! 5. For streaming: forwards response body chunks in real-time, tees to a
//! side buffer for usage extraction after stream completes
//!
//! ## Streaming vs Non-streaming
//!
//! gRPC has both unary (non-streaming) and server-streaming RPCs.
//! The LS uses server-streaming for methods like `StreamGenerateContent`.
//! We MUST forward streaming responses immediately — buffering would break
//! the LS's perception of real-time generation.
//!
//! For usage extraction: ModelUsageStats is typically in the LAST message
//! of a streaming response, so we tee the data and parse after stream ends.
use crate::mitm::proto::parse_grpc_response_for_usage;
use crate::mitm::store::{ApiUsage, MitmStore};
use bytes::Bytes;
use http_body_util::{BodyExt, Full, StreamBody};
use hyper::body::{Frame, Incoming};
use hyper::server::conn::http2::Builder as H2ServerBuilder;
use hyper::service::service_fn;
use hyper::{Request, Response};
use hyper_util::rt::TokioExecutor;
use hyper_util::rt::TokioIo;
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream;
use tokio::sync::Mutex;
use tracing::{debug, info, trace, warn};
/// A lazily-initialized, shared HTTP/2 connection to the upstream server.
///
/// gRPC multiplexes many requests over a single HTTP/2 connection.
/// We mirror this by maintaining a single upstream connection per domain.
struct UpstreamPool {
domain: String,
tls_config: Arc<rustls::ClientConfig>,
sender: Mutex<Option<hyper::client::conn::http2::SendRequest<Full<Bytes>>>>,
}
impl UpstreamPool {
fn new(domain: String, tls_config: Arc<rustls::ClientConfig>) -> Self {
Self {
domain,
tls_config,
sender: Mutex::new(None),
}
}
/// Get or create the upstream HTTP/2 sender.
async fn get_sender(
&self,
) -> Result<hyper::client::conn::http2::SendRequest<Full<Bytes>>, String> {
let mut guard = self.sender.lock().await;
// Check if existing sender is still usable
if let Some(ref sender) = *guard {
if !sender.is_closed() {
return Ok(sender.clone());
}
debug!(domain = %self.domain, "MITM H2: upstream connection closed, reconnecting");
}
// Create new connection
let sender = self.connect().await?;
*guard = Some(sender.clone());
Ok(sender)
}
async fn connect(
&self,
) -> Result<hyper::client::conn::http2::SendRequest<Full<Bytes>>, String> {
let upstream_tcp = TcpStream::connect(format!("{}:443", self.domain))
.await
.map_err(|e| format!("upstream TCP connect to {} failed: {e}", self.domain))?;
let connector = tokio_rustls::TlsConnector::from(self.tls_config.clone());
let server_name = rustls::pki_types::ServerName::try_from(self.domain.clone())
.map_err(|e| format!("invalid domain {}: {e}", self.domain))?;
let upstream_tls = connector
.connect(server_name, upstream_tcp)
.await
.map_err(|e| format!("upstream TLS to {} failed: {e}", self.domain))?;
let upstream_io = TokioIo::new(upstream_tls);
let (sender, conn) =
hyper::client::conn::http2::Builder::new(TokioExecutor::new())
.handshake(upstream_io)
.await
.map_err(|e| format!("upstream h2 handshake to {} failed: {e}", self.domain))?;
let domain = self.domain.clone();
tokio::spawn(async move {
if let Err(e) = conn.await {
debug!(domain = %domain, error = %e, "MITM H2: upstream connection driver ended");
}
});
info!(domain = %self.domain, "MITM H2: established upstream HTTP/2 connection");
Ok(sender)
}
}
/// gRPC methods that carry ModelUsageStats in their responses.
const USAGE_METHODS: &[&str] = &[
// Unary methods
"GenerateContent",
"AsyncGenerateContent",
"GenerateChat",
"GenerateCode",
"CompleteCode",
"InternalAtomicAgenticChat",
"Predict",
"DirectPredict",
// Streaming methods
"StreamGenerateContent",
"StreamAsyncGenerateContent",
"StreamGenerateChat",
];
/// Handle an HTTP/2 connection from the LS after TLS termination.
///
/// Uses hyper's HTTP/2 server to accept requests and a shared upstream
/// HTTP/2 connection to forward them.
pub async fn handle_h2_connection<S>(
tls_stream: S,
domain: String,
store: MitmStore,
) -> Result<(), String>
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
info!(domain = %domain, "MITM H2: handling HTTP/2 connection");
// Build TLS config for upstream connections
let mut root_store = rustls::RootCertStore::empty();
let native_certs = rustls_native_certs::load_native_certs();
for cert in native_certs.certs {
let _ = root_store.add(cert);
}
let mut upstream_tls_config = rustls::ClientConfig::builder()
.with_root_certificates(root_store)
.with_no_client_auth();
upstream_tls_config.alpn_protocols = vec![b"h2".to_vec()];
// Shared upstream connection pool (single connection, multiplexed)
let pool = Arc::new(UpstreamPool::new(
domain.clone(),
Arc::new(upstream_tls_config),
));
let io = TokioIo::new(tls_stream);
let domain = Arc::new(domain);
let result = H2ServerBuilder::new(TokioExecutor::new())
.serve_connection(
io,
service_fn(move |req: Request<Incoming>| {
let domain = domain.clone();
let store = store.clone();
let pool = pool.clone();
async move { handle_h2_request(req, &domain, store, pool).await }
}),
)
.await;
match result {
Ok(()) => {
debug!("MITM H2: connection closed cleanly");
Ok(())
}
Err(e) => {
// Connection errors are expected on clean close
debug!(error = %e, "MITM H2: connection ended");
Ok(())
}
}
}
/// Response body type — either buffered or streaming.
type BoxBody = http_body_util::Either<
Full<Bytes>,
StreamBody<tokio_stream::wrappers::ReceiverStream<Result<Frame<Bytes>, hyper::Error>>>,
>;
/// Handle a single HTTP/2 request: forward to upstream, capture usage.
///
/// For streaming responses, forwards chunks in real-time while teeing
/// data to a side buffer for post-stream usage extraction.
async fn handle_h2_request(
req: Request<Incoming>,
domain: &str,
store: MitmStore,
pool: Arc<UpstreamPool>,
) -> Result<Response<BoxBody>, hyper::Error> {
let method = req.method().clone();
let uri = req.uri().clone();
let path = uri.path().to_string();
// Identify gRPC method
let is_grpc = req
.headers()
.get("content-type")
.and_then(|v| v.to_str().ok())
.map(|ct| ct.starts_with("application/grpc"))
.unwrap_or(false);
// Check if this method carries usage data
let is_usage_method = is_grpc
&& USAGE_METHODS.iter().any(|m| path.contains(m));
// Check if this is a streaming method
let is_streaming = is_grpc
&& (path.contains("Stream") || path.contains("stream"));
debug!(
domain,
%method,
path = %path,
grpc = is_grpc,
usage_method = is_usage_method,
streaming = is_streaming,
"MITM H2: forwarding request"
);
// Collect request body (we need it for cascade ID extraction)
let (parts, body) = req.into_parts();
let request_body = match body.collect().await {
Ok(collected) => collected.to_bytes(),
Err(e) => {
warn!(error = %e, "MITM H2: failed to collect request body");
Bytes::new()
}
};
// Get upstream sender from pool
let mut upstream_sender = match pool.get_sender().await {
Ok(s) => s,
Err(e) => {
warn!(error = %e, domain, "MITM H2: upstream connect failed");
let resp = Response::builder()
.status(502)
.body(http_body_util::Either::Left(Full::new(
Bytes::from(format!("upstream connect failed: {e}")),
)))
.unwrap();
return Ok(resp);
}
};
// Build the upstream request with proper authority
let upstream_uri = http::Uri::builder()
.scheme("https")
.authority(domain)
.path_and_query(
uri.path_and_query()
.map(|pq| pq.as_str())
.unwrap_or("/"),
)
.build()
.unwrap_or(uri);
let mut upstream_req = Request::builder()
.method(parts.method)
.uri(upstream_uri);
// Copy headers, skip hop-by-hop
for (name, value) in &parts.headers {
let n = name.as_str();
if n == "host" || n == "connection" || n == "transfer-encoding" {
continue;
}
upstream_req = upstream_req.header(name, value);
}
let upstream_req = match upstream_req.body(Full::new(request_body.clone())) {
Ok(r) => r,
Err(e) => {
let resp = Response::builder()
.status(502)
.body(http_body_util::Either::Left(Full::new(
Bytes::from(format!("build request failed: {e}")),
)))
.unwrap();
return Ok(resp);
}
};
// Send to upstream
let upstream_resp = match upstream_sender.send_request(upstream_req).await {
Ok(r) => r,
Err(e) => {
warn!(error = %e, domain, path = %path, "MITM H2: upstream request failed");
let resp = Response::builder()
.status(502)
.body(http_body_util::Either::Left(Full::new(
Bytes::from(format!("upstream request failed: {e}")),
)))
.unwrap();
return Ok(resp);
}
};
let (resp_parts, resp_body) = upstream_resp.into_parts();
let status = resp_parts.status;
// ──────────────────────────────────────────────────────────────────
// Streaming path: forward chunks immediately, tee for usage parsing
// ──────────────────────────────────────────────────────────────────
if is_streaming && status.is_success() {
let should_track_usage = is_usage_method;
let (tx, rx) = tokio::sync::mpsc::channel::<Result<Frame<Bytes>, hyper::Error>>(32);
let store_clone = store.clone();
let path_clone = path.clone();
let request_body_clone = request_body.clone();
// Spawn a task to forward body chunks and tee for usage extraction
tokio::spawn(async move {
let mut tee_buffer = if should_track_usage { Some(Vec::new()) } else { None };
let mut body = resp_body;
loop {
match body.frame().await {
Some(Ok(frame)) => {
if let (Some(ref mut buf), Some(data)) = (&mut tee_buffer, frame.data_ref()) {
buf.extend_from_slice(data);
}
if tx.send(Ok(frame)).await.is_err() {
break; // client disconnected
}
}
Some(Err(e)) => {
warn!(error = %e, path = %path_clone, "MITM H2: streaming error");
let _ = tx.send(Err(e)).await;
break;
}
None => break, // stream ended
}
}
// Stream completed — parse the tee buffer for usage
if let Some(tee_buffer) = tee_buffer {
if !tee_buffer.is_empty() {
if let Some(grpc_usage) = parse_grpc_response_for_usage(&tee_buffer) {
let usage = ApiUsage {
input_tokens: grpc_usage.input_tokens,
output_tokens: grpc_usage.output_tokens,
thinking_output_tokens: grpc_usage.thinking_output_tokens,
response_output_tokens: grpc_usage.response_output_tokens,
cache_creation_input_tokens: grpc_usage.cache_write_tokens,
cache_read_input_tokens: grpc_usage.cache_read_tokens,
model: grpc_usage.model,
api_provider: grpc_usage.api_provider,
grpc_method: Some(path_clone.clone()),
stop_reason: None,
total_cost_usd: None,
captured_at: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
};
let cascade_hint = extract_cascade_from_grpc_request(&request_body_clone);
store_clone.record_usage(cascade_hint.as_deref(), usage).await;
}
}
}
});
let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
let stream_body = StreamBody::new(stream);
let mut client_resp = Response::builder().status(resp_parts.status);
for (name, value) in &resp_parts.headers {
client_resp = client_resp.header(name, value);
}
let client_resp = client_resp
.body(http_body_util::Either::Right(stream_body))
.unwrap_or_else(|_| {
Response::builder()
.status(500)
.body(http_body_util::Either::Left(Full::new(Bytes::from(
"internal error",
))))
.unwrap()
});
return Ok(client_resp);
}
// ──────────────────────────────────────────────────────────────────
// Non-streaming path: buffer full response, extract usage, forward
// ──────────────────────────────────────────────────────────────────
let response_body = match resp_body.collect().await {
Ok(collected) => collected.to_bytes(),
Err(e) => {
warn!(error = %e, "MITM H2: failed to collect response body");
Bytes::new()
}
};
trace!(
domain,
path = %path,
status = %status,
body_len = response_body.len(),
"MITM H2: got upstream response"
);
// Extract usage data from usage-carrying gRPC methods
if is_usage_method && !response_body.is_empty() && status.is_success() {
if let Some(grpc_usage) = parse_grpc_response_for_usage(&response_body) {
let usage = ApiUsage {
input_tokens: grpc_usage.input_tokens,
output_tokens: grpc_usage.output_tokens,
thinking_output_tokens: grpc_usage.thinking_output_tokens,
response_output_tokens: grpc_usage.response_output_tokens,
cache_creation_input_tokens: grpc_usage.cache_write_tokens,
cache_read_input_tokens: grpc_usage.cache_read_tokens,
model: grpc_usage.model,
api_provider: grpc_usage.api_provider,
grpc_method: Some(path.clone()),
stop_reason: None,
total_cost_usd: None,
captured_at: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
};
let cascade_hint = extract_cascade_from_grpc_request(&request_body);
store.record_usage(cascade_hint.as_deref(), usage).await;
}
}
// Build response for the client
let mut client_resp = Response::builder().status(resp_parts.status);
for (name, value) in &resp_parts.headers {
client_resp = client_resp.header(name, value);
}
let client_resp = client_resp
.body(http_body_util::Either::Left(Full::new(response_body)))
.unwrap_or_else(|_| {
Response::builder()
.status(500)
.body(http_body_util::Either::Left(Full::new(Bytes::from(
"internal error",
))))
.unwrap()
});
Ok(client_resp)
}
/// Try to extract a cascade ID from a gRPC request body.
///
/// Looks for UUID-formatted strings in the protobuf fields.
fn extract_cascade_from_grpc_request(body: &[u8]) -> Option<String> {
use crate::mitm::proto::{decode_proto, extract_grpc_messages};
let messages = extract_grpc_messages(body);
for msg in &messages {
let fields = decode_proto(msg);
for field in &fields {
if let Some(id) = extract_uuid_from_field(field) {
return Some(id);
}
}
}
None
}
fn extract_uuid_from_field(field: &crate::mitm::proto::ProtoField) -> Option<String> {
use crate::mitm::proto::ProtoValue;
match &field.value {
ProtoValue::Bytes(b) => {
if let Ok(s) = std::str::from_utf8(b) {
if is_uuid(s) {
return Some(s.to_string());
}
}
}
ProtoValue::Message(nested) => {
for nf in nested {
if let Some(id) = extract_uuid_from_field(nf) {
return Some(id);
}
}
}
_ => {}
}
None
}
fn is_uuid(s: &str) -> bool {
s.len() == 36
&& s.chars().all(|c| c.is_ascii_hexdigit() || c == '-')
&& s.chars().filter(|&c| c == '-').count() == 4
}

271
src/mitm/intercept.rs Normal file
View File

@@ -0,0 +1,271 @@
//! API response interceptor: parses Anthropic/Google API responses to extract usage data.
//!
//! Handles both streaming (SSE) and non-streaming (JSON) responses.
use super::store::ApiUsage;
use serde_json::Value;
use tracing::{debug, trace};
/// Parse a complete (non-streaming) Anthropic Messages API response body.
///
/// Response format:
/// ```json
/// {
/// "id": "msg_...",
/// "type": "message",
/// "model": "claude-sonnet-4-20250514",
/// "usage": {
/// "input_tokens": 1234,
/// "output_tokens": 567,
/// "cache_creation_input_tokens": 0,
/// "cache_read_input_tokens": 890
/// },
/// "stop_reason": "end_turn"
/// }
/// ```
pub fn parse_non_streaming_response(body: &[u8]) -> Option<ApiUsage> {
let json: Value = serde_json::from_slice(body).ok()?;
extract_usage_from_message(&json)
}
/// Parse SSE events from a streaming Anthropic response body chunk.
///
/// Events of interest:
/// - `message_start` — contains `message.usage.input_tokens` + cache tokens
/// - `message_delta` — contains `usage.output_tokens`
/// - `message_stop` — marks end (no usage data)
///
/// Returns accumulated usage across all events in this chunk.
pub fn parse_streaming_chunk(chunk: &str, accumulator: &mut StreamingAccumulator) {
for line in chunk.lines() {
if let Some(data) = line.strip_prefix("data: ") {
if data.trim() == "[DONE]" {
continue;
}
if let Ok(event) = serde_json::from_str::<Value>(data) {
accumulator.process_event(&event);
}
}
}
}
/// Accumulates usage data across streaming SSE events.
#[derive(Debug, Default)]
pub struct StreamingAccumulator {
pub input_tokens: u64,
pub output_tokens: u64,
pub cache_creation_input_tokens: u64,
pub cache_read_input_tokens: u64,
pub model: Option<String>,
pub stop_reason: Option<String>,
pub is_complete: bool,
}
impl StreamingAccumulator {
pub fn new() -> Self {
Self::default()
}
/// Process a single SSE event.
pub fn process_event(&mut self, event: &Value) {
let event_type = event["type"].as_str().unwrap_or("");
match event_type {
"message_start" => {
// message_start contains the initial usage (input tokens + cache)
if let Some(usage) = event.get("message").and_then(|m| m.get("usage")) {
self.input_tokens = usage["input_tokens"].as_u64().unwrap_or(0);
self.cache_creation_input_tokens = usage["cache_creation_input_tokens"].as_u64().unwrap_or(0);
self.cache_read_input_tokens = usage["cache_read_input_tokens"].as_u64().unwrap_or(0);
}
if let Some(model) = event.get("message").and_then(|m| m["model"].as_str()) {
self.model = Some(model.to_string());
}
trace!(
input = self.input_tokens,
cache_read = self.cache_read_input_tokens,
cache_create = self.cache_creation_input_tokens,
"SSE message_start: captured input usage"
);
}
"message_delta" => {
// message_delta contains the output usage
if let Some(usage) = event.get("usage") {
self.output_tokens = usage["output_tokens"].as_u64().unwrap_or(self.output_tokens);
}
if let Some(reason) = event["delta"]["stop_reason"].as_str() {
self.stop_reason = Some(reason.to_string());
}
trace!(output = self.output_tokens, "SSE message_delta: updated output tokens");
}
"message_stop" => {
self.is_complete = true;
debug!(
input = self.input_tokens,
output = self.output_tokens,
cache_read = self.cache_read_input_tokens,
model = ?self.model,
"SSE message_stop: stream complete"
);
}
"content_block_start" | "content_block_delta" | "content_block_stop" | "ping" => {
// Content events — no usage data, just pass through
}
_ => {
trace!(event_type, "SSE: unknown event type");
}
}
}
/// Convert accumulated data to an ApiUsage.
pub fn into_usage(self) -> ApiUsage {
ApiUsage {
input_tokens: self.input_tokens,
output_tokens: self.output_tokens,
cache_creation_input_tokens: self.cache_creation_input_tokens,
cache_read_input_tokens: self.cache_read_input_tokens,
thinking_output_tokens: 0,
response_output_tokens: 0,
total_cost_usd: None,
model: self.model,
stop_reason: self.stop_reason,
api_provider: Some("anthropic".to_string()),
grpc_method: None,
captured_at: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
}
}
}
/// Extract usage from a complete Message JSON object.
fn extract_usage_from_message(msg: &Value) -> Option<ApiUsage> {
let usage = msg.get("usage")?;
Some(ApiUsage {
input_tokens: usage["input_tokens"].as_u64().unwrap_or(0),
output_tokens: usage["output_tokens"].as_u64().unwrap_or(0),
cache_creation_input_tokens: usage["cache_creation_input_tokens"].as_u64().unwrap_or(0),
cache_read_input_tokens: usage["cache_read_input_tokens"].as_u64().unwrap_or(0),
thinking_output_tokens: 0,
response_output_tokens: 0,
total_cost_usd: None,
model: msg["model"].as_str().map(|s| s.to_string()),
stop_reason: msg["stop_reason"].as_str().map(|s| s.to_string()),
api_provider: Some("anthropic".to_string()),
grpc_method: None,
captured_at: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
})
}
/// Try to identify a cascade ID from the request body.
///
/// The LS includes cascade-related metadata in its API requests (as part of
/// the system prompt or metadata field). We try to find it.
pub fn extract_cascade_hint(request_body: &[u8]) -> Option<String> {
let json: Value = serde_json::from_slice(request_body).ok()?;
// Check for metadata field (some API configurations include it)
if let Some(metadata) = json.get("metadata") {
if let Some(user_id) = metadata["user_id"].as_str() {
// The LS often sets user_id to the cascadeId
return Some(user_id.to_string());
}
}
// Check system prompt for cascade/workspace markers
if let Some(system) = json.get("system") {
let system_str = match system {
Value::String(s) => s.clone(),
Value::Array(arr) => {
// Array of content blocks
arr.iter()
.filter_map(|b| b["text"].as_str())
.collect::<Vec<_>>()
.join(" ")
}
_ => return None,
};
// Look for workspace_id or cascade_id patterns
if let Some(pos) = system_str.find("workspace_id") {
let rest = &system_str[pos..];
// Extract the value after workspace_id
if let Some(val) = rest.split_whitespace().nth(1) {
return Some(val.to_string());
}
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_non_streaming() {
let body = r#"{
"id": "msg_123",
"type": "message",
"model": "claude-sonnet-4-20250514",
"usage": {
"input_tokens": 100,
"output_tokens": 50,
"cache_creation_input_tokens": 10,
"cache_read_input_tokens": 30
},
"stop_reason": "end_turn"
}"#;
let usage = parse_non_streaming_response(body.as_bytes()).unwrap();
assert_eq!(usage.input_tokens, 100);
assert_eq!(usage.output_tokens, 50);
assert_eq!(usage.cache_creation_input_tokens, 10);
assert_eq!(usage.cache_read_input_tokens, 30);
assert_eq!(usage.model.as_deref(), Some("claude-sonnet-4-20250514"));
}
#[test]
fn test_streaming_accumulator() {
let mut acc = StreamingAccumulator::new();
// message_start
let start = serde_json::json!({
"type": "message_start",
"message": {
"model": "claude-sonnet-4-20250514",
"usage": {
"input_tokens": 200,
"cache_creation_input_tokens": 5,
"cache_read_input_tokens": 50
}
}
});
acc.process_event(&start);
assert_eq!(acc.input_tokens, 200);
assert_eq!(acc.cache_read_input_tokens, 50);
// message_delta
let delta = serde_json::json!({
"type": "message_delta",
"delta": { "stop_reason": "end_turn" },
"usage": { "output_tokens": 75 }
});
acc.process_event(&delta);
assert_eq!(acc.output_tokens, 75);
// message_stop
let stop = serde_json::json!({ "type": "message_stop" });
acc.process_event(&stop);
assert!(acc.is_complete);
let usage = acc.into_usage();
assert_eq!(usage.input_tokens, 200);
assert_eq!(usage.output_tokens, 75);
}
}

19
src/mitm/mod.rs Normal file
View File

@@ -0,0 +1,19 @@
//! MITM proxy module: intercepts LS ↔ Google/Anthropic API traffic.
//!
//! The LS (Go binary with BoringCrypto) respects `HTTPS_PROXY` and `SSL_CERT_FILE`.
//! By setting these env vars via the wrapper script, we route all outbound HTTPS
//! traffic through our local MITM proxy, which:
//!
//! 1. Terminates TLS using dynamically-generated per-domain certificates
//! 2. Detects protocol: HTTP/1.1 (REST) or HTTP/2 (gRPC)
//! 3. For HTTP/1.1: parses JSON/SSE responses (Anthropic format)
//! 4. For HTTP/2: decodes gRPC protobuf responses (Google format)
//! 5. Captures token usage data (input, output, thinking, cache)
//! 6. Forwards everything transparently to real upstream servers
pub mod ca;
pub mod h2_handler;
pub mod intercept;
pub mod proto;
pub mod proxy;
pub mod store;

584
src/mitm/proto.rs Normal file
View File

@@ -0,0 +1,584 @@
//! Raw protobuf decoder for extracting ModelUsageStats from gRPC responses.
//!
//! We don't have the .proto schema, so we decode protobuf messages generically
//! and search for usage-like structures by matching field patterns.
//!
//! gRPC wire format:
//! - 1 byte: compression flag (0 = uncompressed, 1 = compressed)
//! - 4 bytes: message length (big-endian u32)
//! - N bytes: protobuf message
//!
//! Protobuf wire format:
//! - Each field: (field_number << 3 | wire_type) as varint, then value
//! - Wire type 0: varint
//! - Wire type 1: 64-bit fixed
//! - Wire type 2: length-delimited (string, bytes, embedded message)
//! - Wire type 5: 32-bit fixed
//!
//! ## ModelUsageStats schema (reverse-engineered from LS binary):
//!
//! ```protobuf
//! message ModelUsageStats {
//! Model model = 1; // enum (varint)
//! uint64 input_tokens = 2;
//! uint64 output_tokens = 3;
//! uint64 cache_write_tokens = 4;
//! uint64 cache_read_tokens = 5;
//! APIProvider api_provider = 6; // enum (varint)
//! string message_id = 7;
//! map<string,string> response_header = 8; // repeated message
//! uint64 thinking_output_tokens = 9;
//! uint64 response_output_tokens = 10;
//! string response_id = 11;
//! }
//! ```
use flate2::read::GzDecoder;
use std::io::Read;
use tracing::{debug, trace, warn};
/// A decoded protobuf field.
#[derive(Debug, Clone)]
pub enum ProtoValue {
Varint(u64),
#[allow(dead_code)]
Fixed64(u64),
#[allow(dead_code)]
Fixed32(u32),
Bytes(Vec<u8>),
/// Nested message (parsed recursively)
Message(Vec<ProtoField>),
}
/// A single protobuf field with its number and value.
#[derive(Debug, Clone)]
pub struct ProtoField {
pub number: u32,
pub value: ProtoValue,
}
/// Extracted usage data from a gRPC response.
#[derive(Debug, Default)]
pub struct GrpcUsage {
pub input_tokens: u64,
pub output_tokens: u64,
pub thinking_output_tokens: u64,
pub response_output_tokens: u64,
pub cache_read_tokens: u64,
pub cache_write_tokens: u64,
pub model: Option<String>,
pub api_provider: Option<String>,
pub message_id: Option<String>,
pub response_id: Option<String>,
}
/// Extract gRPC message frames from a buffer.
///
/// A gRPC message is:
/// [1 byte compressed flag] [4 bytes length BE] [N bytes protobuf]
///
/// Multiple messages can be concatenated in a single buffer.
/// If compressed flag is 1, the message is gzip-decompressed.
pub fn extract_grpc_messages(data: &[u8]) -> Vec<Vec<u8>> {
let mut messages = Vec::new();
let mut offset = 0;
while offset + 5 <= data.len() {
let compressed = data[offset];
let length = u32::from_be_bytes([
data[offset + 1],
data[offset + 2],
data[offset + 3],
data[offset + 4],
]) as usize;
offset += 5;
if offset + length > data.len() {
break;
}
let payload = &data[offset..offset + length];
if compressed == 1 {
// gzip-compressed frame
let mut decoder = GzDecoder::new(payload);
let mut decompressed = Vec::new();
match decoder.read_to_end(&mut decompressed) {
Ok(_) => messages.push(decompressed),
Err(e) => {
warn!(error = %e, "Proto: failed to decompress gRPC frame");
}
}
} else {
messages.push(payload.to_vec());
}
offset += length;
}
messages
}
/// Decode a protobuf message into a list of fields.
///
/// This is a best-effort decoder that handles the common wire types.
/// Embedded messages (wire type 2) are attempted to be parsed recursively.
pub fn decode_proto(data: &[u8]) -> Vec<ProtoField> {
let mut fields = Vec::new();
let mut offset = 0;
while offset < data.len() {
// Read tag (varint)
let (tag, bytes_read) = match read_varint(&data[offset..]) {
Some(v) => v,
None => break,
};
offset += bytes_read;
let field_number = (tag >> 3) as u32;
let wire_type = (tag & 0x07) as u8;
if field_number == 0 {
break; // invalid
}
let value = match wire_type {
0 => {
// Varint
let (val, bytes_read) = match read_varint(&data[offset..]) {
Some(v) => v,
None => break,
};
offset += bytes_read;
ProtoValue::Varint(val)
}
1 => {
// 64-bit fixed
if offset + 8 > data.len() {
break;
}
let val = u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap());
offset += 8;
ProtoValue::Fixed64(val)
}
2 => {
// Length-delimited
let (len, bytes_read) = match read_varint(&data[offset..]) {
Some(v) => v,
None => break,
};
offset += bytes_read;
let len = len as usize;
if offset + len > data.len() {
break;
}
let payload = &data[offset..offset + len];
offset += len;
// Try to parse as a nested message
let nested = decode_proto(payload);
if !nested.is_empty() && looks_like_valid_message(&nested, payload.len()) {
ProtoValue::Message(nested)
} else {
ProtoValue::Bytes(payload.to_vec())
}
}
5 => {
// 32-bit fixed
if offset + 4 > data.len() {
break;
}
let val = u32::from_le_bytes(data[offset..offset + 4].try_into().unwrap());
offset += 4;
ProtoValue::Fixed32(val)
}
_ => {
// Unknown wire type — stop parsing
break;
}
};
fields.push(ProtoField {
number: field_number,
value,
});
}
fields
}
/// Heuristic: does this list of fields look like a valid protobuf message?
/// (vs. a random string that happened to partially decode)
fn looks_like_valid_message(fields: &[ProtoField], original_len: usize) -> bool {
if fields.is_empty() {
return false;
}
// Check that field numbers are reasonable (< 10000)
let valid_numbers = fields.iter().all(|f| f.number < 10000);
if !valid_numbers {
return false;
}
// If we have very few fields relative to the data size, it's probably not a message
// (e.g., a long string that happened to have a valid first-field prefix)
if fields.len() == 1 && original_len > 100 {
// Single-field messages of >100 bytes are suspicious unless the field is bytes/message
match &fields[0].value {
ProtoValue::Bytes(_) | ProtoValue::Message(_) => true,
_ => false,
}
} else {
true
}
}
/// Read a varint from a byte slice. Returns (value, bytes_consumed).
pub fn read_varint(data: &[u8]) -> Option<(u64, usize)> {
let mut result: u64 = 0;
let mut shift = 0;
for (i, &byte) in data.iter().enumerate() {
if i >= 10 {
return None; // Too many bytes for a varint
}
result |= ((byte & 0x7F) as u64) << shift;
shift += 7;
if byte & 0x80 == 0 {
return Some((result, i + 1));
}
}
None
}
/// Search a decoded protobuf message tree for usage-like structures.
///
/// Uses the exact field numbers from the reverse-engineered ModelUsageStats schema:
///
/// field 1: model (enum/varint)
/// field 2: input_tokens (uint64)
/// field 3: output_tokens (uint64)
/// field 4: cache_write_tokens (uint64)
/// field 5: cache_read_tokens (uint64)
/// field 6: api_provider (enum/varint)
/// field 7: message_id (string)
/// field 8: response_header (map, repeated message)
/// field 9: thinking_output_tokens (uint64)
/// field 10: response_output_tokens (uint64)
/// field 11: response_id (string)
pub fn extract_usage_from_proto(fields: &[ProtoField]) -> Option<GrpcUsage> {
// Strategy: recursively search for any sub-message that looks like usage data
// Try this level first
if let Some(usage) = try_extract_usage(fields) {
return Some(usage);
}
// Recurse into nested messages
for field in fields {
if let ProtoValue::Message(ref nested) = field.value {
if let Some(usage) = extract_usage_from_proto(nested) {
return Some(usage);
}
}
}
None
}
/// Try to extract usage from this specific set of fields.
///
/// Uses verified field numbers from the binary's embedded proto descriptor.
fn try_extract_usage(fields: &[ProtoField]) -> Option<GrpcUsage> {
// We need:
// - At least 2 varint fields with values in token range
// - Ideally field 2 (input_tokens) or field 3 (output_tokens) present
let varint_fields: Vec<_> = fields
.iter()
.filter(|f| matches!(f.value, ProtoValue::Varint(_)))
.collect();
let string_fields: Vec<_> = fields
.iter()
.filter_map(|f| {
if let ProtoValue::Bytes(ref b) = f.value {
std::str::from_utf8(b).ok().map(|s| (f.number, s.to_string()))
} else {
None
}
})
.collect();
// Need at least 2 varint fields to be a candidate
if varint_fields.len() < 2 {
return None;
}
// Check if the varint values make sense as token counts
let plausible_token_count = |v: u64| v <= 10_000_000;
let plausible_varints = varint_fields
.iter()
.filter(|f| {
if let ProtoValue::Varint(v) = f.value {
plausible_token_count(v) && v > 0
} else {
false
}
})
.count();
// Need at least 2 non-zero plausible values
if plausible_varints < 2 {
return None;
}
// Check if there's a model-like string (field 7 = message_id or field 11 = response_id
// can contain model names, or model enum values map to known names)
let has_model_string = string_fields.iter().any(|(_, s)| {
s.contains("claude") || s.contains("gemini") || s.contains("gpt")
|| s.starts_with("models/") || s.contains("sonnet") || s.contains("opus")
|| s.contains("flash") || s.contains("pro")
});
// Check for fields at the known ModelUsageStats field numbers
let has_field_2 = fields.iter().any(|f| f.number == 2 && matches!(f.value, ProtoValue::Varint(_)));
let has_field_3 = fields.iter().any(|f| f.number == 3 && matches!(f.value, ProtoValue::Varint(_)));
// Strong signal: has both input and output token fields
let is_likely_usage = (has_field_2 && has_field_3) || has_model_string;
if !is_likely_usage && varint_fields.len() < 3 {
// Without strong signal, need more fields
return None;
}
// Build usage from exact field numbers (verified from binary)
let mut usage = GrpcUsage::default();
for field in fields {
match &field.value {
ProtoValue::Varint(v) => {
let v = *v;
if !plausible_token_count(v) {
continue;
}
match field.number {
// field 1 = model enum (varint, not string!)
2 => usage.input_tokens = v,
3 => usage.output_tokens = v,
4 => usage.cache_write_tokens = v, // VERIFIED: field 4
5 => usage.cache_read_tokens = v, // VERIFIED: field 5
// field 6 = api_provider enum (varint)
9 => usage.thinking_output_tokens = v, // VERIFIED: field 9
10 => usage.response_output_tokens = v, // VERIFIED: field 10
_ => {}
}
}
ProtoValue::Bytes(ref b) => {
if let Ok(s) = std::str::from_utf8(b) {
match field.number {
7 => usage.message_id = Some(s.to_string()),
11 => usage.response_id = Some(s.to_string()),
_ => {}
}
}
}
_ => {}
}
}
// Model and api_provider are enums (varints), not strings
// We can map known enum values later if needed
// For now, extract the enum value as a string representation
for field in fields {
if let ProtoValue::Varint(v) = &field.value {
match field.number {
1 => {
// Model enum — we don't have the mapping, store as number
usage.model = Some(format!("model_enum_{v}"));
}
6 => {
// APIProvider enum
usage.api_provider = Some(match *v {
0 => "unknown".to_string(),
1 => "google".to_string(),
2 => "anthropic".to_string(),
_ => format!("provider_{v}"),
});
}
_ => {}
}
}
}
// Validate — we should have at least input OR output tokens
if usage.input_tokens == 0 && usage.output_tokens == 0 {
return None;
}
debug!(
input = usage.input_tokens,
output = usage.output_tokens,
thinking = usage.thinking_output_tokens,
response = usage.response_output_tokens,
cache_read = usage.cache_read_tokens,
cache_write = usage.cache_write_tokens,
model = ?usage.model,
api_provider = ?usage.api_provider,
"Proto: extracted ModelUsageStats from protobuf"
);
Some(usage)
}
/// Parse a gRPC response body (may contain multiple messages) for usage data.
///
/// Handles both compressed and uncompressed gRPC frames.
pub fn parse_grpc_response_for_usage(body: &[u8]) -> Option<GrpcUsage> {
let messages = extract_grpc_messages(body);
trace!(count = messages.len(), "Proto: extracted gRPC messages");
// Check each message for usage data (last message usually has it)
for msg in messages.iter().rev() {
let fields = decode_proto(msg);
if let Some(usage) = extract_usage_from_proto(&fields) {
return Some(usage);
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_read_varint() {
assert_eq!(read_varint(&[0x00]), Some((0, 1)));
assert_eq!(read_varint(&[0x01]), Some((1, 1)));
assert_eq!(read_varint(&[0x96, 0x01]), Some((150, 2)));
assert_eq!(read_varint(&[0xAC, 0x02]), Some((300, 2)));
}
#[test]
fn test_extract_grpc_messages_uncompressed() {
// Construct a test gRPC frame: [0x00] [0x00, 0x00, 0x00, 0x05] [5 bytes data]
let mut buf = vec![0u8]; // not compressed
buf.extend_from_slice(&5u32.to_be_bytes());
buf.extend_from_slice(&[0x08, 0x96, 0x01, 0x10, 0x42]); // field 1 varint 150, field 2 varint 66
let messages = extract_grpc_messages(&buf);
assert_eq!(messages.len(), 1);
assert_eq!(messages[0].len(), 5);
}
#[test]
fn test_extract_grpc_messages_compressed() {
use flate2::write::GzEncoder;
use flate2::Compression;
use std::io::Write;
// Create a payload
let payload = vec![0x08, 0x96, 0x01, 0x10, 0x42];
// Compress it
let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
encoder.write_all(&payload).unwrap();
let compressed = encoder.finish().unwrap();
// Build gRPC frame with compressed flag
let mut buf = vec![1u8]; // compressed
buf.extend_from_slice(&(compressed.len() as u32).to_be_bytes());
buf.extend_from_slice(&compressed);
let messages = extract_grpc_messages(&buf);
assert_eq!(messages.len(), 1);
assert_eq!(messages[0], payload);
}
#[test]
fn test_decode_proto_varints() {
// field 1 = 150, field 2 = 66
let data = [0x08, 0x96, 0x01, 0x10, 0x42];
let fields = decode_proto(&data);
assert_eq!(fields.len(), 2);
assert_eq!(fields[0].number, 1);
assert!(matches!(fields[0].value, ProtoValue::Varint(150)));
assert_eq!(fields[1].number, 2);
assert!(matches!(fields[1].value, ProtoValue::Varint(66)));
}
#[test]
fn test_decode_proto_with_string() {
// field 1 = "hello" (string), field 2 = varint 42
let mut data = Vec::new();
// field 1, wire type 2 (length-delimited)
data.push(0x0A); // (1 << 3) | 2
data.push(0x05); // length 5
data.extend_from_slice(b"hello");
// field 2, wire type 0 (varint)
data.push(0x10); // (2 << 3) | 0
data.push(0x2A); // 42
let fields = decode_proto(&data);
assert!(fields.len() >= 2);
assert_eq!(fields[0].number, 1);
}
#[test]
fn test_extract_usage_correct_field_numbers() {
// Build a mock ModelUsageStats with the correct field numbers:
// field 1 (model enum) = 5 (some model)
// field 2 (input_tokens) = 1000
// field 3 (output_tokens) = 500
// field 4 (cache_write_tokens) = 100
// field 5 (cache_read_tokens) = 200
// field 9 (thinking_output_tokens) = 300
// field 10 (response_output_tokens) = 200
let mut data = Vec::new();
// Helper: encode varint field
fn encode_varint_field(data: &mut Vec<u8>, field_num: u32, value: u64) {
// Tag
let tag = (field_num << 3) | 0; // wire type 0
let mut t = tag;
while t >= 0x80 {
data.push((t as u8) | 0x80);
t >>= 7;
}
data.push(t as u8);
// Value
let mut v = value;
while v >= 0x80 {
data.push((v as u8) | 0x80);
v >>= 7;
}
data.push(v as u8);
}
encode_varint_field(&mut data, 1, 5); // model enum
encode_varint_field(&mut data, 2, 1000); // input_tokens
encode_varint_field(&mut data, 3, 500); // output_tokens
encode_varint_field(&mut data, 4, 100); // cache_write_tokens
encode_varint_field(&mut data, 5, 200); // cache_read_tokens
encode_varint_field(&mut data, 9, 300); // thinking_output_tokens
encode_varint_field(&mut data, 10, 200); // response_output_tokens
let fields = decode_proto(&data);
let usage = try_extract_usage(&fields).expect("should extract usage");
assert_eq!(usage.input_tokens, 1000);
assert_eq!(usage.output_tokens, 500);
assert_eq!(usage.cache_write_tokens, 100);
assert_eq!(usage.cache_read_tokens, 200);
assert_eq!(usage.thinking_output_tokens, 300);
assert_eq!(usage.response_output_tokens, 200);
}
}

591
src/mitm/proxy.rs Normal file
View File

@@ -0,0 +1,591 @@
//! MITM proxy server: handles CONNECT tunnels and TLS interception.
//!
//! Listens on a local port for HTTP CONNECT requests from the LS.
//! For intercepted domains, it terminates TLS with our CA-signed cert,
//! reads/modifies the request, forwards to the real upstream, and captures
//! the response (especially usage data).
//!
//! For non-intercepted domains, it acts as a transparent TCP tunnel.
use super::ca::MitmCa;
use super::intercept::{
extract_cascade_hint, parse_non_streaming_response, parse_streaming_chunk,
StreamingAccumulator,
};
use super::store::MitmStore;
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tokio_rustls::TlsAcceptor;
use tracing::{debug, error, info, trace, warn};
/// Domains we intercept (terminate TLS and inspect traffic).
/// This includes exact matches and suffix matches for regional endpoints
/// (e.g., us-central1-aiplatform.googleapis.com).
const INTERCEPT_DOMAINS: &[&str] = &[
"cloudcode-pa.googleapis.com",
"aiplatform.googleapis.com",
"api.anthropic.com",
"speech.googleapis.com",
"modelarmor.googleapis.com",
];
/// Domains we NEVER intercept (transparent tunnel).
const PASSTHROUGH_DOMAINS: &[&str] = &[
"oauth2.googleapis.com",
"accounts.google.com",
"storage.googleapis.com",
"www.googleapis.com",
"firebaseinstallations.googleapis.com",
"crashlyticsreports-pa.googleapis.com",
"play.googleapis.com",
"update.googleapis.com",
"dl.google.com",
];
/// Configuration for the MITM proxy.
pub struct MitmConfig {
/// Port to listen on (0 = auto-assign).
pub port: u16,
/// Whether to enable request modification.
pub modify_requests: bool,
}
impl Default for MitmConfig {
fn default() -> Self {
Self {
port: 0,
modify_requests: false,
}
}
}
/// Run the MITM proxy server.
///
/// Returns (port, task_handle) — port it's listening on, handle to abort on shutdown.
pub async fn run(
ca: Arc<MitmCa>,
store: MitmStore,
config: MitmConfig,
) -> Result<(u16, tokio::task::JoinHandle<()>), String> {
let listener = TcpListener::bind(format!("127.0.0.1:{}", config.port))
.await
.map_err(|e| format!("MITM bind failed: {e}"))?;
let port = listener
.local_addr()
.map_err(|e| format!("MITM local_addr failed: {e}"))?
.port();
info!(port, "MITM proxy listening");
let modify_requests = config.modify_requests;
let handle = tokio::spawn(async move {
loop {
match listener.accept().await {
Ok((stream, addr)) => {
trace!(?addr, "MITM: new connection");
let ca = ca.clone();
let store = store.clone();
tokio::spawn(async move {
if let Err(e) = handle_connection(stream, ca, store, modify_requests).await {
debug!(error = %e, "MITM connection error");
}
});
}
Err(e) => {
error!(error = %e, "MITM accept error");
}
}
}
});
Ok((port, handle))
}
/// Handle a single incoming connection from the LS.
///
/// The LS sends an HTTP CONNECT request to establish a tunnel.
/// We then decide whether to intercept or passthrough.
async fn handle_connection(
mut stream: TcpStream,
ca: Arc<MitmCa>,
store: MitmStore,
modify_requests: bool,
) -> Result<(), String> {
// Read the CONNECT request
let mut buf = vec![0u8; 8192];
let n = stream
.read(&mut buf)
.await
.map_err(|e| format!("Read CONNECT: {e}"))?;
if n == 0 {
return Ok(());
}
let request = String::from_utf8_lossy(&buf[..n]);
let first_line = request.lines().next().unwrap_or("");
// Parse "CONNECT host:port HTTP/1.1"
let parts: Vec<&str> = first_line.split_whitespace().collect();
if parts.len() < 3 || parts[0] != "CONNECT" {
// Not a CONNECT request — return 400
let resp = "HTTP/1.1 400 Bad Request\r\n\r\n";
let _ = stream.write_all(resp.as_bytes()).await;
return Ok(());
}
let host_port = parts[1];
let (domain, _port) = match host_port.rsplit_once(':') {
Some((h, p)) => (h, p.parse::<u16>().unwrap_or(443)),
None => (host_port, 443),
};
debug!(domain, "MITM: CONNECT request");
// Decide: intercept or passthrough
let should_intercept = should_intercept_domain(domain);
// Send 200 Connection Established
let response = "HTTP/1.1 200 Connection Established\r\n\r\n";
stream
.write_all(response.as_bytes())
.await
.map_err(|e| format!("Write 200: {e}"))?;
if should_intercept {
handle_intercepted(stream, domain, ca, store, modify_requests).await
} else {
handle_passthrough(stream, domain, _port).await
}
}
/// Check if a domain should be intercepted.
fn should_intercept_domain(domain: &str) -> bool {
// Never intercept passthrough domains
for &pt in PASSTHROUGH_DOMAINS {
if domain == pt {
return false;
}
}
// Intercept known API domains (exact match, subdomain, or regional prefix)
for &intercept in INTERCEPT_DOMAINS {
if domain == intercept
|| domain.ends_with(&format!(".{intercept}"))
|| domain.ends_with(&format!("-{intercept}"))
{
return true;
}
}
// Default: passthrough
false
}
/// Handle an intercepted connection: terminate TLS, inspect traffic.
///
/// After TLS termination, checks the negotiated ALPN protocol:
/// - `h2` → HTTP/2 handler (for gRPC traffic to Google APIs)
/// - `http/1.1` or none → HTTP/1.1 handler (for REST/SSE traffic)
async fn handle_intercepted(
stream: TcpStream,
domain: &str,
ca: Arc<MitmCa>,
store: MitmStore,
modify_requests: bool,
) -> Result<(), String> {
info!(domain, "MITM: intercepting TLS");
// Get or create server TLS config for this domain
let server_config = ca
.server_config_for_domain(domain)
.await?;
let acceptor = TlsAcceptor::from(server_config);
// Perform TLS handshake with the client (LS)
let tls_stream = acceptor
.accept(stream)
.await
.map_err(|e| format!("TLS handshake with client failed for {domain}: {e}"))?;
// Check negotiated ALPN protocol
let alpn = tls_stream.get_ref().1
.alpn_protocol()
.map(|p| String::from_utf8_lossy(p).to_string());
debug!(domain, alpn = ?alpn, "MITM: TLS handshake successful");
match alpn.as_deref() {
Some("h2") => {
// HTTP/2 — use the hyper-based gRPC handler
info!(domain, "MITM: routing to HTTP/2 handler (gRPC)");
super::h2_handler::handle_h2_connection(
tls_stream,
domain.to_string(),
store,
)
.await
}
_ => {
// HTTP/1.1 or no ALPN — use the existing handler
debug!(domain, "MITM: routing to HTTP/1.1 handler");
handle_http_over_tls(tls_stream, domain, store, modify_requests).await
}
}
}
/// Handle HTTP traffic over the decrypted TLS connection.
///
/// Loops to handle multiple requests on the same connection (HTTP keep-alive).
/// Reads full request, connects to upstream, forwards request, streams response
/// back to client while capturing usage data.
async fn handle_http_over_tls(
mut client: tokio_rustls::server::TlsStream<TcpStream>,
domain: &str,
store: MitmStore,
_modify_requests: bool,
) -> Result<(), String> {
let mut tmp = vec![0u8; 32768];
// Build upstream TLS connector once for this connection
let mut root_store = rustls::RootCertStore::empty();
let native_certs = rustls_native_certs::load_native_certs();
for cert in native_certs.certs {
let _ = root_store.add(cert);
}
let upstream_config = Arc::new(
rustls::ClientConfig::builder()
.with_root_certificates(root_store)
.with_no_client_auth(),
);
// Reusable upstream connection — created lazily, reconnected if stale
let mut upstream: Option<tokio_rustls::client::TlsStream<TcpStream>> = None;
/// Connect (or reconnect) to the real upstream via TLS.
async fn connect_upstream(
domain: &str,
config: &Arc<rustls::ClientConfig>,
) -> Result<tokio_rustls::client::TlsStream<TcpStream>, String> {
let connector = tokio_rustls::TlsConnector::from(config.clone());
let tcp = TcpStream::connect(format!("{domain}:443"))
.await
.map_err(|e| format!("Connect to upstream {domain}: {e}"))?;
let server_name = rustls::pki_types::ServerName::try_from(domain.to_string())
.map_err(|e| format!("Invalid server name: {e}"))?;
connector
.connect(server_name, tcp)
.await
.map_err(|e| format!("TLS connect to upstream {domain}: {e}"))
}
// Keep-alive loop: handle multiple requests on this connection
loop {
// ── Read the HTTP request from the client ─────────────────────────
let mut request_buf = Vec::with_capacity(1024 * 64);
loop {
let n = match client.read(&mut tmp).await {
Ok(0) => return Ok(()), // Client closed connection cleanly
Ok(n) => n,
Err(e) => {
// Connection reset / broken pipe is normal for keep-alive end
debug!(domain, error = %e, "MITM: client read finished");
return Ok(());
}
};
request_buf.extend_from_slice(&tmp[..n]);
// Check if we have the full request (headers + body)
if has_complete_http_request(&request_buf) {
break;
}
}
if request_buf.is_empty() {
return Ok(());
}
// Parse the HTTP request to find headers and body
let (headers_end, content_length, is_streaming_request) = parse_http_request_meta(&request_buf);
// Try to extract cascade hint from request body
let cascade_hint = if headers_end < request_buf.len() {
extract_cascade_hint(&request_buf[headers_end..])
} else {
None
};
debug!(
domain,
content_length,
streaming = is_streaming_request,
cascade = ?cascade_hint,
"MITM: forwarding request to upstream"
);
// ── Ensure upstream connection is alive ──────────────────────────────
// Lazily connect on first request, or reconnect if the previous connection died
let conn = match upstream.as_mut() {
Some(c) => c,
None => {
let c = connect_upstream(domain, &upstream_config).await?;
upstream.insert(c)
}
};
// Forward the request — if write fails, reconnect and retry once
if let Err(e) = conn.write_all(&request_buf).await {
debug!(domain, error = %e, "MITM: upstream write failed, reconnecting");
let c = connect_upstream(domain, &upstream_config).await?;
let conn = upstream.insert(c);
conn.write_all(&request_buf)
.await
.map_err(|e| format!("Write to upstream (retry): {e}"))?;
}
let conn = upstream.as_mut().unwrap();
// ── Stream response back to client ──────────────────────────────────
let mut streaming_acc = StreamingAccumulator::new();
let mut is_streaming_response = false;
let mut headers_parsed = false;
// Only buffer response body for non-streaming (for usage parsing)
let mut non_streaming_buf: Option<Vec<u8>> = None;
// Track if upstream connection is still usable after this response
let mut upstream_ok = true;
// Per-request timeout: 5 minutes (covers large context API calls)
const READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(300);
loop {
let n = match tokio::time::timeout(READ_TIMEOUT, conn.read(&mut tmp)).await {
Ok(Ok(0)) => {
// Upstream closed — connection is no longer reusable
upstream_ok = false;
break;
}
Ok(Ok(n)) => n,
Ok(Err(e)) => {
debug!(domain, error = %e, "MITM: upstream read finished");
upstream_ok = false;
break;
}
Err(_) => {
warn!(domain, "MITM: upstream read timed out after 5 minutes");
upstream_ok = false;
break;
}
};
let chunk = &tmp[..n];
// Check response headers for content-type
if !headers_parsed {
// We need to buffer until we see the end of headers
let buf = non_streaming_buf.get_or_insert_with(|| Vec::with_capacity(1024 * 64));
buf.extend_from_slice(chunk);
if let Some(_hdr_end) = find_headers_end(buf) {
// Use httparse for response header parsing
let mut resp_headers = [httparse::EMPTY_HEADER; 64];
let mut resp = httparse::Response::new(&mut resp_headers);
let hdr_end = match resp.parse(buf) {
Ok(httparse::Status::Complete(n)) => n,
_ => _hdr_end, // Fallback to manual detection
};
// Detect content type and connection handling from parsed headers
for header in resp.headers.iter() {
if header.name.eq_ignore_ascii_case("content-type") {
if let Ok(val) = std::str::from_utf8(header.value) {
if val.contains("text/event-stream") {
is_streaming_response = true;
}
}
}
if header.name.eq_ignore_ascii_case("connection") {
if let Ok(val) = std::str::from_utf8(header.value) {
if val.trim().eq_ignore_ascii_case("close") {
upstream_ok = false;
}
}
}
}
headers_parsed = true;
if is_streaming_response {
// For streaming, parse any SSE data already in the buffer
let body_so_far = String::from_utf8_lossy(&buf[hdr_end..]);
if !body_so_far.is_empty() {
parse_streaming_chunk(&body_so_far, &mut streaming_acc);
}
// Forward the accumulated buffer to client
if let Err(e) = client.write_all(buf).await {
warn!(error = %e, "MITM: write to client failed");
break;
}
non_streaming_buf = None;
continue;
}
// Non-streaming: keep buffering the response body for parsing
continue;
}
continue;
}
// If streaming, parse SSE events and forward immediately
if is_streaming_response {
let chunk_str = String::from_utf8_lossy(chunk);
parse_streaming_chunk(&chunk_str, &mut streaming_acc);
if let Err(e) = client.write_all(chunk).await {
warn!(error = %e, "MITM: write to client failed (client disconnected?)");
break;
}
} else {
// Non-streaming: keep accumulating to parse usage at the end
if let Some(ref mut buf) = non_streaming_buf {
buf.extend_from_slice(chunk);
}
}
}
// Forward non-streaming response all at once
if !is_streaming_response {
if let Some(ref buf) = non_streaming_buf {
if let Err(e) = client.write_all(buf).await {
warn!(error = %e, "MITM: write to client failed");
}
}
}
// Capture usage data
if is_streaming_response {
if streaming_acc.is_complete || streaming_acc.output_tokens > 0 {
let usage = streaming_acc.into_usage();
store.record_usage(cascade_hint.as_deref(), usage).await;
}
} else if let Some(ref buf) = non_streaming_buf {
if let Some(body_start) = find_headers_end(buf) {
let body = &buf[body_start..];
if let Some(usage) = parse_non_streaming_response(body) {
store.record_usage(cascade_hint.as_deref(), usage).await;
}
}
}
// If upstream closed, drop the connection so next iteration reconnects
if !upstream_ok {
upstream = None;
}
} // end keep-alive loop
}
/// Handle a passthrough connection: transparent TCP tunnel to upstream.
async fn handle_passthrough(
mut client: TcpStream,
domain: &str,
port: u16,
) -> Result<(), String> {
trace!(domain, port, "MITM: transparent tunnel");
let mut upstream = TcpStream::connect(format!("{domain}:{port}"))
.await
.map_err(|e| format!("Connect to {domain}:{port}: {e}"))?;
// Bidirectional copy
match tokio::io::copy_bidirectional(&mut client, &mut upstream).await {
Ok((client_to_server, server_to_client)) => {
trace!(domain, client_to_server, server_to_client, "MITM: tunnel closed");
}
Err(e) => {
trace!(domain, error = %e, "MITM: tunnel error (likely clean close)");
}
}
Ok(())
}
/// Check if buffer contains a complete HTTP request (headers + full body).
/// Uses `httparse` for zero-copy, case-insensitive header parsing.
fn has_complete_http_request(buf: &[u8]) -> bool {
let mut headers = [httparse::EMPTY_HEADER; 64];
let mut req = httparse::Request::new(&mut headers);
let headers_end = match req.parse(buf) {
Ok(httparse::Status::Complete(n)) => n,
_ => return false, // Incomplete or parse error — need more data
};
// Look for Content-Length
for header in req.headers.iter() {
if header.name.eq_ignore_ascii_case("content-length") {
if let Ok(val) = std::str::from_utf8(header.value) {
if let Ok(len) = val.trim().parse::<usize>() {
return buf.len() >= headers_end + len;
}
}
}
if header.name.eq_ignore_ascii_case("transfer-encoding") {
if let Ok(val) = std::str::from_utf8(header.value) {
if val.trim().eq_ignore_ascii_case("chunked") {
let body = &buf[headers_end..];
return body.len() >= 5 && body.ends_with(b"0\r\n\r\n");
}
}
}
}
// No Content-Length or Transfer-Encoding — no body expected (e.g., GET)
true
}
/// Find the end of HTTP headers (position after \r\n\r\n).
fn find_headers_end(buf: &[u8]) -> Option<usize> {
buf.windows(4)
.position(|w| w == b"\r\n\r\n")
.map(|pos| pos + 4)
}
/// Parse HTTP request metadata from raw bytes using `httparse`.
/// Returns (headers_end, content_length, is_streaming_request).
fn parse_http_request_meta(buf: &[u8]) -> (usize, usize, bool) {
let mut headers = [httparse::EMPTY_HEADER; 64];
let mut req = httparse::Request::new(&mut headers);
let headers_end = match req.parse(buf) {
Ok(httparse::Status::Complete(n)) => n,
_ => {
// Fallback if httparse can't parse
return (find_headers_end(buf).unwrap_or(buf.len()), 0, false);
}
};
let mut content_length = 0usize;
for header in req.headers.iter() {
if header.name.eq_ignore_ascii_case("content-length") {
if let Ok(val) = std::str::from_utf8(header.value) {
content_length = val.trim().parse().unwrap_or(0);
}
}
}
// Check if request body asks for streaming
let is_streaming = if headers_end < buf.len() {
let body_str = String::from_utf8_lossy(&buf[headers_end..]);
body_str.contains("\"stream\":true") || body_str.contains("\"stream\": true")
} else {
false
};
(headers_end, content_length, is_streaming)
}

163
src/mitm/store.rs Normal file
View File

@@ -0,0 +1,163 @@
//! Shared store for intercepted API usage data.
//!
//! The MITM proxy writes usage data here; the API handlers read from it.
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use serde::{Deserialize, Serialize};
use tracing::debug;
/// Token usage from an intercepted API response.
///
/// Covers both Anthropic JSON/SSE responses and Google gRPC protobuf responses.
/// Fields map to the superset of Anthropic's `usage` object and Google's `ModelUsageStats` proto.
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ApiUsage {
pub input_tokens: u64,
pub output_tokens: u64,
/// Anthropic: cache_creation_input_tokens / Google: cache_write_tokens
pub cache_creation_input_tokens: u64,
/// Anthropic: cache_read_input_tokens / Google: cache_read_tokens
pub cache_read_input_tokens: u64,
/// Google-specific: thinking/reasoning output tokens (extended thinking)
pub thinking_output_tokens: u64,
/// Google-specific: response output tokens (non-thinking portion)
pub response_output_tokens: u64,
/// Total cost in USD (if provided by the API).
pub total_cost_usd: Option<f64>,
/// The actual model that served the request.
pub model: Option<String>,
/// Stop reason / finish reason from the API.
pub stop_reason: Option<String>,
/// API provider (e.g. "anthropic", "google")
pub api_provider: Option<String>,
/// gRPC method path (e.g. "/google.internal.cloud.code.v1internal.PredictionService/GenerateContent")
pub grpc_method: Option<String>,
/// Timestamp when this usage was captured.
pub captured_at: u64,
}
/// Thread-safe store for intercepted data.
///
/// Keyed by a unique request ID that we can correlate with cascade operations.
/// In practice, we use the cascade ID + a sequence number.
#[derive(Clone)]
pub struct MitmStore {
/// Most recent usage per cascade ID.
latest_usage: Arc<RwLock<HashMap<String, ApiUsage>>>,
/// Global aggregate stats.
stats: Arc<RwLock<MitmStats>>,
}
/// Aggregate statistics across all intercepted traffic.
#[derive(Debug, Clone, Default, Serialize)]
pub struct MitmStats {
pub total_requests: u64,
pub total_input_tokens: u64,
pub total_output_tokens: u64,
pub total_cache_read_tokens: u64,
pub total_cache_creation_tokens: u64,
pub total_thinking_output_tokens: u64,
pub total_response_output_tokens: u64,
/// Per-model usage breakdown (model name → stats).
pub per_model: HashMap<String, ModelStats>,
}
/// Per-model usage counters.
#[derive(Debug, Clone, Default, Serialize)]
pub struct ModelStats {
pub requests: u64,
pub input_tokens: u64,
pub output_tokens: u64,
pub cache_read_tokens: u64,
pub cache_creation_tokens: u64,
}
impl MitmStore {
pub fn new() -> Self {
Self {
latest_usage: Arc::new(RwLock::new(HashMap::new())),
stats: Arc::new(RwLock::new(MitmStats::default())),
}
}
/// Record a completed API exchange with usage data.
pub async fn record_usage(&self, cascade_id: Option<&str>, usage: ApiUsage) {
debug!(
input = usage.input_tokens,
output = usage.output_tokens,
cache_read = usage.cache_read_input_tokens,
cache_create = usage.cache_creation_input_tokens,
thinking = usage.thinking_output_tokens,
response = usage.response_output_tokens,
model = ?usage.model,
provider = ?usage.api_provider,
grpc = ?usage.grpc_method,
"MITM captured API usage"
);
// Update aggregate stats
{
let mut stats = self.stats.write().await;
stats.total_requests += 1;
stats.total_input_tokens += usage.input_tokens;
stats.total_output_tokens += usage.output_tokens;
stats.total_cache_read_tokens += usage.cache_read_input_tokens;
stats.total_cache_creation_tokens += usage.cache_creation_input_tokens;
stats.total_thinking_output_tokens += usage.thinking_output_tokens;
stats.total_response_output_tokens += usage.response_output_tokens;
// Per-model breakdown
if let Some(ref model_name) = usage.model {
let model_stats = stats.per_model.entry(model_name.clone()).or_default();
model_stats.requests += 1;
model_stats.input_tokens += usage.input_tokens;
model_stats.output_tokens += usage.output_tokens;
model_stats.cache_read_tokens += usage.cache_read_input_tokens;
model_stats.cache_creation_tokens += usage.cache_creation_input_tokens;
}
}
// Store latest usage for the cascade (if we can identify it)
let key = cascade_id.map(|s| s.to_string()).unwrap_or_else(|| "_latest".to_string());
let mut latest = self.latest_usage.write().await;
latest.insert(key, usage);
// Evict old entries to prevent unbounded memory growth
const MAX_ENTRIES: usize = 500;
if latest.len() > MAX_ENTRIES {
// Find the oldest entry by captured_at and remove it
let oldest_key = latest
.iter()
.min_by_key(|(_, v)| v.captured_at)
.map(|(k, _)| k.clone());
if let Some(key) = oldest_key {
latest.remove(&key);
}
}
}
/// Get the latest usage for a cascade, consuming it (one-shot read).
///
/// Only returns exact cascade_id matches — no cross-cascade fallback.
/// The `_latest` key is only consumed when the caller explicitly requests it
/// (i.e., when the MITM couldn't identify the cascade).
pub async fn take_usage(&self, cascade_id: &str) -> Option<ApiUsage> {
let mut latest = self.latest_usage.write().await;
latest.remove(cascade_id)
}
/// Peek at the latest usage without consuming it.
#[allow(dead_code)]
pub async fn peek_usage(&self, cascade_id: &str) -> Option<ApiUsage> {
let latest = self.latest_usage.read().await;
latest.get(cascade_id)
.cloned()
}
/// Get aggregate stats.
pub async fn stats(&self) -> MitmStats {
self.stats.read().await.clone()
}
}

233
src/proto.rs Normal file
View File

@@ -0,0 +1,233 @@
//! Protobuf wire-format encoder — byte-exact match to the real Antigravity webview.
//!
//! This is a minimal, hand-rolled encoder. We do NOT use prost or any codegen
//! because we need precise control over field ordering and encoding to produce
//! byte-identical output to the captured webview traffic.
use crate::constants::{client_version, CLIENT_NAME};
// ─── Wire primitives ────────────────────────────────────────────────────────
/// Encode a varint (base-128, little-endian, MSB continuation).
pub fn varint(mut val: u64) -> Vec<u8> {
if val == 0 {
return vec![0x00];
}
let mut out = Vec::with_capacity(10);
while val > 0x7F {
out.push(((val & 0x7F) | 0x80) as u8);
val >>= 7;
}
out.push((val & 0x7F) as u8);
out
}
/// Encode a field tag (field_number << 3 | wire_type).
pub fn tag(field: u32, wire: u8) -> Vec<u8> {
varint(((field as u64) << 3) | (wire as u64))
}
/// Wire type 2: length-delimited string/bytes field.
pub fn proto_string(field: u32, val: &[u8]) -> Vec<u8> {
let mut out = tag(field, 2);
out.extend(varint(val.len() as u64));
out.extend_from_slice(val);
out
}
/// Wire type 2: length-delimited sub-message field.
pub fn proto_message(field: u32, inner: &[u8]) -> Vec<u8> {
let mut out = tag(field, 2);
out.extend(varint(inner.len() as u64));
out.extend_from_slice(inner);
out
}
/// Wire type 0: boolean field (varint 0 or 1).
pub fn bool_field(field: u32, val: bool) -> Vec<u8> {
let mut out = tag(field, 0);
out.extend(varint(if val { 1 } else { 0 }));
out
}
/// Wire type 0: varint field.
pub fn varint_field(field: u32, val: u64) -> Vec<u8> {
let mut out = tag(field, 0);
out.extend(varint(val));
out
}
// ─── SendUserCascadeMessageRequest builder ───────────────────────────────────
/// Build the `SendUserCascadeMessageRequest` protobuf binary.
///
/// Produces a byte-exact match to real Antigravity webview traffic.
/// Verified against Chrome DevTools network capture 2026-02-12.
///
/// Field layout:
/// 1: cascade_id (string)
/// 2: { 1: text } (message)
/// 3: metadata { 1: client_name, 3: oauth_token, 4: "en", 7: version, 12: client_name }
/// 5: PlannerConfig { 1: inner_config, 7: { 1: 1 } }
/// inner_config contains: f2 (conv mode), f13 (tool config), f15 (model), f21 (ephemeral), f32 (knowledge)
/// 11: conversation_history = true
pub fn build_request(cascade_id: &str, text: &str, oauth_token: &str, model_enum: u32) -> Vec<u8> {
let mut msg = Vec::with_capacity(256);
// Field 1: cascade_id
msg.extend(proto_string(1, cascade_id.as_bytes()));
// Field 2: { field 1: text }
msg.extend(proto_message(2, &proto_string(1, text.as_bytes())));
// Field 3: Metadata (Auth + Client ID)
let mut meta = Vec::new();
meta.extend(proto_string(1, CLIENT_NAME.as_bytes()));
meta.extend(proto_string(3, oauth_token.as_bytes()));
meta.extend(proto_string(4, b"en"));
meta.extend(proto_string(7, client_version().as_bytes()));
meta.extend(proto_string(12, CLIENT_NAME.as_bytes()));
msg.extend(proto_message(3, &meta));
// Field 5: PlannerConfig
let mut inner = Vec::new();
// field 2: conversational mode { f4: 1, f14: 0 }
let conv_mode = [varint_field(4, 1), varint_field(14, 0)].concat();
inner.extend(proto_message(2, &conv_mode));
// field 13: toolConfig
// field 8 (runCommand): field 3 (autoCommandConfig) -> field 6 (policy) = 3 (EAGER)
// field 33 (artifactReviewPolicy): field 1 = 2 (TURBO)
let run_cmd = proto_message(3, &varint_field(6, 3));
let tool_config = [
proto_message(8, &run_cmd),
proto_message(33, &varint_field(1, 2)),
]
.concat();
inner.extend(proto_message(13, &tool_config));
// field 15: requested model { f1: model_enum }
inner.extend(proto_message(15, &varint_field(1, model_enum as u64)));
// field 21: ephemeral messages config { f1: 1 }
inner.extend(proto_message(21, &varint_field(1, 1)));
// field 32: knowledge config { f1: true }
inner.extend(proto_message(32, &bool_field(1, true)));
// Field 5 wraps: field 1 (inner config) + field 7 { f1: 1 }
let f5_payload = [
proto_message(1, &inner),
proto_message(7, &varint_field(1, 1)),
]
.concat();
msg.extend(proto_message(5, &f5_payload));
// Field 11: conversation history flag
msg.extend(bool_field(11, true));
msg
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_varint_zero() {
assert_eq!(varint(0), vec![0x00]);
}
#[test]
fn test_varint_small() {
assert_eq!(varint(1), vec![0x01]);
assert_eq!(varint(127), vec![0x7F]);
}
#[test]
fn test_varint_multibyte() {
assert_eq!(varint(128), vec![0x80, 0x01]);
assert_eq!(varint(300), vec![0xAC, 0x02]);
}
#[test]
fn test_varint_1026() {
// model_enum 1026 = 0x402 → varint [0x82, 0x08]
assert_eq!(varint(1026), vec![0x82, 0x08]);
}
#[test]
fn test_tag() {
// field 1, wire type 2 (LEN) = (1 << 3) | 2 = 0x0A
assert_eq!(tag(1, 2), vec![0x0A]);
// field 3, wire type 0 (VARINT) = (3 << 3) | 0 = 0x18
assert_eq!(tag(3, 0), vec![0x18]);
// field 33, wire type 2 = (33 << 3) | 2 = 266 → varint [0x8A, 0x02]
assert_eq!(tag(33, 2), vec![0x8A, 0x02]);
}
#[test]
fn test_proto_string() {
let result = proto_string(1, b"hi");
// tag(1,2) = 0x0A, len=2, 'h'=0x68, 'i'=0x69
assert_eq!(result, vec![0x0A, 0x02, 0x68, 0x69]);
}
#[test]
fn test_build_request_deterministic() {
let a = build_request("cid", "hello", "ya29.tok", 1026);
let b = build_request("cid", "hello", "ya29.tok", 1026);
assert_eq!(a, b, "build_request must be deterministic");
}
#[test]
fn test_build_request_structure() {
let msg = build_request("test-cascade-id", "hello", "ya29.test-token", 1026);
assert_eq!(msg[0], 0x0A, "first byte must be field 1 tag");
let cascade_bytes = b"test-cascade-id";
assert!(
msg.windows(cascade_bytes.len())
.any(|w| w == cascade_bytes),
"cascade_id must appear in output"
);
assert!(
msg.windows(5).any(|w| w == b"hello"),
"text must appear in output"
);
let token_bytes = b"ya29.test-token";
assert!(
msg.windows(token_bytes.len()).any(|w| w == token_bytes),
"oauth token must appear in output"
);
// model enum 1026 varint [0x82, 0x08]
assert!(
msg.windows(2).any(|w| w == [0x82, 0x08]),
"model enum 1026 varint must appear in output"
);
// field 11 bool true at end: tag(11,0)=0x58, varint(1)=0x01
let len = msg.len();
assert_eq!(msg[len - 2], 0x58);
assert_eq!(msg[len - 1], 0x01);
}
/// Cross-verified against Python output: 127/127 bytes identical.
#[test]
fn test_byte_exact_match_with_python() {
let msg = build_request("test-cascade-id", "hello", "ya29.test-token", 1026);
let hex: String = msg.iter().map(|b| format!("{:02x}", b)).collect();
let expected = "0a0f746573742d636173636164652d696412070a0568656c6c6f\
1a370a0b616e7469677261766974791a0f796132392e746573742d746f6b656e\
2202656e3a06312e31362e35620b616e7469677261766974792a280a22120420\
0170006a0b42041a0230038a020208027a03088208aa010208018202020801\
3a0208015801";
assert_eq!(hex, expected, "must be byte-exact match with Python");
assert_eq!(msg.len(), 127);
}
}

218
src/quota.rs Normal file
View File

@@ -0,0 +1,218 @@
//! Quota monitor — polls the local LS `GetUserStatus` to track
//! prompt/flow credits and per-model rate limits without touching Google servers.
use serde::Serialize;
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{debug, warn};
/// How often to poll the LS for fresh quota data (seconds).
const POLL_INTERVAL_SECS: u64 = 60;
// ─── Public types ────────────────────────────────────────────────────────────
#[derive(Debug, Clone, Serialize, Default)]
pub struct QuotaSnapshot {
/// When this snapshot was last refreshed (ISO-8601 UTC).
pub last_updated: String,
/// Overall plan info.
pub plan: PlanInfo,
/// Monthly credit balances.
pub credits: CreditInfo,
/// Per-model rate limits.
pub models: Vec<ModelQuota>,
}
#[derive(Debug, Clone, Serialize, Default)]
pub struct PlanInfo {
pub plan_name: String,
pub tier_id: String,
pub tier_name: String,
}
#[derive(Debug, Clone, Serialize, Default)]
pub struct CreditInfo {
pub prompt_available: i64,
pub prompt_total: i64,
pub prompt_used_pct: f64,
pub flow_available: i64,
pub flow_total: i64,
pub flow_used_pct: f64,
pub flex_purchasable: i64,
pub can_buy_more: bool,
}
#[derive(Debug, Clone, Serialize, Default)]
pub struct ModelQuota {
pub label: String,
pub model_id: String,
/// 0.01.0 remaining fraction (1.0 = full quota).
pub remaining_fraction: f64,
/// Percentage remaining (0100).
pub remaining_pct: f64,
/// ISO-8601 UTC reset time.
pub reset_time: String,
/// Seconds until reset (negative = already reset).
pub reset_in_secs: i64,
/// Human-readable countdown.
pub reset_in_human: String,
}
// ─── Quota Store ─────────────────────────────────────────────────────────────
#[derive(Clone)]
pub struct QuotaStore {
inner: Arc<RwLock<QuotaSnapshot>>,
}
impl QuotaStore {
pub fn new() -> Self {
Self {
inner: Arc::new(RwLock::new(QuotaSnapshot::default())),
}
}
/// Get the latest cached snapshot.
pub async fn snapshot(&self) -> QuotaSnapshot {
self.inner.read().await.clone()
}
/// Start the background polling loop. Call once at startup.
pub fn start_polling(self, backend: Arc<crate::backend::Backend>) {
tokio::spawn(async move {
// Initial poll immediately.
self.poll_once(&backend).await;
let mut interval = tokio::time::interval(
std::time::Duration::from_secs(POLL_INTERVAL_SECS),
);
interval.tick().await; // consume the first immediate tick
loop {
interval.tick().await;
self.poll_once(&backend).await;
}
});
}
async fn poll_once(&self, backend: &crate::backend::Backend) {
match backend
.call_json("GetUserStatus", &serde_json::json!({}))
.await
{
Ok((200, data)) => {
let snapshot = parse_user_status(&data);
debug!(
"Quota poll: prompt {}/{} flow {}/{}",
snapshot.credits.prompt_available,
snapshot.credits.prompt_total,
snapshot.credits.flow_available,
snapshot.credits.flow_total,
);
*self.inner.write().await = snapshot;
}
Ok((status, data)) => {
warn!("GetUserStatus returned {status}: {data}");
}
Err(e) => {
warn!("GetUserStatus poll failed: {e}");
}
}
}
}
// ─── Parsing ─────────────────────────────────────────────────────────────────
fn parse_user_status(data: &serde_json::Value) -> QuotaSnapshot {
let now = chrono::Utc::now();
let us = &data["userStatus"];
let ps = &us["planStatus"];
let pi = &ps["planInfo"];
let ut = &us["userTier"];
let prompt_total = pi["monthlyPromptCredits"].as_i64().unwrap_or(0);
let prompt_avail = ps["availablePromptCredits"].as_i64().unwrap_or(0);
let flow_total = pi["monthlyFlowCredits"].as_i64().unwrap_or(0);
let flow_avail = ps["availableFlowCredits"].as_i64().unwrap_or(0);
let prompt_used_pct = if prompt_total > 0 {
((prompt_total - prompt_avail) as f64 / prompt_total as f64) * 100.0
} else {
0.0
};
let flow_used_pct = if flow_total > 0 {
((flow_total - flow_avail) as f64 / flow_total as f64) * 100.0
} else {
0.0
};
let models = us["cascadeModelConfigData"]["clientModelConfigs"]
.as_array()
.map(|arr| {
arr.iter()
.map(|m| {
let label = m["label"].as_str().unwrap_or("").to_string();
let model_id = m["modelOrAlias"]["model"]
.as_str()
.unwrap_or("")
.to_string();
let frac = m["quotaInfo"]["remainingFraction"]
.as_f64()
.unwrap_or(0.0);
let reset_str = m["quotaInfo"]["resetTime"]
.as_str()
.unwrap_or("")
.to_string();
let reset_in_secs = if !reset_str.is_empty() {
chrono::DateTime::parse_from_rfc3339(&reset_str)
.map(|dt| (dt.with_timezone(&chrono::Utc) - now).num_seconds())
.unwrap_or(0)
} else {
0
};
let reset_in_human = if reset_in_secs > 0 {
let h = reset_in_secs / 3600;
let m = (reset_in_secs % 3600) / 60;
format!("{h}h {m}m")
} else {
"available".to_string()
};
ModelQuota {
label,
model_id,
remaining_fraction: frac,
remaining_pct: frac * 100.0,
reset_time: reset_str,
reset_in_secs,
reset_in_human,
}
})
.collect()
})
.unwrap_or_default();
QuotaSnapshot {
last_updated: now.to_rfc3339(),
plan: PlanInfo {
plan_name: pi["planName"].as_str().unwrap_or("").to_string(),
tier_id: ut["id"].as_str().unwrap_or("").to_string(),
tier_name: ut["name"].as_str().unwrap_or("").to_string(),
},
credits: CreditInfo {
prompt_available: prompt_avail,
prompt_total,
prompt_used_pct,
flow_available: flow_avail,
flow_total,
flow_used_pct,
flex_purchasable: pi["monthlyFlexCreditPurchaseAmount"]
.as_i64()
.unwrap_or(0),
can_buy_more: pi["canBuyMoreCredits"].as_bool().unwrap_or(false),
},
models,
}
}

152
src/session.rs Normal file
View File

@@ -0,0 +1,152 @@
//! Cascade session manager — maps session IDs to cascade IDs for reuse.
//!
//! Mimics real webview behavior: one chat tab = one cascade with many messages.
//! Without this, every API call creates a new cascade — an obvious automation
//! fingerprint (100 calls = 100 single-message cascades).
use std::collections::HashMap;
use std::time::Instant;
use tokio::sync::RwLock;
const DEFAULT_SESSION: &str = "__default__";
const SESSION_TTL_SECS: u64 = 3600 * 4; // 4 hours
#[derive(Clone)]
struct Session {
cascade_id: String,
created: Instant,
last_used: Instant,
msg_count: u64,
}
pub struct SessionManager {
sessions: RwLock<HashMap<String, Session>>,
}
/// Result of session resolution.
pub struct SessionResult {
pub cascade_id: String,
}
impl SessionManager {
pub fn new() -> Self {
Self {
sessions: RwLock::new(HashMap::new()),
}
}
/// Get existing cascade for session, or create a new one.
///
/// - `session_id = None` → use default session
/// - `session_id = Some("new")` → always create fresh cascade
/// - `session_id = Some("my-task")` → reuse cascade for that task
///
/// Uses double-check locking to avoid TOCTOU races: after creating a cascade,
/// re-acquires the lock and checks if another request raced us.
pub async fn get_or_create<F, Fut>(
&self,
session_id: Option<&str>,
create_fn: F,
) -> Result<SessionResult, String>
where
F: FnOnce() -> Fut,
Fut: std::future::Future<Output = Result<String, String>>,
{
// "new" always creates a fresh cascade
if session_id == Some("new") {
let cascade_id = create_fn().await?;
let new_sid = format!("s-{}", &uuid::Uuid::new_v4().to_string()[..8]);
let mut sessions = self.sessions.write().await;
sessions.insert(
new_sid.clone(),
Session {
cascade_id: cascade_id.clone(),
created: Instant::now(),
last_used: Instant::now(),
msg_count: 0,
},
);
return Ok(SessionResult {
cascade_id,
});
}
let sid = session_id.unwrap_or(DEFAULT_SESSION).to_string();
// Check existing — only need write lock for cleanup + mutation
{
let mut sessions = self.sessions.write().await;
cleanup_expired(&mut sessions);
if let Some(sess) = sessions.get_mut(&sid) {
sess.last_used = Instant::now();
sess.msg_count += 1;
return Ok(SessionResult {
cascade_id: sess.cascade_id.clone(),
});
}
}
// Lock released before async create_fn
// Create new cascade (this may take a while — lock is NOT held)
let cascade_id = create_fn().await?;
// Double-check: another request may have raced us and created the same session
{
let mut sessions = self.sessions.write().await;
if let Some(existing) = sessions.get_mut(&sid) {
// Another request won the race — use their cascade, discard ours
existing.last_used = Instant::now();
existing.msg_count += 1;
return Ok(SessionResult {
cascade_id: existing.cascade_id.clone(),
});
}
sessions.insert(
sid.clone(),
Session {
cascade_id: cascade_id.clone(),
created: Instant::now(),
last_used: Instant::now(),
msg_count: 1,
},
);
}
Ok(SessionResult {
cascade_id,
})
}
/// List all active sessions.
pub async fn list_sessions(&self) -> serde_json::Value {
let mut sessions = self.sessions.write().await;
cleanup_expired(&mut sessions);
let now = Instant::now();
let mut map = serde_json::Map::new();
for (sid, sess) in sessions.iter() {
map.insert(
sid.clone(),
serde_json::json!({
"cascade_id": sess.cascade_id,
"msg_count": sess.msg_count,
"age_seconds": now.duration_since(sess.created).as_secs(),
"idle_seconds": now.duration_since(sess.last_used).as_secs(),
}),
);
}
serde_json::Value::Object(map)
}
/// Delete a session. Returns true if it existed.
pub async fn delete_session(&self, session_id: &str) -> bool {
let mut sessions = self.sessions.write().await;
sessions.remove(session_id).is_some()
}
}
fn cleanup_expired(sessions: &mut HashMap<String, Session>) {
let now = Instant::now();
sessions.retain(|_, s| {
now.duration_since(s.last_used).as_secs() < SESSION_TTL_SECS
});
}

69
src/warmup.rs Normal file
View File

@@ -0,0 +1,69 @@
//! Startup warmup and periodic heartbeat — mimics real webview lifecycle.
//!
//! The real Electron webview calls these methods on startup and then sends
//! Heartbeat every ~30 seconds. Without this, the LS sees a "user" that
//! never initializes and never heartbeats — an obvious bot fingerprint.
use crate::backend::Backend;
use rand::Rng;
use std::sync::Arc;
use std::time::Duration;
use tokio::task::JoinHandle;
use tracing::{debug, info, warn};
/// Run the exact startup sequence the real webview performs on load.
///
/// Called BEFORE accepting any API requests. Each call is fire-and-forget
/// (we don't care if some fail — the LS might not support all methods).
pub async fn warmup_sequence(backend: &Backend) {
info!("Running webview warmup sequence...");
let calls: &[(&str, serde_json::Value)] = &[
("GetStatus", serde_json::json!({})),
("Heartbeat", serde_json::json!({})),
("GetUserStatus", serde_json::json!({})),
("GetCascadeModelConfigs", serde_json::json!({})),
("GetCascadeModelConfigData", serde_json::json!({})),
("GetWorkspaceInfos", serde_json::json!({})),
("GetWorkingDirectories", serde_json::json!({})),
("GetAllCascadeTrajectories", serde_json::json!({})),
("GetMcpServerStates", serde_json::json!({})),
("GetWebDocsOptions", serde_json::json!({})),
("GetRepoInfos", serde_json::json!({})),
("GetAllSkills", serde_json::json!({})),
("InitializeCascadePanelState", serde_json::json!({})),
];
for (method, body) in calls {
match backend.call_json(method, body).await {
Ok((status, _)) => debug!("Warmup {method}: {status}"),
Err(e) => warn!("Warmup {method} failed: {e}"),
}
// Small delay between calls — real webview doesn't blast them instantly
let delay = rand::thread_rng().gen_range(50..200);
tokio::time::sleep(Duration::from_millis(delay)).await;
}
info!("Warmup complete");
}
/// Spawn a background task that sends Heartbeat every ~30s ± jitter.
///
/// Returns a JoinHandle that runs until the task is aborted (on shutdown).
pub fn start_heartbeat(backend: Arc<Backend>) -> JoinHandle<()> {
tokio::spawn(async move {
loop {
// ~30s interval (± 500ms) — matches real setInterval(30000) precision
let interval_ms = rand::thread_rng().gen_range(29_500..30_500);
tokio::time::sleep(Duration::from_millis(interval_ms)).await;
match backend
.call_json("Heartbeat", &serde_json::json!({}))
.await
{
Ok((status, _)) => debug!("Heartbeat: {status}"),
Err(e) => warn!("Heartbeat failed: {e}"),
}
}
})
}