feat: initial commit — antigravity proxy with MITM, standalone LS, and snapshot tooling
This commit is contained in:
8
.gitignore
vendored
Normal file
8
.gitignore
vendored
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
# Build
|
||||||
|
/target/
|
||||||
|
|
||||||
|
# Debug artifacts
|
||||||
|
*.log
|
||||||
|
*.txt
|
||||||
|
!README.txt
|
||||||
|
test_output.json
|
||||||
2470
Cargo.lock
generated
Normal file
2470
Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
46
Cargo.toml
Normal file
46
Cargo.toml
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
[package]
|
||||||
|
name = "antigravity-proxy"
|
||||||
|
version = "3.0.0"
|
||||||
|
edition = "2021"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
axum = { version = "0.8", features = ["json"] }
|
||||||
|
tokio = { version = "1", features = ["full"] }
|
||||||
|
wreq = { version = "6.0.0-rc.28", features = ["json"] }
|
||||||
|
wreq-util = "3.0.0-rc.10"
|
||||||
|
serde = { version = "1", features = ["derive"] }
|
||||||
|
serde_json = "1"
|
||||||
|
uuid = { version = "1", features = ["v4"] }
|
||||||
|
regex = "1"
|
||||||
|
async-stream = "0.3"
|
||||||
|
tower-http = { version = "0.6", features = ["cors"] }
|
||||||
|
tracing = "0.1"
|
||||||
|
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||||
|
clap = { version = "4", features = ["derive"] }
|
||||||
|
rand = "0.8"
|
||||||
|
flate2 = "1"
|
||||||
|
brotli = "7"
|
||||||
|
chrono = "0.4"
|
||||||
|
|
||||||
|
# MITM proxy dependencies
|
||||||
|
rcgen = "0.13"
|
||||||
|
rustls = { version = "0.23", features = ["ring"] }
|
||||||
|
tokio-rustls = "0.26"
|
||||||
|
rustls-native-certs = "0.8"
|
||||||
|
rustls-pemfile = "2"
|
||||||
|
time = "0.3"
|
||||||
|
base64 = "0.22"
|
||||||
|
httparse = "1"
|
||||||
|
|
||||||
|
# HTTP/2 + gRPC interception
|
||||||
|
hyper = { version = "1", features = ["http2", "client", "server"] }
|
||||||
|
hyper-util = { version = "0.1", features = ["tokio"] }
|
||||||
|
http-body-util = "0.1"
|
||||||
|
http = "1"
|
||||||
|
bytes = "1"
|
||||||
|
tokio-stream = "0.1"
|
||||||
|
|
||||||
|
[profile.release]
|
||||||
|
opt-level = "z"
|
||||||
|
lto = true
|
||||||
|
strip = true
|
||||||
155
GEMINI.md
Normal file
155
GEMINI.md
Normal file
@@ -0,0 +1,155 @@
|
|||||||
|
# Antigravity Rust Proxy
|
||||||
|
|
||||||
|
OpenAI-compatible proxy that intercepts and relays requests to Google's Antigravity language server, impersonating the real Electron webview.
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Build
|
||||||
|
cargo build --release
|
||||||
|
|
||||||
|
# Run (language server must be running)
|
||||||
|
RUST_LOG=info ./target/release/antigravity-proxy
|
||||||
|
|
||||||
|
# Custom port
|
||||||
|
RUST_LOG=info ./target/release/antigravity-proxy --port 9000
|
||||||
|
```
|
||||||
|
|
||||||
|
Default port: **8741**
|
||||||
|
|
||||||
|
## Endpoints
|
||||||
|
|
||||||
|
| Method | Path | Description |
|
||||||
|
| -------- | ---------------------- | ----------------------------------------------------------- |
|
||||||
|
| `POST` | `/v1/responses` | **Responses API** (primary) — supports `stream: true/false` |
|
||||||
|
| `POST` | `/v1/chat/completions` | Chat Completions API (OpenAI compat shim) |
|
||||||
|
| `GET` | `/v1/models` | List available models |
|
||||||
|
| `GET` | `/v1/sessions` | List active sessions |
|
||||||
|
| `DELETE` | `/v1/sessions/:id` | Delete a session |
|
||||||
|
| `POST` | `/v1/token` | Set OAuth token at runtime |
|
||||||
|
| `GET` | `/v1/usage` | MITM-intercepted token usage stats |
|
||||||
|
| `GET` | `/v1/quota` | LS quota — credits, per-model rate limits, reset timers |
|
||||||
|
| `GET` | `/health` | Health check |
|
||||||
|
|
||||||
|
## Available Models
|
||||||
|
|
||||||
|
| Name | Label |
|
||||||
|
| ------------------- | ---------------------------------------- |
|
||||||
|
| `opus-4.6` | Claude Opus 4.6 (Thinking) — **default** |
|
||||||
|
| `opus-4.5` | Claude Opus 4.5 (Thinking) |
|
||||||
|
| `gemini-3-pro-high` | Gemini 3 Pro (High) |
|
||||||
|
| `gemini-3-pro` | Gemini 3 Pro (Low) |
|
||||||
|
| `gemini-3-flash` | Gemini 3 Flash |
|
||||||
|
|
||||||
|
## Example: Responses API
|
||||||
|
|
||||||
|
### Sync
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -s http://localhost:8741/v1/responses \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"model": "gemini-3-flash",
|
||||||
|
"input": "Say hello in exactly 3 words",
|
||||||
|
"stream": false,
|
||||||
|
"timeout": 60
|
||||||
|
}' | jq .
|
||||||
|
```
|
||||||
|
|
||||||
|
### Streaming
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -N http://localhost:8741/v1/responses \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"model": "gemini-3-flash",
|
||||||
|
"input": "Say hello in exactly 3 words",
|
||||||
|
"stream": true,
|
||||||
|
"timeout": 60
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
### Multi-turn (session reuse)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -s http://localhost:8741/v1/responses \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"model": "gemini-3-flash",
|
||||||
|
"input": "What is 2+2?",
|
||||||
|
"conversation": "my-session-1",
|
||||||
|
"stream": false
|
||||||
|
}' | jq .
|
||||||
|
|
||||||
|
# Follow-up in same cascade:
|
||||||
|
curl -s http://localhost:8741/v1/responses \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"model": "gemini-3-flash",
|
||||||
|
"input": "Now multiply that by 10",
|
||||||
|
"conversation": "my-session-1",
|
||||||
|
"stream": false
|
||||||
|
}' | jq .
|
||||||
|
```
|
||||||
|
|
||||||
|
## Authentication
|
||||||
|
|
||||||
|
The proxy needs an OAuth token. Three ways to provide it:
|
||||||
|
|
||||||
|
1. **Environment variable**: `export ANTIGRAVITY_OAUTH_TOKEN=ya29.xxx`
|
||||||
|
2. **Token file**: `echo 'ya29.xxx' > ~/.config/antigravity-proxy-token`
|
||||||
|
3. **Runtime API**: `curl -X POST http://localhost:8741/v1/token -d '{"token":"ya29.xxx"}'`
|
||||||
|
|
||||||
|
## Version Detection
|
||||||
|
|
||||||
|
Version strings (Antigravity, Chrome, Electron, Client) are **auto-detected** at startup from the installed Antigravity app:
|
||||||
|
|
||||||
|
- `product.json` → app version + client/IDE version
|
||||||
|
- Binary → Chrome + Electron versions via `strings`
|
||||||
|
|
||||||
|
Falls back to hardcoded values if the app isn't installed. No manual updates needed when Antigravity updates.
|
||||||
|
|
||||||
|
## Stealth Features
|
||||||
|
|
||||||
|
- **TLS fingerprint**: BoringSSL with Chrome 142 JA3/JA4 + H2 fingerprint via `wreq`
|
||||||
|
- **Protobuf**: Hand-rolled encoder producing byte-exact match to real webview traffic
|
||||||
|
- **Warmup**: Mimics real webview startup RPC calls
|
||||||
|
- **Heartbeat**: Periodic keep-alive matching real webview lifecycle
|
||||||
|
- **Jitter**: Randomized polling intervals to avoid automation fingerprint
|
||||||
|
- **Session reuse**: Cascades are reused for multi-turn, matching real webview behavior
|
||||||
|
- **MITM proxy**: TLS-intercepting proxy for real token usage capture (opt-in)
|
||||||
|
|
||||||
|
## MITM Proxy
|
||||||
|
|
||||||
|
Built-in MITM proxy intercepts LS ↔ Google/Anthropic traffic to capture **real** token usage (input, output, cache read, cache creation). Disabled with `--no-mitm`.
|
||||||
|
|
||||||
|
### Setup
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 1. Start proxy (generates CA cert automatically)
|
||||||
|
RUST_LOG=info ./target/release/antigravity-proxy
|
||||||
|
|
||||||
|
# 2. Install wrapper (patches LS binary to route through MITM)
|
||||||
|
./scripts/mitm-wrapper.sh install
|
||||||
|
|
||||||
|
# 3. Restart Antigravity — done!
|
||||||
|
|
||||||
|
# Check status
|
||||||
|
./scripts/mitm-wrapper.sh status
|
||||||
|
|
||||||
|
# Uninstall
|
||||||
|
./scripts/mitm-wrapper.sh uninstall
|
||||||
|
```
|
||||||
|
|
||||||
|
### Usage Stats
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -s http://localhost:8741/v1/usage | jq .
|
||||||
|
```
|
||||||
|
|
||||||
|
Returns aggregate token counts from all intercepted API calls.
|
||||||
|
|
||||||
|
### CLI Flags
|
||||||
|
|
||||||
|
- `--no-mitm`: Disable MITM proxy entirely
|
||||||
|
- `--mitm-port <PORT>`: Override MITM proxy port (default: auto-assign)
|
||||||
109
KNOWN_ISSUES.md
Normal file
109
KNOWN_ISSUES.md
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
# Known Issues & Future Work
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Medium
|
||||||
|
|
||||||
|
### 1. Cascade Correlation Is Heuristic
|
||||||
|
|
||||||
|
**File:** `src/mitm/intercept.rs` — `extract_cascade_hint()`
|
||||||
|
|
||||||
|
The MITM proxy matches intercepted API traffic to cascade IDs by scanning for `metadata.user_id` or `workspace_id` in the request body. If neither is found, it stores under `_latest`. Since `take_usage()` no longer falls back to `_latest`, unidentified requests will have **no MITM usage data at all**.
|
||||||
|
|
||||||
|
**Fix:** Investigate the actual request body format the LS sends for better correlation keys. Alternatively, use timing-based correlation (match MITM capture timestamp to cascade polling window).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 2. Domain Certificate Cache Is Unbounded
|
||||||
|
|
||||||
|
**File:** `src/mitm/ca.rs` — `domain_cache`
|
||||||
|
|
||||||
|
The `domain_cache` (`HashMap<String, Arc<ServerConfig>>`) grows without bound. Each unique domain gets a cached entry containing a full `ServerConfig` with an RSA key pair. In practice, only ~5–10 domains are intercepted so this is unlikely to matter, but there's no eviction.
|
||||||
|
|
||||||
|
**Fix:** Set a max cache size (e.g., 100 entries) with LRU eviction, or use a TTL since leaf certs are generated with a 1-year validity.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 3. Request Modification Not Implemented
|
||||||
|
|
||||||
|
**File:** `src/mitm/proxy.rs` — `modify_requests: false`
|
||||||
|
|
||||||
|
The `MitmConfig.modify_requests` flag exists and is plumbed through, but no actual modification logic is implemented. The flag is hardcoded to `false`.
|
||||||
|
|
||||||
|
**Fix:** When needed, implement request body mutation in `handle_http_over_tls()` — parse JSON, modify, reserialize, update `Content-Length`.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 4. `total_cost_usd` Is Dead
|
||||||
|
|
||||||
|
**File:** `src/mitm/store.rs` (line 28)
|
||||||
|
|
||||||
|
`ApiUsage.total_cost_usd` is `Option<f64>` but is **always `None`** — set to `None` in all 4 construction sites (`h2_handler.rs` ×2, `intercept.rs` ×2). Neither Anthropic nor Google include cost in API responses.
|
||||||
|
|
||||||
|
**Fix:** Either remove the field (simpler), or populate it via a pricing table lookup (model → $/1K tokens) at `record_usage()` time.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 🟢 Low
|
||||||
|
|
||||||
|
### 5. Wrapper Script Fallback Paths May Be Stale
|
||||||
|
|
||||||
|
**File:** `scripts/mitm-wrapper.sh` — `LS_FALLBACK_DIRS`
|
||||||
|
|
||||||
|
The fallback glob patterns (e.g., `~/.cursor/extensions/antigravity.antigravity-*/...`) assume a specific extension naming convention. These are only used when no running LS process is found via `/proc` scanning (Method 1), which is the primary and robust discovery mechanism.
|
||||||
|
|
||||||
|
**Impact:** Only affects `install` when the LS isn't running. Low priority.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 6. No Integration Tests for MITM Module
|
||||||
|
|
||||||
|
The MITM module has unit tests for protobuf decoding and intercept parsing, but no integration tests that verify:
|
||||||
|
|
||||||
|
- TLS interception end-to-end with the generated CA
|
||||||
|
- Full HTTP/1.1 request/response cycle through the proxy
|
||||||
|
- gRPC (HTTP/2) request/response cycle through `h2_handler`
|
||||||
|
- Store recording and retrieval under concurrency
|
||||||
|
- Wrapper script install/uninstall lifecycle
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 🔍 Investigation
|
||||||
|
|
||||||
|
### 7. LS Exposes Credit/Quota Data via `GetUserStatus`
|
||||||
|
|
||||||
|
**Confirmed via live probing.** The LS's `GetUserStatus` RPC already returns structured cost/quota data:
|
||||||
|
|
||||||
|
```json
|
||||||
|
"planStatus": {
|
||||||
|
"planInfo": {
|
||||||
|
"planName": "Pro",
|
||||||
|
"monthlyPromptCredits": 50000,
|
||||||
|
"monthlyFlowCredits": 150000,
|
||||||
|
"monthlyFlexCreditPurchaseAmount": 25000,
|
||||||
|
"canBuyMoreCredits": true
|
||||||
|
},
|
||||||
|
"availablePromptCredits": 500,
|
||||||
|
"availableFlowCredits": 100
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Each model also includes **per-model quota info**:
|
||||||
|
|
||||||
|
```json
|
||||||
|
"quotaInfo": {
|
||||||
|
"remainingFraction": 0.2,
|
||||||
|
"resetTime": "2026-02-14T07:41:37Z"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Key findings:**
|
||||||
|
|
||||||
|
- `GetUserStatus` is the single source for credit/quota data (exposed via `LanguageServerService`)
|
||||||
|
- `SeatManagementService` methods (`GetPlanStatus`, `GetTeamCreditEntries`, `GetCascadeAnalytics`, `GetUserSubscription`) are **not routed through the LS** — they're backend-only
|
||||||
|
- `PredictionService/RetrieveUserQuota` is also backend-only (not proxied by LS)
|
||||||
|
- `GetUserAnalyticsSummary` returns empty `{}` (may not be implemented or requires different context)
|
||||||
|
- `GetModelStatuses` returns empty `{}` (separate from the model configs in `GetUserStatus`)
|
||||||
|
- `userTier` field shows subscription level: `{"id": "g1-ultra-tier", "name": "Google AI Ultra"}`
|
||||||
|
|
||||||
|
**Potential use:** We could poll `GetUserStatus` periodically and expose `availablePromptCredits`, `availableFlowCredits`, and per-model `remainingFraction` via the `/v1/usage` endpoint — giving users real-time credit burn visibility without needing MITM token counting.
|
||||||
239
README.md
Normal file
239
README.md
Normal file
@@ -0,0 +1,239 @@
|
|||||||
|
# Antigravity Proxy
|
||||||
|
|
||||||
|
OpenAI-compatible proxy that intercepts and relays requests to Google's Antigravity language server, impersonating the real Electron webview.
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Build
|
||||||
|
cargo build --release
|
||||||
|
|
||||||
|
# Run (language server must be running)
|
||||||
|
RUST_LOG=info ./target/release/antigravity-proxy
|
||||||
|
|
||||||
|
# Custom port
|
||||||
|
RUST_LOG=info ./target/release/antigravity-proxy --port 9000
|
||||||
|
```
|
||||||
|
|
||||||
|
Default port: **8741**
|
||||||
|
|
||||||
|
## Endpoints
|
||||||
|
|
||||||
|
| Method | Path | Description |
|
||||||
|
| -------- | ---------------------- | ----------------------------------------------------------- |
|
||||||
|
| `POST` | `/v1/responses` | **Responses API** (primary) — supports `stream: true/false` |
|
||||||
|
| `POST` | `/v1/chat/completions` | Chat Completions API (OpenAI compat shim) |
|
||||||
|
| `GET` | `/v1/models` | List available models |
|
||||||
|
| `GET` | `/v1/sessions` | List active sessions |
|
||||||
|
| `DELETE` | `/v1/sessions/:id` | Delete a session |
|
||||||
|
| `POST` | `/v1/token` | Set OAuth token at runtime |
|
||||||
|
| `GET` | `/v1/usage` | MITM-intercepted token usage stats |
|
||||||
|
| `GET` | `/v1/quota` | LS quota — credits, per-model rate limits, reset timers |
|
||||||
|
| `GET` | `/health` | Health check |
|
||||||
|
|
||||||
|
## Available Models
|
||||||
|
|
||||||
|
| Name | Label |
|
||||||
|
| ------------------- | ---------------------------------------- |
|
||||||
|
| `opus-4.6` | Claude Opus 4.6 (Thinking) — **default** |
|
||||||
|
| `opus-4.5` | Claude Opus 4.5 (Thinking) |
|
||||||
|
| `gemini-3-pro-high` | Gemini 3 Pro (High) |
|
||||||
|
| `gemini-3-pro` | Gemini 3 Pro (Low) |
|
||||||
|
| `gemini-3-flash` | Gemini 3 Flash |
|
||||||
|
|
||||||
|
## Example: Responses API
|
||||||
|
|
||||||
|
### Sync
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -s http://localhost:8741/v1/responses \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"model": "gemini-3-flash",
|
||||||
|
"input": "Say hello in exactly 3 words",
|
||||||
|
"stream": false,
|
||||||
|
"timeout": 60
|
||||||
|
}' | jq .
|
||||||
|
```
|
||||||
|
|
||||||
|
### Streaming
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -N http://localhost:8741/v1/responses \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"model": "gemini-3-flash",
|
||||||
|
"input": "Say hello in exactly 3 words",
|
||||||
|
"stream": true,
|
||||||
|
"timeout": 60
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
### Multi-turn (session reuse)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# First message
|
||||||
|
curl -s http://localhost:8741/v1/responses \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"model": "gemini-3-flash",
|
||||||
|
"input": "What is 2+2?",
|
||||||
|
"conversation": "my-session-1",
|
||||||
|
"stream": false
|
||||||
|
}' | jq .
|
||||||
|
|
||||||
|
# Follow-up in same cascade
|
||||||
|
curl -s http://localhost:8741/v1/responses \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"model": "gemini-3-flash",
|
||||||
|
"input": "Now multiply that by 10",
|
||||||
|
"conversation": "my-session-1",
|
||||||
|
"stream": false
|
||||||
|
}' | jq .
|
||||||
|
```
|
||||||
|
|
||||||
|
## Authentication
|
||||||
|
|
||||||
|
The proxy needs an OAuth token. Three ways to provide it:
|
||||||
|
|
||||||
|
1. **Environment variable**: `export ANTIGRAVITY_OAUTH_TOKEN=ya29.xxx`
|
||||||
|
2. **Token file**: `echo 'ya29.xxx' > ~/.config/antigravity-proxy-token`
|
||||||
|
3. **Runtime API**: `curl -X POST http://localhost:8741/v1/token -d '{"token":"ya29.xxx"}'`
|
||||||
|
|
||||||
|
## Stealth Features
|
||||||
|
|
||||||
|
- **TLS fingerprint**: BoringSSL with Chrome 142 JA3/JA4 + H2 fingerprint via `wreq`
|
||||||
|
- **Protobuf**: Hand-rolled encoder producing byte-exact match to real webview traffic
|
||||||
|
- **Warmup**: Mimics real webview startup RPC calls
|
||||||
|
- **Heartbeat**: Periodic keep-alive matching real webview lifecycle
|
||||||
|
- **Jitter**: Randomized polling intervals to avoid automation fingerprint
|
||||||
|
- **Session reuse**: Cascades are reused for multi-turn, matching real webview behavior
|
||||||
|
- **Version detection**: Auto-detects Antigravity/Chrome/Electron versions from installed app
|
||||||
|
|
||||||
|
## MITM Proxy
|
||||||
|
|
||||||
|
Built-in TLS-intercepting proxy captures real token usage from LS ↔ Google/Anthropic traffic. Disabled with `--no-mitm`.
|
||||||
|
|
||||||
|
### Setup
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 1. Start proxy (generates CA cert automatically)
|
||||||
|
RUST_LOG=info ./target/release/antigravity-proxy
|
||||||
|
|
||||||
|
# 2. Install wrapper (patches LS binary to route through MITM)
|
||||||
|
sudo ./scripts/mitm-wrapper.sh install
|
||||||
|
|
||||||
|
# 3. Restart Antigravity — done!
|
||||||
|
|
||||||
|
# Check status
|
||||||
|
./scripts/mitm-wrapper.sh status
|
||||||
|
|
||||||
|
# Uninstall
|
||||||
|
sudo ./scripts/mitm-wrapper.sh uninstall
|
||||||
|
```
|
||||||
|
|
||||||
|
### Usage Stats
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -s http://localhost:8741/v1/usage | jq .
|
||||||
|
```
|
||||||
|
|
||||||
|
## Standalone Language Server
|
||||||
|
|
||||||
|
Launch an isolated LS instance for experimentation:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Basic test (starts, checks quota, exits)
|
||||||
|
./scripts/standalone-ls.sh
|
||||||
|
|
||||||
|
# Foreground mode (stays alive)
|
||||||
|
./scripts/standalone-ls.sh --fg
|
||||||
|
|
||||||
|
# With MITM traffic interception
|
||||||
|
./scripts/standalone-ls.sh --mitm
|
||||||
|
|
||||||
|
# Capture a clean traffic snapshot
|
||||||
|
./scripts/standalone-ls.sh --snapshot
|
||||||
|
|
||||||
|
# Snapshot with custom prompt
|
||||||
|
./scripts/standalone-ls.sh --snapshot --prompt "Explain quantum computing"
|
||||||
|
```
|
||||||
|
|
||||||
|
The standalone LS shares the main Antigravity app's OAuth (via its extension server) but has its own port, data directory, and cascades.
|
||||||
|
|
||||||
|
### Traffic Snapshots
|
||||||
|
|
||||||
|
The `--snapshot` flag captures all HTTP/2 traffic and formats it into a clean, color-coded report:
|
||||||
|
|
||||||
|
```
|
||||||
|
══════════════════════════════════════════════════════════════════════
|
||||||
|
STANDALONE LS TRAFFIC SNAPSHOT
|
||||||
|
══════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
▸ Outbound Connections
|
||||||
|
→ antigravity-unleash.goog (Feature Flags)
|
||||||
|
→ play.googleapis.com (Telemetry)
|
||||||
|
|
||||||
|
══════════════════════════════════════════════════════════════════════
|
||||||
|
antigravity-unleash.goog — Feature Flags
|
||||||
|
══════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
→ POST /api/client/register
|
||||||
|
authorization: *:production.e4455...
|
||||||
|
unleash-appname: codeium-language-server
|
||||||
|
Body (561 bytes, JSON):
|
||||||
|
{"appName":"codeium-language-server","instanceId":"..."}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
```mermaid
|
||||||
|
graph LR
|
||||||
|
A[Your App<br/>OpenAI SDK] -->|HTTP| B[Proxy<br/>:8741]
|
||||||
|
B -->|gRPC| C[Language<br/>Server]
|
||||||
|
C -->|HTTPS| D[Google /<br/>Anthropic]
|
||||||
|
E[MITM Proxy<br/>:8742] -.->|intercept| D
|
||||||
|
C -.->|routed via| E
|
||||||
|
```
|
||||||
|
|
||||||
|
## Project Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
src/
|
||||||
|
├── main.rs # Entry point, CLI args, lifecycle
|
||||||
|
├── backend.rs # LS discovery and RPC communication
|
||||||
|
├── constants.rs # Version detection + stealth constants
|
||||||
|
├── proto.rs # Hand-rolled protobuf encoder
|
||||||
|
├── quota.rs # LS quota polling and caching
|
||||||
|
├── session.rs # Multi-turn session management
|
||||||
|
├── warmup.rs # Startup warmup (mimics real webview)
|
||||||
|
├── api/
|
||||||
|
│ └── mod.rs # Axum API server + route handlers
|
||||||
|
└── mitm/
|
||||||
|
├── mod.rs # MITM module root
|
||||||
|
├── ca.rs # Dynamic CA cert generation
|
||||||
|
├── proxy.rs # TLS-intercepting proxy server
|
||||||
|
├── intercept.rs # API response parser (usage extraction)
|
||||||
|
└── store.rs # Token usage aggregation store
|
||||||
|
|
||||||
|
scripts/
|
||||||
|
├── mitm-wrapper.sh # Install/uninstall MITM wrapper on LS binary
|
||||||
|
├── standalone-ls.sh # Launch isolated LS instance
|
||||||
|
└── parse-snapshot.py # HTTP/2 traffic snapshot parser
|
||||||
|
```
|
||||||
|
|
||||||
|
## CLI Flags
|
||||||
|
|
||||||
|
```
|
||||||
|
antigravity-proxy [OPTIONS]
|
||||||
|
|
||||||
|
Options:
|
||||||
|
--port <PORT> API server port (default: 8741)
|
||||||
|
--no-mitm Disable MITM proxy
|
||||||
|
--mitm-port <PORT> Override MITM proxy port (default: auto)
|
||||||
|
```
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
Private. Do not distribute.
|
||||||
331
scripts/mitm-wrapper.sh
Executable file
331
scripts/mitm-wrapper.sh
Executable file
@@ -0,0 +1,331 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
# ╔═══════════════════════════════════════════════════════════════════════════╗
|
||||||
|
# ║ Antigravity MITM LS Wrapper ║
|
||||||
|
# ║ ║
|
||||||
|
# ║ This script replaces the real Antigravity language server binary. ║
|
||||||
|
# ║ It injects HTTPS_PROXY and NODE_EXTRA_CA_CERTS environment variables ║
|
||||||
|
# ║ so the MITM proxy can intercept LS<->API traffic. ║
|
||||||
|
# ║ ║
|
||||||
|
# ║ Install: ./mitm-wrapper.sh install ║
|
||||||
|
# ║ Uninstall: ./mitm-wrapper.sh uninstall ║
|
||||||
|
# ║ (No args = act as wrapper, exec the real binary with injected env) ║
|
||||||
|
# ╚═══════════════════════════════════════════════════════════════════════════╝
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
# ── Config ────────────────────────────────────────────────────────────────────
|
||||||
|
# Resolve the real user's home (not /root when running under sudo)
|
||||||
|
if [[ -n "${SUDO_USER:-}" ]]; then
|
||||||
|
REAL_HOME="$(getent passwd "$SUDO_USER" | cut -d: -f6)"
|
||||||
|
else
|
||||||
|
REAL_HOME="$HOME"
|
||||||
|
fi
|
||||||
|
MITM_PORT_FILE="${REAL_HOME}/.config/antigravity-proxy/mitm-port"
|
||||||
|
if [[ -n "${ANTIGRAVITY_MITM_PORT:-}" ]]; then
|
||||||
|
MITM_PORT="$ANTIGRAVITY_MITM_PORT"
|
||||||
|
elif [[ -f "$MITM_PORT_FILE" ]]; then
|
||||||
|
MITM_PORT="$(cat "$MITM_PORT_FILE" 2>/dev/null || echo 8742)"
|
||||||
|
else
|
||||||
|
MITM_PORT="8742"
|
||||||
|
fi
|
||||||
|
CA_PATH="${REAL_HOME}/.config/antigravity-proxy/mitm-ca.pem"
|
||||||
|
|
||||||
|
# Antigravity LS — discovered dynamically from running processes.
|
||||||
|
# Hardcoded paths are only used as a fallback if no LS process is running.
|
||||||
|
LS_FALLBACK_DIRS=(
|
||||||
|
"/usr/share/antigravity/resources/app/extensions/antigravity/bin"
|
||||||
|
"${REAL_HOME}/.antigravity/extensions/antigravity.antigravity-*/dist/bundled/language-server/bin"
|
||||||
|
"${REAL_HOME}/.cursor/extensions/antigravity.antigravity-*/dist/bundled/language-server/bin"
|
||||||
|
"${REAL_HOME}/.vscode/extensions/antigravity.antigravity-*/dist/bundled/language-server/bin"
|
||||||
|
"/opt/antigravity/language-server/bin"
|
||||||
|
)
|
||||||
|
|
||||||
|
BACKUP_SUFFIX=".real"
|
||||||
|
|
||||||
|
# ── Colors ────────────────────────────────────────────────────────────────────
|
||||||
|
RED='\033[0;31m'
|
||||||
|
GREEN='\033[0;32m'
|
||||||
|
YELLOW='\033[0;33m'
|
||||||
|
CYAN='\033[0;36m'
|
||||||
|
BOLD='\033[1m'
|
||||||
|
NC='\033[0m'
|
||||||
|
|
||||||
|
# ── Find LS binary ───────────────────────────────────────────────────────────
|
||||||
|
find_ls_binary() {
|
||||||
|
# Method 1: Find from running process via /proc
|
||||||
|
if [[ -d /proc ]]; then
|
||||||
|
for pid_dir in /proc/[0-9]*; do
|
||||||
|
local exe_target
|
||||||
|
exe_target="$(readlink "${pid_dir}/exe" 2>/dev/null)" || continue
|
||||||
|
# Strip " (deleted)" suffix that appears when the binary was unlinked
|
||||||
|
exe_target="${exe_target% (deleted)}"
|
||||||
|
if [[ "$exe_target" == *language_server_linux* ]] || \
|
||||||
|
[[ "$exe_target" == *antigravity-language-server* ]]; then
|
||||||
|
# FIX: If the running process is the backup (.real), strip the suffix
|
||||||
|
# so we return the path to the base binary name.
|
||||||
|
echo "${exe_target%$BACKUP_SUFFIX}"
|
||||||
|
return 0
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Method 2: Fallback — scan known directories for common binary names
|
||||||
|
local bin_names=("language_server_linux_x64" "language_server_linux_arm64" "antigravity-language-server")
|
||||||
|
for dir_pattern in "${LS_FALLBACK_DIRS[@]}"; do
|
||||||
|
for dir in $dir_pattern; do
|
||||||
|
[[ -d "$dir" ]] || continue
|
||||||
|
for name in "${bin_names[@]}"; do
|
||||||
|
local path="${dir}/${name}"
|
||||||
|
if [[ -f "$path" || -f "${path}${BACKUP_SUFFIX}" ]]; then
|
||||||
|
echo "$path"
|
||||||
|
return 0
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
done
|
||||||
|
done
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
# ── Install ──────────────────────────────────────────────────────────────────
|
||||||
|
cmd_install() {
|
||||||
|
# Find the LS binary first (quiet, just to check permissions)
|
||||||
|
local ls_path
|
||||||
|
ls_path=$(find_ls_binary) || ls_path="${1:-}"
|
||||||
|
|
||||||
|
# Allow override
|
||||||
|
if [[ -n "${1:-}" ]]; then
|
||||||
|
ls_path="$1"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Check permissions upfront — re-exec with sudo before doing anything
|
||||||
|
if [[ -n "$ls_path" ]]; then
|
||||||
|
local ls_dir
|
||||||
|
ls_dir="$(dirname "$ls_path")"
|
||||||
|
if [[ ! -w "$ls_dir" ]] && [[ "$EUID" -ne 0 ]]; then
|
||||||
|
echo -e " ${RED}✗${NC} ${ls_dir} requires elevated permissions"
|
||||||
|
echo -e " run: sudo $0 install ${1:-}"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo -e "${BOLD}${CYAN}Antigravity MITM Wrapper Installer${NC}"
|
||||||
|
echo -e "───────────────────────────────────"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Find the LS binary (for real this time, with output)
|
||||||
|
if [[ -z "$ls_path" ]]; then
|
||||||
|
echo -e " ${RED}✗${NC} Could not find Antigravity language server binary."
|
||||||
|
echo -e " No LS process found in /proc, and fallback paths didn't match."
|
||||||
|
echo ""
|
||||||
|
echo -e " Set the path manually:"
|
||||||
|
echo -e " $0 install /path/to/language_server_linux_x64"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
echo -e " ${GREEN}✓${NC} Found LS: ${ls_path}"
|
||||||
|
|
||||||
|
local real_path="${ls_path}${BACKUP_SUFFIX}"
|
||||||
|
local wrapper_dir
|
||||||
|
wrapper_dir="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||||
|
local wrapper_src="${wrapper_dir}/mitm-wrapper.sh"
|
||||||
|
|
||||||
|
# Verify the binary exists and is not already wrapped
|
||||||
|
if [[ -f "$real_path" ]]; then
|
||||||
|
echo -e " ${YELLOW}!${NC} Already installed (backup exists at ${real_path})"
|
||||||
|
echo -e " Run '$0 uninstall' first to reinstall."
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ ! -f "$ls_path" ]]; then
|
||||||
|
echo -e " ${RED}✗${NC} Binary not found: ${ls_path}"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Verify it's a real binary, not already our wrapper
|
||||||
|
if head -c 100 "$ls_path" | grep -q 'ANTIGRAVITY_MITM_PORT'; then
|
||||||
|
echo -e " ${YELLOW}!${NC} Already wrapped (script detected). Run '$0 uninstall' first."
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Check CA cert
|
||||||
|
if [[ ! -f "$CA_PATH" ]]; then
|
||||||
|
echo -e " ${YELLOW}!${NC} CA cert not found at ${CA_PATH}"
|
||||||
|
echo -e " Start the proxy first to generate it."
|
||||||
|
echo -e " Continuing install anyway..."
|
||||||
|
else
|
||||||
|
echo -e " ${GREEN}✓${NC} CA cert: ${CA_PATH}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Back up the real binary
|
||||||
|
cp -p "$ls_path" "$real_path"
|
||||||
|
echo -e " ${GREEN}✓${NC} Backed up real binary to ${real_path}"
|
||||||
|
|
||||||
|
# Remove the original before writing (avoids "Text file busy" if LS is running)
|
||||||
|
rm -f "$ls_path"
|
||||||
|
|
||||||
|
# Create the wrapper script in-place
|
||||||
|
tee "$ls_path" > /dev/null << 'WRAPPER_EOF'
|
||||||
|
#!/usr/bin/env bash
|
||||||
|
# Antigravity MITM LS Wrapper — auto-generated, do not edit.
|
||||||
|
# The LS is a Go binary — it reads HTTPS_PROXY and SSL_CERT_FILE (not NODE_EXTRA_CA_CERTS).
|
||||||
|
# Go's gRPC library also reads GRPC_DEFAULT_SSL_ROOTS_FILE_PATH for root certs.
|
||||||
|
# We build a combined CA bundle (system CAs + MITM CA) and inject it.
|
||||||
|
|
||||||
|
REAL_BINARY="${BASH_SOURCE[0]}.real"
|
||||||
|
|
||||||
|
if [[ ! -f "$REAL_BINARY" ]]; then
|
||||||
|
echo "ERROR: Real LS binary not found at $REAL_BINARY" >&2
|
||||||
|
echo "Run 'mitm-wrapper.sh uninstall' and reinstall." >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Inject MITM proxy (don't override if already set)
|
||||||
|
export HTTPS_PROXY="${HTTPS_PROXY:-http://127.0.0.1:__MITM_PORT__}"
|
||||||
|
|
||||||
|
# Build combined CA bundle: system CAs + MITM CA
|
||||||
|
MITM_CA="__CA_PATH__"
|
||||||
|
COMBINED_CA="/tmp/antigravity-mitm-combined-ca.pem"
|
||||||
|
if [[ -f "$MITM_CA" ]]; then
|
||||||
|
# Find system CA bundle
|
||||||
|
SYS_CA=""
|
||||||
|
for candidate in /etc/ssl/certs/ca-certificates.crt /etc/pki/tls/certs/ca-bundle.crt /etc/ssl/cert.pem; do
|
||||||
|
if [[ -f "$candidate" ]]; then
|
||||||
|
SYS_CA="$candidate"
|
||||||
|
break
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
if [[ -n "$SYS_CA" ]]; then
|
||||||
|
cat "$SYS_CA" "$MITM_CA" > "$COMBINED_CA" 2>/dev/null
|
||||||
|
export SSL_CERT_FILE="$COMBINED_CA"
|
||||||
|
# Go's gRPC library may use this instead of SSL_CERT_FILE
|
||||||
|
export GRPC_DEFAULT_SSL_ROOTS_FILE_PATH="$COMBINED_CA"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
exec "$REAL_BINARY" "$@"
|
||||||
|
WRAPPER_EOF
|
||||||
|
|
||||||
|
# Substitute actual values
|
||||||
|
sed -i "s|__MITM_PORT__|${MITM_PORT}|g" "$ls_path"
|
||||||
|
sed -i "s|__CA_PATH__|${CA_PATH}|g" "$ls_path"
|
||||||
|
|
||||||
|
# Make executable
|
||||||
|
chmod +x "$ls_path"
|
||||||
|
|
||||||
|
echo -e " ${GREEN}✓${NC} Wrapper installed at ${ls_path}"
|
||||||
|
echo ""
|
||||||
|
echo -e " ${BOLD}How it works:${NC}"
|
||||||
|
echo -e " When Antigravity starts the LS, the wrapper will:"
|
||||||
|
echo -e " 1. Set ${CYAN}HTTPS_PROXY${NC}=http://127.0.0.1:${MITM_PORT}"
|
||||||
|
echo -e " 2. Build combined CA bundle (system + MITM) at /tmp/antigravity-mitm-combined-ca.pem"
|
||||||
|
echo -e " 3. Set ${CYAN}SSL_CERT_FILE${NC} to the combined bundle"
|
||||||
|
echo -e " 4. Exec the real LS binary with all original args"
|
||||||
|
echo ""
|
||||||
|
echo -e " ${YELLOW}Note:${NC} Restart Antigravity for the wrapper to take effect."
|
||||||
|
echo ""
|
||||||
|
}
|
||||||
|
|
||||||
|
# ── Uninstall ────────────────────────────────────────────────────────────────
|
||||||
|
cmd_uninstall() {
|
||||||
|
# Check permissions upfront
|
||||||
|
local ls_path
|
||||||
|
ls_path=$(find_ls_binary) || true
|
||||||
|
if [[ -n "$ls_path" ]] && [[ ! -w "$(dirname "$ls_path")" ]] && [[ "$EUID" -ne 0 ]]; then
|
||||||
|
echo -e " ${RED}✗${NC} $(dirname "$ls_path") requires elevated permissions"
|
||||||
|
echo -e " run: sudo $0 uninstall"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo -e "${BOLD}${CYAN}Antigravity MITM Wrapper Uninstaller${NC}"
|
||||||
|
echo -e "─────────────────────────────────────"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
if [[ -n "$ls_path" ]]; then
|
||||||
|
local real_path="${ls_path}${BACKUP_SUFFIX}"
|
||||||
|
if [[ -f "$real_path" ]]; then
|
||||||
|
mv -f "$real_path" "$ls_path"
|
||||||
|
echo -e " ${GREEN}✓${NC} Restored real binary at ${ls_path}"
|
||||||
|
else
|
||||||
|
echo -e " ${YELLOW}!${NC} No backup found at ${real_path}"
|
||||||
|
echo -e " The LS binary may not be wrapped."
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
echo -e " ${RED}✗${NC} Could not find Antigravity language server binary."
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo -e " ${YELLOW}Note:${NC} Restart Antigravity for the change to take effect."
|
||||||
|
echo ""
|
||||||
|
}
|
||||||
|
|
||||||
|
# ── Status ───────────────────────────────────────────────────────────────────
|
||||||
|
cmd_status() {
|
||||||
|
echo -e "${BOLD}${CYAN}Antigravity MITM Wrapper Status${NC}"
|
||||||
|
echo -e "────────────────────────────────"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
local ls_path
|
||||||
|
if ls_path=$(find_ls_binary); then
|
||||||
|
echo -e " ${GREEN}✓${NC} LS binary: ${ls_path}"
|
||||||
|
|
||||||
|
local real_path="${ls_path}${BACKUP_SUFFIX}"
|
||||||
|
if [[ -f "$real_path" ]]; then
|
||||||
|
echo -e " ${GREEN}✓${NC} Wrapper: ${BOLD}installed${NC}"
|
||||||
|
echo -e " ${GREEN}✓${NC} Real binary: ${real_path}"
|
||||||
|
|
||||||
|
# Check if wrapper is valid
|
||||||
|
if head -c 200 "$ls_path" | grep -q 'MITM LS Wrapper'; then
|
||||||
|
echo -e " ${GREEN}✓${NC} Wrapper script: valid"
|
||||||
|
else
|
||||||
|
echo -e " ${RED}✗${NC} Wrapper script: ${BOLD}corrupted or replaced${NC}"
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
echo -e " ${YELLOW}○${NC} Wrapper: ${BOLD}not installed${NC}"
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
echo -e " ${RED}✗${NC} LS binary: not found"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Check CA cert
|
||||||
|
if [[ -f "$CA_PATH" ]]; then
|
||||||
|
echo -e " ${GREEN}✓${NC} CA cert: ${CA_PATH}"
|
||||||
|
else
|
||||||
|
echo -e " ${RED}✗${NC} CA cert: not found (start proxy first)"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Check MITM port
|
||||||
|
if ss -tlnp 2>/dev/null | grep -q ":${MITM_PORT} "; then
|
||||||
|
echo -e " ${GREEN}✓${NC} MITM proxy: listening on :${MITM_PORT}"
|
||||||
|
else
|
||||||
|
echo -e " ${YELLOW}○${NC} MITM proxy: not running on :${MITM_PORT}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
}
|
||||||
|
|
||||||
|
# ── Main ─────────────────────────────────────────────────────────────────────
|
||||||
|
case "${1:-}" in
|
||||||
|
install)
|
||||||
|
shift
|
||||||
|
cmd_install "${1:-}"
|
||||||
|
;;
|
||||||
|
uninstall)
|
||||||
|
cmd_uninstall
|
||||||
|
;;
|
||||||
|
status)
|
||||||
|
cmd_status
|
||||||
|
;;
|
||||||
|
-h|--help)
|
||||||
|
echo "Usage: $0 {install|uninstall|status}"
|
||||||
|
echo ""
|
||||||
|
echo "Commands:"
|
||||||
|
echo " install [path] Install MITM wrapper (auto-detect or specify path)"
|
||||||
|
echo " uninstall Restore original LS binary"
|
||||||
|
echo " status Show wrapper installation status"
|
||||||
|
echo ""
|
||||||
|
echo "Environment:"
|
||||||
|
echo " ANTIGRAVITY_MITM_PORT MITM proxy port (default: 8742)"
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
echo "Usage: $0 {install|uninstall|status}"
|
||||||
|
exit 1
|
||||||
|
;;
|
||||||
|
esac
|
||||||
475
scripts/parse-snapshot.py
Normal file
475
scripts/parse-snapshot.py
Normal file
@@ -0,0 +1,475 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Parse Go GODEBUG=http2debug=2 output into a clean, readable snapshot.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python3 parse-snapshot.py < raw-http2-dump.log
|
||||||
|
python3 parse-snapshot.py /path/to/logfile
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import re
|
||||||
|
import json
|
||||||
|
import gzip
|
||||||
|
from collections import defaultdict
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
# ── Colors ────────────────────────────────────────────────────────────────────
|
||||||
|
BOLD = "\033[1m"
|
||||||
|
DIM = "\033[2m"
|
||||||
|
RED = "\033[91m"
|
||||||
|
GREEN = "\033[92m"
|
||||||
|
YELLOW = "\033[93m"
|
||||||
|
CYAN = "\033[96m"
|
||||||
|
MAGENTA = "\033[95m"
|
||||||
|
NC = "\033[0m"
|
||||||
|
|
||||||
|
# ── Regexes ───────────────────────────────────────────────────────────────────
|
||||||
|
RE_ENCODING_HEADER = re.compile(
|
||||||
|
r'http2: Transport encoding header "([^"]+)" = "([^"]*)"'
|
||||||
|
)
|
||||||
|
RE_DECODED_HEADER = re.compile(
|
||||||
|
r'http2: decoded hpack field header field "([^"]+)" = "([^"]*)"'
|
||||||
|
)
|
||||||
|
RE_SERVER_ENCODING = re.compile(
|
||||||
|
r'http2: server encoding header "([^"]+)" = "([^"]*)"'
|
||||||
|
)
|
||||||
|
RE_WROTE_DATA = re.compile(
|
||||||
|
r'http2: Framer [^:]+: wrote DATA flags=(\S+) stream=(\d+) len=(\d+) data="(.*?)"'
|
||||||
|
)
|
||||||
|
RE_READ_DATA = re.compile(
|
||||||
|
r'http2: Framer [^:]+: read DATA flags=(\S+) stream=(\d+) len=(\d+) data="(.*?)"'
|
||||||
|
)
|
||||||
|
RE_TRANSPORT_CONN = re.compile(
|
||||||
|
r'http2: Transport creating client conn [^ ]+ to (.+)'
|
||||||
|
)
|
||||||
|
RE_SERVER_READ_DATA = re.compile(
|
||||||
|
r'http2: server read frame DATA flags=(\S+) stream=(\d+) len=(\d+) data="(.*?)"'
|
||||||
|
)
|
||||||
|
RE_WROTE_HEADERS = re.compile(
|
||||||
|
r'http2: Framer [^:]+: wrote HEADERS flags=(\S+) stream=(\d+)'
|
||||||
|
)
|
||||||
|
RE_TIMESTAMP = re.compile(r'^(\d{4}/\d{2}/\d{2} \d{2}:\d{2}:\d{2})')
|
||||||
|
RE_LS_LOG = re.compile(r'^[IWE]\d{4} ')
|
||||||
|
RE_MAXPROCS = re.compile(r'^.*maxprocs:')
|
||||||
|
RE_BYTES_OMITTED = re.compile(r'\((\d+) bytes omitted\)$')
|
||||||
|
|
||||||
|
# Known domain purposes
|
||||||
|
DOMAIN_INFO = {
|
||||||
|
"antigravity-unleash.goog": ("Feature Flags", "Unleash SDK — controls A/B tests, feature rollouts"),
|
||||||
|
"daily-cloudcode-pa.googleapis.com": ("LLM API (gRPC)", "Primary Gemini/Claude API endpoint"),
|
||||||
|
"cloudcode-pa.googleapis.com": ("LLM API (gRPC)", "Production Gemini/Claude API endpoint"),
|
||||||
|
"api.anthropic.com": ("Claude API", "Direct Anthropic API calls"),
|
||||||
|
"lh3.googleusercontent.com": ("Profile Picture", "User avatar image"),
|
||||||
|
"play.googleapis.com": ("Telemetry", "Google Play telemetry/logging"),
|
||||||
|
"firebaseinstallations.googleapis.com": ("Firebase", "Firebase installation tracking"),
|
||||||
|
"oauth2.googleapis.com": ("OAuth", "Token refresh/exchange"),
|
||||||
|
"speech.googleapis.com": ("Speech", "Voice input processing"),
|
||||||
|
"modelarmor.googleapis.com": ("Safety", "Content safety/filtering"),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class Request:
|
||||||
|
def __init__(self):
|
||||||
|
self.method = ""
|
||||||
|
self.path = ""
|
||||||
|
self.authority = ""
|
||||||
|
self.scheme = ""
|
||||||
|
self.headers = {}
|
||||||
|
self.data = b""
|
||||||
|
self.data_len = 0
|
||||||
|
self.stream_id = None
|
||||||
|
self.timestamp = ""
|
||||||
|
self.direction = "outgoing" # outgoing = LS→upstream, incoming = LS←upstream
|
||||||
|
|
||||||
|
|
||||||
|
class Snapshot:
|
||||||
|
def __init__(self):
|
||||||
|
self.connections = [] # (timestamp, target)
|
||||||
|
self.requests = [] # list of Request
|
||||||
|
self.responses = defaultdict(lambda: {"headers": {}, "data": b"", "data_len": 0})
|
||||||
|
self.ls_logs = []
|
||||||
|
|
||||||
|
def parse(self, lines):
|
||||||
|
current_headers = {}
|
||||||
|
current_direction = "outgoing"
|
||||||
|
current_stream = None
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
line = line.rstrip()
|
||||||
|
|
||||||
|
# Skip empty
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# LS process logs
|
||||||
|
if RE_LS_LOG.match(line) or RE_MAXPROCS.match(line):
|
||||||
|
self.ls_logs.append(line)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# New connection
|
||||||
|
m = RE_TRANSPORT_CONN.search(line)
|
||||||
|
if m:
|
||||||
|
ts = ""
|
||||||
|
ts_m = RE_TIMESTAMP.match(line)
|
||||||
|
if ts_m:
|
||||||
|
ts = ts_m.group(1)
|
||||||
|
self.connections.append((ts, m.group(1)))
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Outgoing headers (Transport encoding = LS sending to upstream)
|
||||||
|
m = RE_ENCODING_HEADER.search(line)
|
||||||
|
if m:
|
||||||
|
key, val = m.group(1), m.group(2)
|
||||||
|
if key == ":method":
|
||||||
|
# New request starting
|
||||||
|
if current_headers.get(":path"):
|
||||||
|
self._finalize_request(current_headers, "outgoing", line)
|
||||||
|
current_headers = {}
|
||||||
|
current_direction = "outgoing"
|
||||||
|
current_headers[key] = val
|
||||||
|
ts_m = RE_TIMESTAMP.match(line)
|
||||||
|
if ts_m and "timestamp" not in current_headers:
|
||||||
|
current_headers["timestamp"] = ts_m.group(1)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Incoming headers (decoded hpack = upstream responding, OR server receiving)
|
||||||
|
m = RE_DECODED_HEADER.search(line)
|
||||||
|
if m:
|
||||||
|
key, val = m.group(1), m.group(2)
|
||||||
|
if key == ":authority" and "server read frame" not in line:
|
||||||
|
# This is a request received by our LS
|
||||||
|
if current_headers.get(":path"):
|
||||||
|
self._finalize_request(current_headers, current_direction, line)
|
||||||
|
current_headers = {}
|
||||||
|
current_direction = "incoming"
|
||||||
|
current_headers[key] = val
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Server encoding (our LS responding)
|
||||||
|
m = RE_SERVER_ENCODING.search(line)
|
||||||
|
if m:
|
||||||
|
continue # Skip server response headers for now
|
||||||
|
|
||||||
|
# Headers frame written (triggers finalization)
|
||||||
|
m = RE_WROTE_HEADERS.search(line)
|
||||||
|
if m:
|
||||||
|
current_stream = m.group(2)
|
||||||
|
if current_headers.get(":path") or current_headers.get(":method"):
|
||||||
|
req = self._finalize_request(current_headers, current_direction, line)
|
||||||
|
if req:
|
||||||
|
req.stream_id = current_stream
|
||||||
|
current_headers = {}
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Data frames (wrote = LS sending, read = LS receiving)
|
||||||
|
for pattern, direction in [
|
||||||
|
(RE_WROTE_DATA, "sent"),
|
||||||
|
(RE_READ_DATA, "received"),
|
||||||
|
(RE_SERVER_READ_DATA, "server_received"),
|
||||||
|
]:
|
||||||
|
m = pattern.search(line)
|
||||||
|
if m:
|
||||||
|
flags, stream, length, data_str = (
|
||||||
|
m.group(1),
|
||||||
|
m.group(2),
|
||||||
|
int(m.group(3)),
|
||||||
|
m.group(4),
|
||||||
|
)
|
||||||
|
# Find matching request by stream
|
||||||
|
for req in reversed(self.requests):
|
||||||
|
if req.stream_id == stream:
|
||||||
|
raw = self._decode_data_str(data_str, line)
|
||||||
|
if direction == "sent" or direction == "server_received":
|
||||||
|
req.data += raw
|
||||||
|
req.data_len = max(req.data_len, length)
|
||||||
|
break
|
||||||
|
# Also check omitted bytes
|
||||||
|
om = RE_BYTES_OMITTED.search(line)
|
||||||
|
if om:
|
||||||
|
pass # length already captured
|
||||||
|
break
|
||||||
|
|
||||||
|
# Finalize any remaining headers
|
||||||
|
if current_headers.get(":path") or current_headers.get(":method"):
|
||||||
|
self._finalize_request(current_headers, current_direction, "")
|
||||||
|
|
||||||
|
def _finalize_request(self, headers, direction, _line):
|
||||||
|
req = Request()
|
||||||
|
req.method = headers.pop(":method", "GET")
|
||||||
|
req.path = headers.pop(":path", "/")
|
||||||
|
req.authority = headers.pop(":authority", "")
|
||||||
|
req.scheme = headers.pop(":scheme", "https")
|
||||||
|
req.timestamp = headers.pop("timestamp", "")
|
||||||
|
req.direction = direction
|
||||||
|
req.headers = {k: v for k, v in headers.items() if not k.startswith(":")}
|
||||||
|
self.requests.append(req)
|
||||||
|
return req
|
||||||
|
|
||||||
|
def _decode_data_str(self, s, full_line):
|
||||||
|
"""Decode escaped string from GODEBUG output back to bytes."""
|
||||||
|
try:
|
||||||
|
# Handle Go's escaped bytes
|
||||||
|
result = bytearray()
|
||||||
|
i = 0
|
||||||
|
while i < len(s):
|
||||||
|
if s[i] == "\\" and i + 1 < len(s):
|
||||||
|
if s[i + 1] == "x" and i + 3 < len(s):
|
||||||
|
result.append(int(s[i + 2 : i + 4], 16))
|
||||||
|
i += 4
|
||||||
|
elif s[i + 1] == "n":
|
||||||
|
result.append(10)
|
||||||
|
i += 2
|
||||||
|
elif s[i + 1] == "r":
|
||||||
|
result.append(13)
|
||||||
|
i += 2
|
||||||
|
elif s[i + 1] == "t":
|
||||||
|
result.append(9)
|
||||||
|
i += 2
|
||||||
|
elif s[i + 1] == "\\":
|
||||||
|
result.append(92)
|
||||||
|
i += 2
|
||||||
|
elif s[i + 1] == '"':
|
||||||
|
result.append(34)
|
||||||
|
i += 2
|
||||||
|
else:
|
||||||
|
result.append(ord(s[i]))
|
||||||
|
i += 1
|
||||||
|
else:
|
||||||
|
result.append(ord(s[i]))
|
||||||
|
i += 1
|
||||||
|
return bytes(result)
|
||||||
|
except Exception:
|
||||||
|
return s.encode("utf-8", errors="replace")
|
||||||
|
|
||||||
|
def render(self):
|
||||||
|
out = []
|
||||||
|
|
||||||
|
# Header
|
||||||
|
out.append(f"\n{BOLD}{CYAN}{'═' * 70}{NC}")
|
||||||
|
out.append(f"{BOLD}{CYAN} STANDALONE LS TRAFFIC SNAPSHOT{NC}")
|
||||||
|
out.append(f"{BOLD}{CYAN}{'═' * 70}{NC}\n")
|
||||||
|
|
||||||
|
# LS Logs
|
||||||
|
if self.ls_logs:
|
||||||
|
out.append(f"{BOLD}▸ Language Server Logs{NC}")
|
||||||
|
out.append(f"{DIM}{'─' * 60}{NC}")
|
||||||
|
for log in self.ls_logs:
|
||||||
|
out.append(f" {DIM}{log}{NC}")
|
||||||
|
out.append("")
|
||||||
|
|
||||||
|
# Connections
|
||||||
|
if self.connections:
|
||||||
|
out.append(f"{BOLD}▸ Outbound Connections{NC}")
|
||||||
|
out.append(f"{DIM}{'─' * 60}{NC}")
|
||||||
|
for ts, target in self.connections:
|
||||||
|
domain = target.split(":")[0] if ":" in target else target
|
||||||
|
info = DOMAIN_INFO.get(domain, ("Unknown", ""))
|
||||||
|
out.append(
|
||||||
|
f" {GREEN}→{NC} {BOLD}{target}{NC} {DIM}({info[0]}){NC}"
|
||||||
|
)
|
||||||
|
if info[1]:
|
||||||
|
out.append(f" {DIM}{info[1]}{NC}")
|
||||||
|
out.append("")
|
||||||
|
|
||||||
|
# Group requests by domain
|
||||||
|
by_domain = defaultdict(list)
|
||||||
|
for req in self.requests:
|
||||||
|
by_domain[req.authority].append(req)
|
||||||
|
|
||||||
|
# Render each domain's requests
|
||||||
|
for domain, reqs in by_domain.items():
|
||||||
|
if domain.startswith("127.0.0.1"):
|
||||||
|
label = "Local (our requests to LS)"
|
||||||
|
color = DIM
|
||||||
|
else:
|
||||||
|
info = DOMAIN_INFO.get(domain, ("External", ""))
|
||||||
|
label = info[0]
|
||||||
|
color = YELLOW if "API" in info[0] else CYAN
|
||||||
|
|
||||||
|
out.append(f"{BOLD}{'═' * 70}{NC}")
|
||||||
|
out.append(f"{BOLD}{color} {domain}{NC} {DIM}— {label}{NC}")
|
||||||
|
out.append(f"{BOLD}{'═' * 70}{NC}")
|
||||||
|
|
||||||
|
for i, req in enumerate(reqs):
|
||||||
|
arrow = "→" if req.direction == "outgoing" else "←"
|
||||||
|
method_color = GREEN if req.method == "GET" else YELLOW
|
||||||
|
|
||||||
|
out.append(f"\n {BOLD}{arrow} {method_color}{req.method}{NC} {req.path}")
|
||||||
|
|
||||||
|
# Important headers
|
||||||
|
interesting = [
|
||||||
|
"authorization",
|
||||||
|
"content-type",
|
||||||
|
"user-agent",
|
||||||
|
"unleash-appname",
|
||||||
|
"unleash-instanceid",
|
||||||
|
"unleash-sdk",
|
||||||
|
"x-goog-api-key",
|
||||||
|
"x-goog-api-client",
|
||||||
|
"grpc-encoding",
|
||||||
|
"te",
|
||||||
|
]
|
||||||
|
shown = False
|
||||||
|
for key in interesting:
|
||||||
|
if key in req.headers:
|
||||||
|
val = req.headers[key]
|
||||||
|
# Mask tokens partially
|
||||||
|
if key == "authorization" and len(val) > 30:
|
||||||
|
if val.startswith("Bearer "):
|
||||||
|
val = f"Bearer {val[7:20]}...{val[-10:]}"
|
||||||
|
elif len(val) > 40:
|
||||||
|
val = f"{val[:30]}...{val[-10:]}"
|
||||||
|
out.append(f" {DIM}{key}:{NC} {val}")
|
||||||
|
shown = True
|
||||||
|
|
||||||
|
# All other headers (collapsed)
|
||||||
|
other = {
|
||||||
|
k: v
|
||||||
|
for k, v in req.headers.items()
|
||||||
|
if k not in interesting and not k.startswith(":")
|
||||||
|
}
|
||||||
|
if other:
|
||||||
|
if not shown:
|
||||||
|
out.append(f" {DIM}Headers:{NC}")
|
||||||
|
for k, v in other.items():
|
||||||
|
out.append(f" {DIM}{k}:{NC} {v}")
|
||||||
|
|
||||||
|
# Body
|
||||||
|
if req.data:
|
||||||
|
out.append(self._render_body(req.data, req.data_len))
|
||||||
|
|
||||||
|
out.append("")
|
||||||
|
|
||||||
|
return "\n".join(out)
|
||||||
|
|
||||||
|
def _render_body(self, data, total_len):
|
||||||
|
"""Render body data in the most readable format possible."""
|
||||||
|
lines = []
|
||||||
|
|
||||||
|
# Try JSON
|
||||||
|
try:
|
||||||
|
text = data.decode("utf-8")
|
||||||
|
obj = json.loads(text)
|
||||||
|
pretty = json.dumps(obj, indent=2, ensure_ascii=False)
|
||||||
|
lines.append(f" {BOLD}Body ({len(data)} bytes, JSON):{NC}")
|
||||||
|
for l in pretty.split("\n")[:30]:
|
||||||
|
lines.append(f" {GREEN}{l}{NC}")
|
||||||
|
if len(pretty.split("\n")) > 30:
|
||||||
|
lines.append(f" {DIM}... ({len(pretty.split(chr(10))) - 30} more lines){NC}")
|
||||||
|
return "\n".join(lines)
|
||||||
|
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Try gzip
|
||||||
|
if data[:2] == b"\x1f\x8b":
|
||||||
|
try:
|
||||||
|
decompressed = gzip.decompress(data)
|
||||||
|
try:
|
||||||
|
text = decompressed.decode("utf-8")
|
||||||
|
try:
|
||||||
|
obj = json.loads(text)
|
||||||
|
pretty = json.dumps(obj, indent=2, ensure_ascii=False)
|
||||||
|
lines.append(
|
||||||
|
f" {BOLD}Body ({len(data)} bytes gzip → {len(decompressed)} bytes, JSON):{NC}"
|
||||||
|
)
|
||||||
|
for l in pretty.split("\n")[:50]:
|
||||||
|
lines.append(f" {GREEN}{l}{NC}")
|
||||||
|
if len(pretty.split("\n")) > 50:
|
||||||
|
lines.append(
|
||||||
|
f" {DIM}... ({len(pretty.split(chr(10))) - 50} more lines){NC}"
|
||||||
|
)
|
||||||
|
return "\n".join(lines)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
lines.append(
|
||||||
|
f" {BOLD}Body ({len(data)} bytes gzip → {len(decompressed)} bytes, text):{NC}"
|
||||||
|
)
|
||||||
|
for l in text.split("\n")[:20]:
|
||||||
|
lines.append(f" {l[:200]}")
|
||||||
|
return "\n".join(lines)
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
lines.append(
|
||||||
|
f" {BOLD}Body ({len(data)} bytes gzip → {len(decompressed)} bytes, binary):{NC}"
|
||||||
|
)
|
||||||
|
lines.append(f" {DIM}{self._extract_strings(decompressed)}{NC}")
|
||||||
|
return "\n".join(lines)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Try protobuf (extract readable strings)
|
||||||
|
if data[:1] in (b"\x08", b"\x0a", b"\x10", b"\x12", b"\x18", b"\x1a", b"\x20", b"\x22"):
|
||||||
|
strings = self._extract_strings(data)
|
||||||
|
if strings:
|
||||||
|
lines.append(f" {BOLD}Body ({total_len} bytes, protobuf):{NC}")
|
||||||
|
lines.append(f" {DIM}Extracted strings:{NC}")
|
||||||
|
for s in strings.split(" | ")[:20]:
|
||||||
|
s = s.strip()
|
||||||
|
if len(s) > 3:
|
||||||
|
lines.append(f" {MAGENTA}{s}{NC}")
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
# Try plain text
|
||||||
|
try:
|
||||||
|
text = data.decode("utf-8")
|
||||||
|
lines.append(f" {BOLD}Body ({len(data)} bytes, text):{NC}")
|
||||||
|
for l in text.split("\n")[:10]:
|
||||||
|
lines.append(f" {l[:200]}")
|
||||||
|
return "\n".join(lines)
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# PNG
|
||||||
|
if data[:4] == b"\x89PNG":
|
||||||
|
lines.append(f" {BOLD}Body ({total_len} bytes, PNG image){NC}")
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
# Binary fallback
|
||||||
|
lines.append(f" {BOLD}Body ({total_len} bytes, binary):{NC}")
|
||||||
|
strings = self._extract_strings(data)
|
||||||
|
if strings:
|
||||||
|
lines.append(f" {DIM}Extracted strings:{NC}")
|
||||||
|
for s in strings.split(" | ")[:15]:
|
||||||
|
s = s.strip()
|
||||||
|
if len(s) > 3:
|
||||||
|
lines.append(f" {MAGENTA}{s}{NC}")
|
||||||
|
else:
|
||||||
|
lines.append(f" {DIM}(no readable strings){NC}")
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
def _extract_strings(self, data, min_len=4):
|
||||||
|
"""Extract printable ASCII strings from binary data."""
|
||||||
|
strings = []
|
||||||
|
current = bytearray()
|
||||||
|
for b in data:
|
||||||
|
if 32 <= b <= 126:
|
||||||
|
current.append(b)
|
||||||
|
else:
|
||||||
|
if len(current) >= min_len:
|
||||||
|
strings.append(current.decode("ascii"))
|
||||||
|
current = bytearray()
|
||||||
|
if len(current) >= min_len:
|
||||||
|
strings.append(current.decode("ascii"))
|
||||||
|
# Deduplicate while preserving order
|
||||||
|
seen = set()
|
||||||
|
unique = []
|
||||||
|
for s in strings:
|
||||||
|
if s not in seen:
|
||||||
|
seen.add(s)
|
||||||
|
unique.append(s)
|
||||||
|
return " | ".join(unique[:30])
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
if len(sys.argv) > 1:
|
||||||
|
with open(sys.argv[1]) as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
else:
|
||||||
|
lines = sys.stdin.readlines()
|
||||||
|
|
||||||
|
snap = Snapshot()
|
||||||
|
snap.parse(lines)
|
||||||
|
print(snap.render())
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
277
scripts/standalone-ls.sh
Executable file
277
scripts/standalone-ls.sh
Executable file
@@ -0,0 +1,277 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
# ╔═══════════════════════════════════════════════════════════════════════════╗
|
||||||
|
# ║ Standalone Language Server Launcher ║
|
||||||
|
# ║ ║
|
||||||
|
# ║ Launches an isolated LS instance that: ║
|
||||||
|
# ║ - Shares OAuth via the main app's extension server ║
|
||||||
|
# ║ - Has its own HTTPS port, data dir, and cascades ║
|
||||||
|
# ║ - Optionally routes traffic through our MITM proxy ║
|
||||||
|
# ║ - Can capture a clean traffic snapshot ║
|
||||||
|
# ║ ║
|
||||||
|
# ║ Usage: ║
|
||||||
|
# ║ ./standalone-ls.sh # Launch, test, exit ║
|
||||||
|
# ║ ./standalone-ls.sh --fg # Foreground (stay alive) ║
|
||||||
|
# ║ ./standalone-ls.sh --mitm # Route through MITM proxy ║
|
||||||
|
# ║ ./standalone-ls.sh --snapshot # Capture clean traffic dump ║
|
||||||
|
# ║ ./standalone-ls.sh --snapshot --prompt "Say hello" ║
|
||||||
|
# ╚═══════════════════════════════════════════════════════════════════════════╝
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||||
|
|
||||||
|
# ── Defaults ──────────────────────────────────────────────────────────────────
|
||||||
|
LS_BIN="/usr/share/antigravity/resources/app/extensions/antigravity/bin/language_server_linux_x64"
|
||||||
|
HTTPS_PORT="42200"
|
||||||
|
DATA_DIR="/tmp/antigravity-standalone"
|
||||||
|
FOREGROUND=false
|
||||||
|
USE_MITM=false
|
||||||
|
SNAPSHOT=false
|
||||||
|
TIMEOUT=15
|
||||||
|
PROMPT=""
|
||||||
|
MODEL="MODEL_PLACEHOLDER_M3"
|
||||||
|
|
||||||
|
# ── Parse args ────────────────────────────────────────────────────────────────
|
||||||
|
while [[ $# -gt 0 ]]; do
|
||||||
|
case "$1" in
|
||||||
|
--port) HTTPS_PORT="$2"; shift 2 ;;
|
||||||
|
--mitm) USE_MITM=true; shift ;;
|
||||||
|
--fg) FOREGROUND=true; shift ;;
|
||||||
|
--timeout) TIMEOUT="$2"; shift 2 ;;
|
||||||
|
--snapshot) SNAPSHOT=true; TIMEOUT=30; shift ;;
|
||||||
|
--prompt) PROMPT="$2"; shift 2 ;;
|
||||||
|
--model) MODEL="$2"; shift 2 ;;
|
||||||
|
-h|--help)
|
||||||
|
echo "Usage: $0 [OPTIONS]"
|
||||||
|
echo ""
|
||||||
|
echo "Options:"
|
||||||
|
echo " --port PORT HTTPS port for standalone LS (default: 42200)"
|
||||||
|
echo " --mitm Route traffic through MITM proxy"
|
||||||
|
echo " --fg Run in foreground (stay alive)"
|
||||||
|
echo " --timeout SECS Background mode timeout (default: 15)"
|
||||||
|
echo " --snapshot Capture clean traffic snapshot"
|
||||||
|
echo " --prompt TEXT Prompt to send (snapshot mode)"
|
||||||
|
echo " --model MODEL Model alias (default: MODEL_PLACEHOLDER_M3)"
|
||||||
|
exit 0 ;;
|
||||||
|
*) echo "Unknown option: $1"; exit 1 ;;
|
||||||
|
esac
|
||||||
|
done
|
||||||
|
|
||||||
|
# ── Discover main LS config ──────────────────────────────────────────────────
|
||||||
|
MAIN_PID=$(pgrep -f 'language_server_linux_x64' | head -1 || true)
|
||||||
|
if [[ -z "$MAIN_PID" ]]; then
|
||||||
|
echo "[-] No main LS process found. Main Antigravity must be running."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
MAIN_CSRF=$(tr '\0' '\n' < /proc/"$MAIN_PID"/cmdline | grep -A1 'csrf_token' | tail -1)
|
||||||
|
EXT_PORT=$(tr '\0' '\n' < /proc/"$MAIN_PID"/cmdline | grep -A1 'extension_server_port' | tail -1)
|
||||||
|
|
||||||
|
echo "[*] Main LS PID: $MAIN_PID"
|
||||||
|
echo "[*] CSRF: $MAIN_CSRF"
|
||||||
|
echo "[*] Extension server: $EXT_PORT"
|
||||||
|
|
||||||
|
# ── Build protobuf metadata for stdin ─────────────────────────────────────────
|
||||||
|
TS=$(date +%s)
|
||||||
|
METADATA=$(python3 -c "
|
||||||
|
import sys
|
||||||
|
def v(n):
|
||||||
|
r = bytearray()
|
||||||
|
while n > 0x7f:
|
||||||
|
r.append((n & 0x7f) | 0x80)
|
||||||
|
n >>= 7
|
||||||
|
r.append(n & 0x7f)
|
||||||
|
return bytes(r)
|
||||||
|
def s(f, val):
|
||||||
|
t = v((f << 3) | 2)
|
||||||
|
d = val.encode()
|
||||||
|
return t + v(len(d)) + d
|
||||||
|
buf = bytearray()
|
||||||
|
buf += s(1, 'standalone-api-key-$TS')
|
||||||
|
buf += s(3, 'antigravity')
|
||||||
|
buf += s(4, '1.15.8')
|
||||||
|
buf += s(5, '1.16.39')
|
||||||
|
buf += s(6, 'en_US')
|
||||||
|
buf += s(10, 'standalone-session-$TS')
|
||||||
|
buf += s(11, 'antigravity')
|
||||||
|
sys.stdout.buffer.write(bytes(buf))
|
||||||
|
" | base64)
|
||||||
|
|
||||||
|
# ── Setup data directory ──────────────────────────────────────────────────────
|
||||||
|
mkdir -p "$DATA_DIR/.gemini"
|
||||||
|
|
||||||
|
# ── MITM environment ─────────────────────────────────────────────────────────
|
||||||
|
MITM_ENV=()
|
||||||
|
if $USE_MITM; then
|
||||||
|
REAL_HOME="${SUDO_USER:+$(getent passwd "$SUDO_USER" | cut -d: -f6)}"
|
||||||
|
REAL_HOME="${REAL_HOME:-$HOME}"
|
||||||
|
MITM_PORT_FILE="${REAL_HOME}/.config/antigravity-proxy/mitm-port"
|
||||||
|
CA_PATH="${REAL_HOME}/.config/antigravity-proxy/mitm-ca.pem"
|
||||||
|
|
||||||
|
if [[ -f "$MITM_PORT_FILE" ]]; then
|
||||||
|
MITM_PORT=$(cat "$MITM_PORT_FILE")
|
||||||
|
else
|
||||||
|
MITM_PORT="8742"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ ! -f "$CA_PATH" ]]; then
|
||||||
|
echo "[-] MITM CA cert not found at $CA_PATH"
|
||||||
|
echo " Start the proxy first to generate it."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
COMBINED_CA="/tmp/antigravity-mitm-combined-ca.pem"
|
||||||
|
SYS_CA=""
|
||||||
|
for candidate in /etc/ssl/certs/ca-certificates.crt /etc/pki/tls/certs/ca-bundle.crt /etc/ssl/cert.pem; do
|
||||||
|
if [[ -f "$candidate" ]]; then SYS_CA="$candidate"; break; fi
|
||||||
|
done
|
||||||
|
if [[ -n "$SYS_CA" ]]; then
|
||||||
|
cat "$SYS_CA" "$CA_PATH" > "$COMBINED_CA"
|
||||||
|
else
|
||||||
|
echo "[-] No system CA bundle found"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
MITM_ENV=(
|
||||||
|
"HTTPS_PROXY=http://127.0.0.1:${MITM_PORT}"
|
||||||
|
"SSL_CERT_FILE=${COMBINED_CA}"
|
||||||
|
"GRPC_DEFAULT_SSL_ROOTS_FILE_PATH=${COMBINED_CA}"
|
||||||
|
)
|
||||||
|
echo "[*] MITM: enabled (port $MITM_PORT)"
|
||||||
|
else
|
||||||
|
echo "[*] MITM: disabled (use --mitm to enable)"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# ── LS args ───────────────────────────────────────────────────────────────────
|
||||||
|
LS_ARGS=(
|
||||||
|
-enable_lsp
|
||||||
|
-extension_server_port "$EXT_PORT"
|
||||||
|
-csrf_token "$MAIN_CSRF"
|
||||||
|
-server_port "$HTTPS_PORT"
|
||||||
|
-workspace_id "standalone_$TS"
|
||||||
|
-cloud_code_endpoint "https://daily-cloudcode-pa.googleapis.com"
|
||||||
|
-app_data_dir "antigravity-standalone"
|
||||||
|
-gemini_dir "$DATA_DIR/.gemini"
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Extra env for snapshot mode ───────────────────────────────────────────────
|
||||||
|
EXTRA_ENV=()
|
||||||
|
if $SNAPSHOT; then
|
||||||
|
EXTRA_ENV=("GODEBUG=http2debug=2")
|
||||||
|
echo "[*] Snapshot: enabled (HTTP/2 debug tracing)"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# ── Banner ────────────────────────────────────────────────────────────────────
|
||||||
|
echo ""
|
||||||
|
echo "========================================="
|
||||||
|
echo " Standalone LS"
|
||||||
|
echo " Port: $HTTPS_PORT (HTTPS)"
|
||||||
|
echo " Data: $DATA_DIR"
|
||||||
|
echo " Mode: $($FOREGROUND && echo "foreground" || echo "background ($TIMEOUT s)")"
|
||||||
|
echo "========================================="
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# ── Foreground mode ───────────────────────────────────────────────────────────
|
||||||
|
if $FOREGROUND; then
|
||||||
|
echo "$METADATA" | base64 -d | \
|
||||||
|
env "${MITM_ENV[@]+"${MITM_ENV[@]}"}" \
|
||||||
|
"${EXTRA_ENV[@]+"${EXTRA_ENV[@]}"}" \
|
||||||
|
ANTIGRAVITY_EDITOR_APP_ROOT="/usr/share/antigravity/resources/app" \
|
||||||
|
exec "$LS_BIN" "${LS_ARGS[@]}"
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
# ── Background mode ──────────────────────────────────────────────────────────
|
||||||
|
LOG="/tmp/standalone-ls.log"
|
||||||
|
rm -f "$LOG"
|
||||||
|
|
||||||
|
echo "$METADATA" | base64 -d | \
|
||||||
|
env "${MITM_ENV[@]+"${MITM_ENV[@]}"}" \
|
||||||
|
"${EXTRA_ENV[@]+"${EXTRA_ENV[@]}"}" \
|
||||||
|
ANTIGRAVITY_EDITOR_APP_ROOT="/usr/share/antigravity/resources/app" \
|
||||||
|
timeout "$TIMEOUT" "$LS_BIN" "${LS_ARGS[@]}" \
|
||||||
|
> "$LOG" 2>&1 &
|
||||||
|
|
||||||
|
LS_PID=$!
|
||||||
|
echo "[*] PID: $LS_PID"
|
||||||
|
|
||||||
|
# Wait for init
|
||||||
|
for i in $(seq 1 5); do
|
||||||
|
sleep 1
|
||||||
|
if ! kill -0 "$LS_PID" 2>/dev/null; then
|
||||||
|
echo "[-] LS died after ${i}s"
|
||||||
|
echo "=== LOGS ==="
|
||||||
|
cat "$LOG"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
echo "[+] LS alive and initialized"
|
||||||
|
|
||||||
|
# ── Snapshot mode: send a prompt and capture traffic ──────────────────────────
|
||||||
|
if $SNAPSHOT; then
|
||||||
|
if [[ -z "$PROMPT" ]]; then
|
||||||
|
PROMPT="Say exactly: Hello standalone world"
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "[*] Sending cascade: \"$PROMPT\""
|
||||||
|
CASCADE_ID=$(curl -sk --max-time 10 \
|
||||||
|
"https://127.0.0.1:${HTTPS_PORT}/exa.language_server_pb.LanguageServerService/StartCascade" \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-H "x-codeium-csrf-token: $MAIN_CSRF" \
|
||||||
|
-H "Origin: vscode-file://vscode-app" \
|
||||||
|
-d "{
|
||||||
|
\"prompt\": \"$PROMPT\",
|
||||||
|
\"modelOrAlias\": {\"model\": \"$MODEL\"},
|
||||||
|
\"workspaceRootPaths\": [\"$DATA_DIR\"]
|
||||||
|
}" 2>/dev/null | python3 -c "import json,sys; print(json.load(sys.stdin).get('cascadeId',''))" 2>/dev/null || true)
|
||||||
|
|
||||||
|
echo "[*] Cascade: $CASCADE_ID"
|
||||||
|
echo "[*] Waiting 15s for upstream API calls..."
|
||||||
|
sleep 15
|
||||||
|
|
||||||
|
# Kill LS to flush logs
|
||||||
|
kill "$LS_PID" 2>/dev/null
|
||||||
|
wait "$LS_PID" 2>/dev/null || true
|
||||||
|
|
||||||
|
# Parse and display
|
||||||
|
echo ""
|
||||||
|
python3 "$SCRIPT_DIR/parse-snapshot.py" "$LOG"
|
||||||
|
|
||||||
|
# Also save raw log
|
||||||
|
SNAPSHOT_FILE="/tmp/standalone-snapshot-$(date +%Y%m%d-%H%M%S).log"
|
||||||
|
cp "$LOG" "$SNAPSHOT_FILE"
|
||||||
|
echo ""
|
||||||
|
echo "[*] Raw log saved to: $SNAPSHOT_FILE"
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
# ── Normal mode: test and report ──────────────────────────────────────────────
|
||||||
|
echo ""
|
||||||
|
echo "=== GetUserStatus ==="
|
||||||
|
curl -sk "https://127.0.0.1:${HTTPS_PORT}/exa.language_server_pb.LanguageServerService/GetUserStatus" \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-H "x-codeium-csrf-token: $MAIN_CSRF" \
|
||||||
|
-H "Origin: vscode-file://vscode-app" \
|
||||||
|
-d '{}' 2>/dev/null | python3 -c "
|
||||||
|
import json, sys
|
||||||
|
try:
|
||||||
|
d = json.load(sys.stdin)
|
||||||
|
us = d.get('userStatus', {})
|
||||||
|
ps = us.get('planStatus', {})
|
||||||
|
pi = ps.get('planInfo', {})
|
||||||
|
print(f'Plan: {pi.get(\"planName\",\"?\")}, Prompt: {ps.get(\"availablePromptCredits\",\"?\")}, Flow: {ps.get(\"availableFlowCredits\",\"?\")}')
|
||||||
|
ut = us.get('userTier', {})
|
||||||
|
print(f'Tier: {ut.get(\"name\",\"?\")}')
|
||||||
|
models = us.get('cascadeModelConfigData', {}).get('clientModelConfigs', [])
|
||||||
|
print(f'Models: {len(models)}')
|
||||||
|
for m in models[:5]:
|
||||||
|
qi = m.get('quotaInfo', {})
|
||||||
|
print(f' - {m.get(\"label\")}: remaining={qi.get(\"remainingFraction\",\"?\")}')
|
||||||
|
except Exception as e:
|
||||||
|
print(f'Error: {e}')
|
||||||
|
" 2>/dev/null
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
kill "$LS_PID" 2>/dev/null || true
|
||||||
|
wait "$LS_PID" 2>/dev/null || true
|
||||||
|
echo "[*] Done"
|
||||||
343
src/api/completions.rs
Normal file
343
src/api/completions.rs
Normal file
@@ -0,0 +1,343 @@
|
|||||||
|
//! OpenAI Chat Completions API (/v1/chat/completions) handler.
|
||||||
|
|
||||||
|
use axum::{
|
||||||
|
extract::State,
|
||||||
|
http::StatusCode,
|
||||||
|
response::{sse::Event, IntoResponse, Json, Sse},
|
||||||
|
};
|
||||||
|
use rand::Rng;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tracing::{debug, info, warn};
|
||||||
|
|
||||||
|
use super::models::{lookup_model, DEFAULT_MODEL, MODELS};
|
||||||
|
use super::polling::{extract_response_text, is_response_done, poll_for_response};
|
||||||
|
use super::types::*;
|
||||||
|
use super::util::{err_response, now_unix};
|
||||||
|
use super::AppState;
|
||||||
|
|
||||||
|
// ─── Input extraction ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Extract user text from Chat Completions messages array.
|
||||||
|
fn extract_chat_input(messages: &[CompletionMessage]) -> String {
|
||||||
|
let mut system_parts = Vec::new();
|
||||||
|
let mut user_parts = Vec::new();
|
||||||
|
|
||||||
|
for msg in messages {
|
||||||
|
let text = match &msg.content {
|
||||||
|
serde_json::Value::String(s) => s.clone(),
|
||||||
|
serde_json::Value::Array(arr) => arr
|
||||||
|
.iter()
|
||||||
|
.filter_map(|item| item["text"].as_str())
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join("\n"),
|
||||||
|
_ => continue,
|
||||||
|
};
|
||||||
|
match msg.role.as_str() {
|
||||||
|
"system" | "developer" => system_parts.push(text),
|
||||||
|
"user" => user_parts.push(text),
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut result = String::new();
|
||||||
|
if !system_parts.is_empty() {
|
||||||
|
result.push_str(&system_parts.join("\n"));
|
||||||
|
result.push_str("\n\n");
|
||||||
|
}
|
||||||
|
// Use the last user message
|
||||||
|
if let Some(last) = user_parts.last() {
|
||||||
|
result.push_str(last);
|
||||||
|
}
|
||||||
|
result.trim().to_string()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── Handler ─────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// POST /v1/chat/completions — OpenAI Chat Completions API compatibility shim.
|
||||||
|
/// Accepts standard messages format, reuses the same backend cascade, and
|
||||||
|
/// outputs in the Chat Completions streaming/sync format.
|
||||||
|
pub(crate) async fn handle_completions(
|
||||||
|
State(state): State<Arc<AppState>>,
|
||||||
|
Json(body): Json<CompletionRequest>,
|
||||||
|
) -> axum::response::Response {
|
||||||
|
let model_name = body.model.as_deref().unwrap_or(DEFAULT_MODEL);
|
||||||
|
info!(
|
||||||
|
"POST /v1/chat/completions model={} stream={}",
|
||||||
|
model_name, body.stream
|
||||||
|
);
|
||||||
|
|
||||||
|
let model = match lookup_model(model_name) {
|
||||||
|
Some(m) => m,
|
||||||
|
None => {
|
||||||
|
let names: Vec<&str> = MODELS.iter().map(|m| m.name).collect();
|
||||||
|
return err_response(
|
||||||
|
StatusCode::BAD_REQUEST,
|
||||||
|
format!("Unknown model: {model_name}. Available: {names:?}"),
|
||||||
|
"invalid_request_error",
|
||||||
|
);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let token = state.backend.oauth_token().await;
|
||||||
|
if token.is_empty() {
|
||||||
|
return err_response(
|
||||||
|
StatusCode::UNAUTHORIZED,
|
||||||
|
"No OAuth token. POST to /v1/token or set ANTIGRAVITY_OAUTH_TOKEN env var.".into(),
|
||||||
|
"authentication_error",
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let user_text = extract_chat_input(&body.messages);
|
||||||
|
if user_text.is_empty() {
|
||||||
|
return err_response(
|
||||||
|
StatusCode::BAD_REQUEST,
|
||||||
|
"No user message found".to_string(),
|
||||||
|
"invalid_request_error",
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fresh cascade per request
|
||||||
|
let cascade_id = match state.backend.create_cascade().await {
|
||||||
|
Ok(cid) => cid,
|
||||||
|
Err(e) => {
|
||||||
|
return err_response(
|
||||||
|
StatusCode::BAD_GATEWAY,
|
||||||
|
format!("StartCascade failed: {e}"),
|
||||||
|
"server_error",
|
||||||
|
);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Send message
|
||||||
|
match state
|
||||||
|
.backend
|
||||||
|
.send_message(&cascade_id, &user_text, model.model_enum)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok((200, _)) => {
|
||||||
|
let bg = Arc::clone(&state.backend);
|
||||||
|
let cid = cascade_id.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let _ = bg.update_annotations(&cid).await;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
Ok((status, _)) => {
|
||||||
|
return err_response(
|
||||||
|
StatusCode::BAD_GATEWAY,
|
||||||
|
format!("Backend returned {status}"),
|
||||||
|
"server_error",
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
return err_response(
|
||||||
|
StatusCode::BAD_GATEWAY,
|
||||||
|
format!("Send failed: {e}"),
|
||||||
|
"server_error",
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let completion_id = format!(
|
||||||
|
"chatcmpl-{}",
|
||||||
|
uuid::Uuid::new_v4().to_string().replace('-', "")
|
||||||
|
);
|
||||||
|
|
||||||
|
if body.stream {
|
||||||
|
chat_completions_stream(
|
||||||
|
state,
|
||||||
|
completion_id,
|
||||||
|
model_name.to_string(),
|
||||||
|
cascade_id,
|
||||||
|
body.timeout,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
} else {
|
||||||
|
chat_completions_sync(
|
||||||
|
state,
|
||||||
|
completion_id,
|
||||||
|
model_name.to_string(),
|
||||||
|
cascade_id,
|
||||||
|
body.timeout,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── Streaming ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Streaming output in Chat Completions format.
|
||||||
|
async fn chat_completions_stream(
|
||||||
|
state: Arc<AppState>,
|
||||||
|
completion_id: String,
|
||||||
|
model_name: String,
|
||||||
|
cascade_id: String,
|
||||||
|
timeout: u64,
|
||||||
|
) -> axum::response::Response {
|
||||||
|
let stream = async_stream::stream! {
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let mut last_text = String::new();
|
||||||
|
|
||||||
|
// Initial role chunk
|
||||||
|
yield Ok::<_, std::convert::Infallible>(Event::default().data(serde_json::to_string(&serde_json::json!({
|
||||||
|
"id": completion_id,
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": now_unix(),
|
||||||
|
"model": model_name,
|
||||||
|
"choices": [{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {"role": "assistant", "content": ""},
|
||||||
|
"finish_reason": serde_json::Value::Null,
|
||||||
|
}],
|
||||||
|
})).unwrap_or_default()));
|
||||||
|
|
||||||
|
while start.elapsed().as_secs() < timeout {
|
||||||
|
if let Ok((status, data)) = state.backend.get_steps(&cascade_id).await {
|
||||||
|
if status == 200 {
|
||||||
|
if let Some(steps) = data["steps"].as_array() {
|
||||||
|
let text = extract_response_text(steps);
|
||||||
|
|
||||||
|
if !text.is_empty() && text != last_text {
|
||||||
|
let delta = if text.len() > last_text.len() && text.starts_with(&*last_text) {
|
||||||
|
&text[last_text.len()..]
|
||||||
|
} else {
|
||||||
|
&text
|
||||||
|
};
|
||||||
|
|
||||||
|
if !delta.is_empty() {
|
||||||
|
yield Ok(Event::default().data(serde_json::to_string(&serde_json::json!({
|
||||||
|
"id": completion_id,
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": now_unix(),
|
||||||
|
"model": model_name,
|
||||||
|
"choices": [{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {"content": delta},
|
||||||
|
"finish_reason": serde_json::Value::Null,
|
||||||
|
}],
|
||||||
|
})).unwrap_or_default()));
|
||||||
|
last_text = text.to_string();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Done check: need DONE status AND non-empty text
|
||||||
|
if is_response_done(steps) && !last_text.is_empty() {
|
||||||
|
debug!("Completions stream done, text length={}", last_text.len());
|
||||||
|
yield Ok(Event::default().data(serde_json::to_string(&serde_json::json!({
|
||||||
|
"id": completion_id,
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": now_unix(),
|
||||||
|
"model": model_name,
|
||||||
|
"choices": [{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {},
|
||||||
|
"finish_reason": "stop",
|
||||||
|
}],
|
||||||
|
})).unwrap_or_default()));
|
||||||
|
yield Ok(Event::default().data("[DONE]".to_string()));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// IDLE fallback: check trajectory status periodically
|
||||||
|
// Only check every 5th step count to reduce backend traffic
|
||||||
|
let step_count = steps.len();
|
||||||
|
if step_count > 4 && step_count % 5 == 0 {
|
||||||
|
if let Ok((ts, td)) = state.backend.get_trajectory(&cascade_id).await {
|
||||||
|
if ts == 200 {
|
||||||
|
let run_status = td["status"].as_str().unwrap_or("");
|
||||||
|
if run_status.contains("IDLE") && !last_text.is_empty() {
|
||||||
|
debug!("Completions IDLE, text length={}", last_text.len());
|
||||||
|
yield Ok(Event::default().data(serde_json::to_string(&serde_json::json!({
|
||||||
|
"id": completion_id,
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": now_unix(),
|
||||||
|
"model": model_name,
|
||||||
|
"choices": [{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {},
|
||||||
|
"finish_reason": "stop",
|
||||||
|
}],
|
||||||
|
})).unwrap_or_default()));
|
||||||
|
yield Ok(Event::default().data("[DONE]".to_string()));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let poll_ms: u64 = rand::thread_rng().gen_range(800..1200);
|
||||||
|
tokio::time::sleep(tokio::time::Duration::from_millis(poll_ms)).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Timeout
|
||||||
|
warn!("Completions stream timeout after {}s", timeout);
|
||||||
|
yield Ok(Event::default().data(serde_json::to_string(&serde_json::json!({
|
||||||
|
"id": completion_id,
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": now_unix(),
|
||||||
|
"model": model_name,
|
||||||
|
"choices": [{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {"content": if last_text.is_empty() { "[Timeout waiting for response]" } else { "" }},
|
||||||
|
"finish_reason": "stop",
|
||||||
|
}],
|
||||||
|
})).unwrap_or_default()));
|
||||||
|
yield Ok(Event::default().data("[DONE]".to_string()));
|
||||||
|
};
|
||||||
|
|
||||||
|
Sse::new(stream)
|
||||||
|
.keep_alive(
|
||||||
|
axum::response::sse::KeepAlive::new()
|
||||||
|
.interval(std::time::Duration::from_secs(15))
|
||||||
|
.text(""),
|
||||||
|
)
|
||||||
|
.into_response()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── Sync ────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Sync output in Chat Completions format.
|
||||||
|
async fn chat_completions_sync(
|
||||||
|
state: Arc<AppState>,
|
||||||
|
completion_id: String,
|
||||||
|
model_name: String,
|
||||||
|
cascade_id: String,
|
||||||
|
timeout: u64,
|
||||||
|
) -> axum::response::Response {
|
||||||
|
let result = poll_for_response(&state, &cascade_id, timeout).await;
|
||||||
|
|
||||||
|
// Check MITM store first for real intercepted usage
|
||||||
|
let (prompt_tokens, completion_tokens, cached_tokens) = if let Some(mitm_usage) = state.mitm_store.take_usage(&cascade_id).await {
|
||||||
|
(mitm_usage.input_tokens, mitm_usage.output_tokens, mitm_usage.cache_read_input_tokens)
|
||||||
|
} else if let Some(u) = &result.usage {
|
||||||
|
(u.input_tokens, u.output_tokens, 0)
|
||||||
|
} else {
|
||||||
|
(0, 0, 0)
|
||||||
|
};
|
||||||
|
|
||||||
|
Json(serde_json::json!({
|
||||||
|
"id": completion_id,
|
||||||
|
"object": "chat.completion",
|
||||||
|
"created": now_unix(),
|
||||||
|
"model": model_name,
|
||||||
|
"choices": [{
|
||||||
|
"index": 0,
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": result.text,
|
||||||
|
},
|
||||||
|
"finish_reason": "stop",
|
||||||
|
}],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": prompt_tokens,
|
||||||
|
"completion_tokens": completion_tokens,
|
||||||
|
"total_tokens": prompt_tokens + completion_tokens,
|
||||||
|
"prompt_tokens_details": {
|
||||||
|
"cached_tokens": cached_tokens,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
.into_response()
|
||||||
|
}
|
||||||
176
src/api/mod.rs
Normal file
176
src/api/mod.rs
Normal file
@@ -0,0 +1,176 @@
|
|||||||
|
//! Axum API server — OpenAI-compatible Responses + Chat Completions endpoints.
|
||||||
|
|
||||||
|
mod completions;
|
||||||
|
mod models;
|
||||||
|
mod polling;
|
||||||
|
mod responses;
|
||||||
|
mod types;
|
||||||
|
mod util;
|
||||||
|
|
||||||
|
use crate::constants::safe_truncate;
|
||||||
|
use crate::session::SessionManager;
|
||||||
|
use axum::{
|
||||||
|
extract::{DefaultBodyLimit, State},
|
||||||
|
http::StatusCode,
|
||||||
|
response::{IntoResponse, Json},
|
||||||
|
routing::{delete, get, post},
|
||||||
|
Router,
|
||||||
|
};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tower_http::cors::CorsLayer;
|
||||||
|
use tracing::warn;
|
||||||
|
|
||||||
|
use self::models::MODELS;
|
||||||
|
use self::types::TokenRequest;
|
||||||
|
|
||||||
|
// ─── Shared state ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
pub struct AppState {
|
||||||
|
pub backend: Arc<crate::backend::Backend>,
|
||||||
|
pub sessions: SessionManager,
|
||||||
|
pub mitm_store: crate::mitm::store::MitmStore,
|
||||||
|
pub quota_store: crate::quota::QuotaStore,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── Router ──────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
pub fn router(state: Arc<AppState>) -> Router {
|
||||||
|
Router::new()
|
||||||
|
.route("/v1/responses", post(responses::handle_responses))
|
||||||
|
.route(
|
||||||
|
"/v1/chat/completions",
|
||||||
|
post(completions::handle_completions),
|
||||||
|
)
|
||||||
|
.route("/v1/models", get(handle_models))
|
||||||
|
.route("/v1/sessions", get(handle_list_sessions))
|
||||||
|
.route("/v1/sessions/{id}", delete(handle_delete_session))
|
||||||
|
.route("/v1/token", post(handle_set_token))
|
||||||
|
.route("/v1/usage", get(handle_usage))
|
||||||
|
.route("/v1/quota", get(handle_quota))
|
||||||
|
.route("/health", get(handle_health))
|
||||||
|
.route("/", get(handle_root))
|
||||||
|
.layer(CorsLayer::permissive())
|
||||||
|
.layer(DefaultBodyLimit::max(1_048_576)) // 1 MB
|
||||||
|
.with_state(state)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── Simple handlers ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async fn handle_root() -> Json<serde_json::Value> {
|
||||||
|
Json(serde_json::json!({
|
||||||
|
"service": "antigravity-openai-proxy",
|
||||||
|
"version": "3.2.0",
|
||||||
|
"runtime": "rust",
|
||||||
|
"endpoints": [
|
||||||
|
"/v1/chat/completions",
|
||||||
|
"/v1/responses",
|
||||||
|
"/v1/models",
|
||||||
|
"/v1/sessions",
|
||||||
|
"/v1/token",
|
||||||
|
"/v1/usage",
|
||||||
|
"/v1/quota",
|
||||||
|
"/health",
|
||||||
|
],
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handle_health() -> Json<serde_json::Value> {
|
||||||
|
Json(serde_json::json!({"status": "ok"}))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handle_models() -> Json<serde_json::Value> {
|
||||||
|
let models: Vec<serde_json::Value> = MODELS
|
||||||
|
.iter()
|
||||||
|
.map(|m| {
|
||||||
|
serde_json::json!({
|
||||||
|
"id": m.name,
|
||||||
|
"object": "model",
|
||||||
|
"created": 1700000000u64,
|
||||||
|
"owned_by": "antigravity",
|
||||||
|
"permission": [],
|
||||||
|
"root": m.name,
|
||||||
|
"parent": null,
|
||||||
|
"meta": {
|
||||||
|
"label": m.label,
|
||||||
|
"enum_value": m.model_enum,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
Json(serde_json::json!({"object": "list", "data": models}))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handle_list_sessions(
|
||||||
|
State(state): State<Arc<AppState>>,
|
||||||
|
) -> Json<serde_json::Value> {
|
||||||
|
let sessions = state.sessions.list_sessions().await;
|
||||||
|
Json(serde_json::json!({"sessions": sessions}))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handle_delete_session(
|
||||||
|
State(state): State<Arc<AppState>>,
|
||||||
|
axum::extract::Path(id): axum::extract::Path<String>,
|
||||||
|
) -> impl IntoResponse {
|
||||||
|
if state.sessions.delete_session(&id).await {
|
||||||
|
(
|
||||||
|
StatusCode::OK,
|
||||||
|
Json(serde_json::json!({"status": "deleted", "session_id": id})),
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
(
|
||||||
|
StatusCode::NOT_FOUND,
|
||||||
|
Json(serde_json::json!({"error": format!("Session not found: {id}")})),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handle_set_token(
|
||||||
|
State(state): State<Arc<AppState>>,
|
||||||
|
Json(body): Json<TokenRequest>,
|
||||||
|
) -> impl IntoResponse {
|
||||||
|
if !body.token.starts_with("ya29.") {
|
||||||
|
return (
|
||||||
|
StatusCode::BAD_REQUEST,
|
||||||
|
Json(serde_json::json!({"error": "Invalid token. Must start with ya29."})),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
state.backend.set_oauth_token(body.token.clone()).await;
|
||||||
|
|
||||||
|
// Also persist to file
|
||||||
|
let token_path = crate::constants::token_file_path();
|
||||||
|
if let Err(e) = tokio::fs::write(&token_path, &body.token).await {
|
||||||
|
warn!("Failed to write token file: {e}");
|
||||||
|
}
|
||||||
|
|
||||||
|
let preview = safe_truncate(&body.token, 20);
|
||||||
|
(
|
||||||
|
StatusCode::OK,
|
||||||
|
Json(serde_json::json!({"status": "ok", "token_prefix": preview})),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handle_usage(
|
||||||
|
State(state): State<Arc<AppState>>,
|
||||||
|
) -> Json<serde_json::Value> {
|
||||||
|
let stats = state.mitm_store.stats().await;
|
||||||
|
Json(serde_json::json!({
|
||||||
|
"mitm": {
|
||||||
|
"total_requests": stats.total_requests,
|
||||||
|
"total_input_tokens": stats.total_input_tokens,
|
||||||
|
"total_output_tokens": stats.total_output_tokens,
|
||||||
|
"total_cache_read_tokens": stats.total_cache_read_tokens,
|
||||||
|
"total_cache_creation_tokens": stats.total_cache_creation_tokens,
|
||||||
|
"total_thinking_output_tokens": stats.total_thinking_output_tokens,
|
||||||
|
"total_response_output_tokens": stats.total_response_output_tokens,
|
||||||
|
"total_tokens": stats.total_input_tokens + stats.total_output_tokens,
|
||||||
|
"per_model": stats.per_model,
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handle_quota(
|
||||||
|
State(state): State<Arc<AppState>>,
|
||||||
|
) -> Json<serde_json::Value> {
|
||||||
|
let snap = state.quota_store.snapshot().await;
|
||||||
|
Json(serde_json::to_value(snap).unwrap_or_default())
|
||||||
|
}
|
||||||
49
src/api/models.rs
Normal file
49
src/api/models.rs
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
//! Model definitions and lookup.
|
||||||
|
|
||||||
|
/// Model definition: friendly name → (antigravity_id, protobuf_enum, label).
|
||||||
|
pub(crate) struct ModelDef {
|
||||||
|
pub name: &'static str,
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub ag_id: &'static str,
|
||||||
|
pub model_enum: u32,
|
||||||
|
pub label: &'static str,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) const MODELS: &[ModelDef] = &[
|
||||||
|
ModelDef {
|
||||||
|
name: "opus-4.6",
|
||||||
|
ag_id: "MODEL_PLACEHOLDER_M26",
|
||||||
|
model_enum: 1026,
|
||||||
|
label: "Claude Opus 4.6 (Thinking)",
|
||||||
|
},
|
||||||
|
ModelDef {
|
||||||
|
name: "opus-4.5",
|
||||||
|
ag_id: "MODEL_PLACEHOLDER_M12",
|
||||||
|
model_enum: 1012,
|
||||||
|
label: "Claude Opus 4.5 (Thinking)",
|
||||||
|
},
|
||||||
|
ModelDef {
|
||||||
|
name: "gemini-3-pro-high",
|
||||||
|
ag_id: "MODEL_PLACEHOLDER_M8",
|
||||||
|
model_enum: 1008,
|
||||||
|
label: "Gemini 3 Pro (High)",
|
||||||
|
},
|
||||||
|
ModelDef {
|
||||||
|
name: "gemini-3-pro",
|
||||||
|
ag_id: "MODEL_PLACEHOLDER_M7",
|
||||||
|
model_enum: 1007,
|
||||||
|
label: "Gemini 3 Pro (Low)",
|
||||||
|
},
|
||||||
|
ModelDef {
|
||||||
|
name: "gemini-3-flash",
|
||||||
|
ag_id: "MODEL_PLACEHOLDER_M18",
|
||||||
|
model_enum: 1018,
|
||||||
|
label: "Gemini 3 Flash",
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
pub(crate) const DEFAULT_MODEL: &str = "opus-4.6";
|
||||||
|
|
||||||
|
pub(crate) fn lookup_model(name: &str) -> Option<&'static ModelDef> {
|
||||||
|
MODELS.iter().find(|m| m.name == name)
|
||||||
|
}
|
||||||
298
src/api/polling.rs
Normal file
298
src/api/polling.rs
Normal file
@@ -0,0 +1,298 @@
|
|||||||
|
//! Shared polling engine and step extraction helpers.
|
||||||
|
|
||||||
|
use rand::Rng;
|
||||||
|
use tracing::{debug, info, warn};
|
||||||
|
|
||||||
|
use super::AppState;
|
||||||
|
|
||||||
|
/// Real token usage reported by the language server.
|
||||||
|
#[derive(Debug, Clone, Default)]
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub(crate) struct ModelUsage {
|
||||||
|
pub input_tokens: u64,
|
||||||
|
pub output_tokens: u64,
|
||||||
|
pub api_provider: String,
|
||||||
|
pub model: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Result of polling — text + optional real usage data + thinking data.
|
||||||
|
pub(crate) struct PollResult {
|
||||||
|
pub text: String,
|
||||||
|
pub usage: Option<ModelUsage>,
|
||||||
|
/// Opaque Anthropic thinking verification signature from PLANNER_RESPONSE.
|
||||||
|
/// Required for multi-turn thinking model chaining.
|
||||||
|
pub thinking_signature: Option<String>,
|
||||||
|
/// The model's internal reasoning/thinking content.
|
||||||
|
/// Available for both Opus (Anthropic) and Gemini models.
|
||||||
|
pub thinking: Option<String>,
|
||||||
|
/// Time the model spent thinking, as reported by the LS (e.g. "0.041999832s").
|
||||||
|
pub thinking_duration: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Extract the response text from steps — scans in REVERSE to find the latest response.
|
||||||
|
pub(crate) fn extract_response_text(steps: &[serde_json::Value]) -> String {
|
||||||
|
for step in steps.iter().rev() {
|
||||||
|
let step_type = step["type"].as_str().unwrap_or("");
|
||||||
|
|
||||||
|
if step_type.contains("PLANNER_RESPONSE") {
|
||||||
|
let resp = &step["plannerResponse"];
|
||||||
|
let text = resp["rawResponse"]
|
||||||
|
.as_str()
|
||||||
|
.or_else(|| resp["response"].as_str())
|
||||||
|
.unwrap_or("");
|
||||||
|
if !text.is_empty() {
|
||||||
|
return text.to_string();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if step_type.contains("AI_RESPONSE") || step_type.contains("MODEL_RESPONSE") {
|
||||||
|
if let Some(text) = step["response"]
|
||||||
|
.as_str()
|
||||||
|
.or_else(|| step["rawResponse"].as_str())
|
||||||
|
.or_else(|| step["text"].as_str())
|
||||||
|
{
|
||||||
|
if !text.is_empty() {
|
||||||
|
return text.to_string();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
String::new()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Extract real token usage from the LS's modelUsage field.
|
||||||
|
/// The LS reports this in CHECKPOINT steps and sometimes in retryInfos.
|
||||||
|
/// Scans in reverse to find the latest usage data.
|
||||||
|
pub(crate) fn extract_model_usage(steps: &[serde_json::Value]) -> Option<ModelUsage> {
|
||||||
|
for step in steps.iter().rev() {
|
||||||
|
if let Some(usage) = step.get("metadata").and_then(|m| m.get("modelUsage")) {
|
||||||
|
let input = usage["inputTokens"]
|
||||||
|
.as_str()
|
||||||
|
.and_then(|s| s.parse::<u64>().ok())
|
||||||
|
.or_else(|| usage["inputTokens"].as_u64())
|
||||||
|
.unwrap_or(0);
|
||||||
|
let output = usage["outputTokens"]
|
||||||
|
.as_str()
|
||||||
|
.and_then(|s| s.parse::<u64>().ok())
|
||||||
|
.or_else(|| usage["outputTokens"].as_u64())
|
||||||
|
.unwrap_or(0);
|
||||||
|
|
||||||
|
if input > 0 || output > 0 {
|
||||||
|
return Some(ModelUsage {
|
||||||
|
input_tokens: input,
|
||||||
|
output_tokens: output,
|
||||||
|
api_provider: usage["apiProvider"]
|
||||||
|
.as_str()
|
||||||
|
.unwrap_or("")
|
||||||
|
.to_string(),
|
||||||
|
model: usage["model"]
|
||||||
|
.as_str()
|
||||||
|
.unwrap_or("")
|
||||||
|
.to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Extract the thinking signature from PLANNER_RESPONSE steps.
|
||||||
|
/// This is an opaque Base64 blob used by Anthropic for extended thinking
|
||||||
|
/// verification. Needed to chain multi-turn conversations with thinking models.
|
||||||
|
pub(crate) fn extract_thinking_signature(steps: &[serde_json::Value]) -> Option<String> {
|
||||||
|
for step in steps.iter().rev() {
|
||||||
|
let step_type = step["type"].as_str().unwrap_or("");
|
||||||
|
if step_type.contains("PLANNER_RESPONSE") {
|
||||||
|
if let Some(sig) = step["plannerResponse"]["thinkingSignature"].as_str() {
|
||||||
|
if !sig.is_empty() {
|
||||||
|
return Some(sig.to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Extract the model's thinking/reasoning content from PLANNER_RESPONSE steps.
|
||||||
|
/// This is the internal monologue produced during extended thinking.
|
||||||
|
/// Available for ALL models (Opus, Gemini Flash, Gemini Pro).
|
||||||
|
pub(crate) fn extract_thinking_content(steps: &[serde_json::Value]) -> Option<String> {
|
||||||
|
for step in steps.iter().rev() {
|
||||||
|
let step_type = step["type"].as_str().unwrap_or("");
|
||||||
|
if step_type.contains("PLANNER_RESPONSE") {
|
||||||
|
if let Some(thinking) = step["plannerResponse"]["thinking"].as_str() {
|
||||||
|
if !thinking.is_empty() {
|
||||||
|
return Some(thinking.to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Extract thinking duration from PLANNER_RESPONSE steps.
|
||||||
|
/// Returns the raw duration string as reported by the LS (e.g. "0.041999832s").
|
||||||
|
pub(crate) fn extract_thinking_duration(steps: &[serde_json::Value]) -> Option<String> {
|
||||||
|
for step in steps.iter().rev() {
|
||||||
|
let step_type = step["type"].as_str().unwrap_or("");
|
||||||
|
if step_type.contains("PLANNER_RESPONSE") {
|
||||||
|
if let Some(dur) = step["plannerResponse"]["thinkingDuration"].as_str() {
|
||||||
|
if !dur.is_empty() {
|
||||||
|
return Some(dur.to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if the cascade has truly finished — the last PLANNER_RESPONSE must be DONE
|
||||||
|
/// AND the very last step must be a terminal type (CHECKPOINT or PLANNER_RESPONSE with DONE).
|
||||||
|
/// This prevents false positives during agentic tool-call loops where intermediate
|
||||||
|
/// PLANNER_RESPONSE steps show DONE but the cascade keeps going.
|
||||||
|
pub(crate) fn is_response_done(steps: &[serde_json::Value]) -> bool {
|
||||||
|
if steps.is_empty() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
let last = &steps[steps.len() - 1];
|
||||||
|
let last_type = last["type"].as_str().unwrap_or("");
|
||||||
|
let last_status = last["status"].as_str().unwrap_or("");
|
||||||
|
|
||||||
|
// CHECKPOINT at the end = cascade is definitely done
|
||||||
|
if last_type.contains("CHECKPOINT") {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Last step is a PLANNER_RESPONSE with DONE = final answer (no more tool calls coming)
|
||||||
|
if (last_type.contains("PLANNER_RESPONSE")
|
||||||
|
|| last_type.contains("AI_RESPONSE")
|
||||||
|
|| last_type.contains("MODEL_RESPONSE"))
|
||||||
|
&& last_status.contains("DONE")
|
||||||
|
{
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Poll the backend until we get a response or timeout.
|
||||||
|
pub(crate) async fn poll_for_response(
|
||||||
|
state: &AppState,
|
||||||
|
cascade_id: &str,
|
||||||
|
timeout: u64,
|
||||||
|
) -> PollResult {
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let short_id = &cascade_id[..8.min(cascade_id.len())];
|
||||||
|
info!("Polling for response on cascade {short_id} (timeout={timeout}s)");
|
||||||
|
|
||||||
|
let mut last_step_count: usize = 0;
|
||||||
|
|
||||||
|
while start.elapsed().as_secs() < timeout {
|
||||||
|
if let Ok((status, data)) = state.backend.get_steps(cascade_id).await {
|
||||||
|
if status == 200 {
|
||||||
|
if let Some(steps) = data["steps"].as_array() {
|
||||||
|
let step_count = steps.len();
|
||||||
|
|
||||||
|
// Only log when step count changes (denoised)
|
||||||
|
if step_count != last_step_count {
|
||||||
|
// Compact type summary: count unique types
|
||||||
|
let mut type_counts: std::collections::BTreeMap<&str, usize> =
|
||||||
|
std::collections::BTreeMap::new();
|
||||||
|
for s in steps.iter() {
|
||||||
|
let t = s["type"]
|
||||||
|
.as_str()
|
||||||
|
.unwrap_or("?")
|
||||||
|
.strip_prefix("CORTEX_STEP_TYPE_")
|
||||||
|
.unwrap_or("?");
|
||||||
|
*type_counts.entry(t).or_insert(0) += 1;
|
||||||
|
}
|
||||||
|
let summary: Vec<String> = type_counts
|
||||||
|
.iter()
|
||||||
|
.map(|(t, c)| {
|
||||||
|
if *c > 1 {
|
||||||
|
format!("{t}×{c}")
|
||||||
|
} else {
|
||||||
|
t.to_string()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
debug!(
|
||||||
|
"Poll {short_id}: {step_count} steps [{}]",
|
||||||
|
summary.join(", ")
|
||||||
|
);
|
||||||
|
last_step_count = step_count;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the cascade is truly done
|
||||||
|
if is_response_done(steps) {
|
||||||
|
let text = extract_response_text(steps);
|
||||||
|
if !text.is_empty() {
|
||||||
|
let usage = extract_model_usage(steps);
|
||||||
|
let thinking_signature = extract_thinking_signature(steps);
|
||||||
|
let thinking = extract_thinking_content(steps);
|
||||||
|
let thinking_duration = extract_thinking_duration(steps);
|
||||||
|
let elapsed = start.elapsed().as_secs_f32();
|
||||||
|
if let Some(ref u) = usage {
|
||||||
|
info!(
|
||||||
|
"Response done ({short_id}), {:.1}s, {} chars, tokens: {}in/{}out ({}){}{}",
|
||||||
|
elapsed, text.len(), u.input_tokens, u.output_tokens, u.model,
|
||||||
|
if thinking.is_some() { format!(", thinking: {} chars", thinking.as_ref().unwrap().len()) } else { String::new() },
|
||||||
|
if thinking_signature.is_some() { ", has sig" } else { "" }
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
info!(
|
||||||
|
"Response done ({short_id}), {:.1}s, {} chars (no usage){}{}",
|
||||||
|
elapsed, text.len(),
|
||||||
|
if thinking.is_some() { format!(", thinking: {} chars", thinking.as_ref().unwrap().len()) } else { String::new() },
|
||||||
|
if thinking_signature.is_some() { ", has sig" } else { "" }
|
||||||
|
);
|
||||||
|
}
|
||||||
|
return PollResult { text, usage, thinking_signature, thinking, thinking_duration };
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback: check trajectory IDLE status (catches edge cases)
|
||||||
|
// Only check every 5th poll to reduce network calls
|
||||||
|
if step_count > 4 && step_count % 5 == 0 {
|
||||||
|
if let Ok((ts, td)) = state.backend.get_trajectory(cascade_id).await
|
||||||
|
{
|
||||||
|
if ts == 200 {
|
||||||
|
let run_status =
|
||||||
|
td["status"].as_str().unwrap_or("");
|
||||||
|
if run_status.contains("IDLE") {
|
||||||
|
let text = extract_response_text(steps);
|
||||||
|
if !text.is_empty() {
|
||||||
|
let usage = extract_model_usage(steps);
|
||||||
|
let thinking_signature = extract_thinking_signature(steps);
|
||||||
|
let thinking = extract_thinking_content(steps);
|
||||||
|
let thinking_duration = extract_thinking_duration(steps);
|
||||||
|
let elapsed = start.elapsed().as_secs_f32();
|
||||||
|
info!(
|
||||||
|
"Trajectory IDLE ({short_id}), {:.1}s, {} chars",
|
||||||
|
elapsed,
|
||||||
|
text.len()
|
||||||
|
);
|
||||||
|
return PollResult { text, usage, thinking_signature, thinking, thinking_duration };
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let poll_ms: u64 = rand::thread_rng().gen_range(1000..1800);
|
||||||
|
tokio::time::sleep(tokio::time::Duration::from_millis(poll_ms)).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
warn!("Timeout after {timeout}s on cascade {short_id}");
|
||||||
|
PollResult {
|
||||||
|
text: "[Timeout waiting for AI response]".to_string(),
|
||||||
|
usage: None,
|
||||||
|
thinking_signature: None,
|
||||||
|
thinking: None,
|
||||||
|
thinking_duration: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
686
src/api/responses.rs
Normal file
686
src/api/responses.rs
Normal file
@@ -0,0 +1,686 @@
|
|||||||
|
//! OpenAI Responses API (/v1/responses) handler.
|
||||||
|
//!
|
||||||
|
//! Strictly adheres to the official OpenAI Responses API protocol:
|
||||||
|
//! https://platform.openai.com/docs/api-reference/responses
|
||||||
|
|
||||||
|
use axum::{
|
||||||
|
extract::State,
|
||||||
|
http::StatusCode,
|
||||||
|
response::{sse::Event, IntoResponse, Json, Sse},
|
||||||
|
};
|
||||||
|
use rand::Rng;
|
||||||
|
use std::sync::atomic::{AtomicU32, Ordering};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tracing::{debug, info};
|
||||||
|
|
||||||
|
use super::models::{lookup_model, DEFAULT_MODEL, MODELS};
|
||||||
|
use super::polling::{extract_response_text, is_response_done, poll_for_response, extract_model_usage, extract_thinking_signature, extract_thinking_content, extract_thinking_duration};
|
||||||
|
use super::types::*;
|
||||||
|
use super::util::{err_response, now_unix, responses_sse_event};
|
||||||
|
use super::AppState;
|
||||||
|
|
||||||
|
// ─── Input extraction ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Extract user text from Responses API `input` field.
|
||||||
|
fn extract_responses_input(input: &serde_json::Value, instructions: Option<&str>) -> String {
|
||||||
|
let user_text = match input {
|
||||||
|
serde_json::Value::String(s) => s.clone(),
|
||||||
|
serde_json::Value::Array(items) => {
|
||||||
|
items
|
||||||
|
.iter()
|
||||||
|
.rev()
|
||||||
|
.find(|item| item["role"].as_str() == Some("user"))
|
||||||
|
.and_then(|item| match &item["content"] {
|
||||||
|
serde_json::Value::String(s) => Some(s.clone()),
|
||||||
|
serde_json::Value::Array(parts) => Some(
|
||||||
|
parts
|
||||||
|
.iter()
|
||||||
|
.filter(|p| {
|
||||||
|
let t = p["type"].as_str().unwrap_or("");
|
||||||
|
t == "input_text" || t == "text"
|
||||||
|
})
|
||||||
|
.filter_map(|p| p["text"].as_str())
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join(" "),
|
||||||
|
),
|
||||||
|
_ => None,
|
||||||
|
})
|
||||||
|
.unwrap_or_default()
|
||||||
|
}
|
||||||
|
_ => String::new(),
|
||||||
|
};
|
||||||
|
|
||||||
|
match instructions {
|
||||||
|
Some(inst) if !inst.is_empty() => format!("{inst}\n\n{user_text}"),
|
||||||
|
_ => user_text,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Extract conversation/session ID from Responses API `conversation` field.
|
||||||
|
fn extract_conversation_id(conv: &Option<serde_json::Value>) -> Option<String> {
|
||||||
|
match conv {
|
||||||
|
Some(serde_json::Value::String(s)) => Some(s.clone()),
|
||||||
|
Some(obj) => obj["id"].as_str().map(|s| s.to_string()),
|
||||||
|
None => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Build a full Response object matching the official OpenAI schema.
|
||||||
|
fn build_response_object(
|
||||||
|
id: &str,
|
||||||
|
model: &str,
|
||||||
|
status: &'static str,
|
||||||
|
created_at: u64,
|
||||||
|
completed_at: Option<u64>,
|
||||||
|
output: Vec<ResponseOutput>,
|
||||||
|
usage: Option<Usage>,
|
||||||
|
instructions: Option<&str>,
|
||||||
|
store: bool,
|
||||||
|
temperature: f64,
|
||||||
|
top_p: f64,
|
||||||
|
max_output_tokens: Option<u64>,
|
||||||
|
previous_response_id: Option<&str>,
|
||||||
|
user: Option<&str>,
|
||||||
|
metadata: &serde_json::Value,
|
||||||
|
thinking_signature: Option<String>,
|
||||||
|
thinking: Option<String>,
|
||||||
|
thinking_duration: Option<String>,
|
||||||
|
) -> ResponsesResponse {
|
||||||
|
ResponsesResponse {
|
||||||
|
id: id.to_string(),
|
||||||
|
object: "response",
|
||||||
|
created_at,
|
||||||
|
status,
|
||||||
|
completed_at,
|
||||||
|
error: None,
|
||||||
|
incomplete_details: None,
|
||||||
|
instructions: instructions.map(|s| s.to_string()),
|
||||||
|
max_output_tokens,
|
||||||
|
model: model.to_string(),
|
||||||
|
output,
|
||||||
|
parallel_tool_calls: true,
|
||||||
|
previous_response_id: previous_response_id.map(|s| s.to_string()),
|
||||||
|
reasoning: Reasoning::default(),
|
||||||
|
store,
|
||||||
|
temperature,
|
||||||
|
text: TextFormat::default(),
|
||||||
|
tool_choice: "auto",
|
||||||
|
tools: vec![],
|
||||||
|
top_p,
|
||||||
|
truncation: "disabled",
|
||||||
|
usage,
|
||||||
|
user: user.map(|s| s.to_string()),
|
||||||
|
metadata: metadata.clone(),
|
||||||
|
thinking_signature,
|
||||||
|
thinking,
|
||||||
|
thinking_duration,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Serialize a ResponsesResponse to serde_json::Value for SSE embedding.
|
||||||
|
fn response_to_json(resp: &ResponsesResponse) -> serde_json::Value {
|
||||||
|
serde_json::to_value(resp).unwrap_or(serde_json::json!({}))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── Handler ─────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
pub(crate) async fn handle_responses(
|
||||||
|
State(state): State<Arc<AppState>>,
|
||||||
|
Json(body): Json<ResponsesRequest>,
|
||||||
|
) -> axum::response::Response {
|
||||||
|
info!(
|
||||||
|
"POST /v1/responses model={} stream={}",
|
||||||
|
body.model.as_deref().unwrap_or(DEFAULT_MODEL),
|
||||||
|
body.stream
|
||||||
|
);
|
||||||
|
|
||||||
|
let model_name = body.model.as_deref().unwrap_or(DEFAULT_MODEL);
|
||||||
|
let model = match lookup_model(model_name) {
|
||||||
|
Some(m) => m,
|
||||||
|
None => {
|
||||||
|
let names: Vec<&str> = MODELS.iter().map(|m| m.name).collect();
|
||||||
|
return err_response(
|
||||||
|
StatusCode::BAD_REQUEST,
|
||||||
|
format!("Unknown model: {model_name}. Available: {names:?}"),
|
||||||
|
"invalid_request_error",
|
||||||
|
);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let token = state.backend.oauth_token().await;
|
||||||
|
if token.is_empty() {
|
||||||
|
return err_response(
|
||||||
|
StatusCode::UNAUTHORIZED,
|
||||||
|
"No OAuth token. POST to /v1/token or set ANTIGRAVITY_OAUTH_TOKEN env var.".into(),
|
||||||
|
"authentication_error",
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let user_text = extract_responses_input(&body.input, body.instructions.as_deref());
|
||||||
|
if user_text.is_empty() {
|
||||||
|
return err_response(
|
||||||
|
StatusCode::BAD_REQUEST,
|
||||||
|
"No user input found".to_string(),
|
||||||
|
"invalid_request_error",
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let response_id = format!(
|
||||||
|
"resp_{}",
|
||||||
|
uuid::Uuid::new_v4().to_string().replace('-', "")
|
||||||
|
);
|
||||||
|
|
||||||
|
// Session/conversation management
|
||||||
|
let session_id_str = extract_conversation_id(&body.conversation);
|
||||||
|
let cascade_id = if let Some(ref sid) = session_id_str {
|
||||||
|
match state
|
||||||
|
.sessions
|
||||||
|
.get_or_create(Some(sid), || state.backend.create_cascade())
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(sr) => sr.cascade_id,
|
||||||
|
Err(e) => {
|
||||||
|
return err_response(
|
||||||
|
StatusCode::BAD_GATEWAY,
|
||||||
|
format!("StartCascade failed: {e}"),
|
||||||
|
"server_error",
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
match state.backend.create_cascade().await {
|
||||||
|
Ok(cid) => cid,
|
||||||
|
Err(e) => {
|
||||||
|
return err_response(
|
||||||
|
StatusCode::BAD_GATEWAY,
|
||||||
|
format!("StartCascade failed: {e}"),
|
||||||
|
"server_error",
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Send message
|
||||||
|
match state
|
||||||
|
.backend
|
||||||
|
.send_message(&cascade_id, &user_text, model.model_enum)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok((status, _)) if status == 200 => {
|
||||||
|
let bg = Arc::clone(&state.backend);
|
||||||
|
let cid = cascade_id.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let _ = bg.update_annotations(&cid).await;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
Ok((status, _)) => {
|
||||||
|
return err_response(
|
||||||
|
StatusCode::BAD_GATEWAY,
|
||||||
|
format!("Antigravity returned {status}"),
|
||||||
|
"server_error",
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
return err_response(
|
||||||
|
StatusCode::BAD_GATEWAY,
|
||||||
|
format!("Send message failed: {e}"),
|
||||||
|
"server_error",
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Capture request params for response building
|
||||||
|
let req_params = RequestParams {
|
||||||
|
user_text: user_text.clone(),
|
||||||
|
instructions: body.instructions.clone(),
|
||||||
|
store: body.store,
|
||||||
|
temperature: body.temperature.unwrap_or(1.0),
|
||||||
|
top_p: body.top_p.unwrap_or(1.0),
|
||||||
|
max_output_tokens: body.max_output_tokens,
|
||||||
|
previous_response_id: body.previous_response_id.clone(),
|
||||||
|
user: body.user.clone(),
|
||||||
|
metadata: body.metadata.clone().unwrap_or(serde_json::json!({})),
|
||||||
|
};
|
||||||
|
|
||||||
|
if body.stream {
|
||||||
|
handle_responses_stream(
|
||||||
|
state, response_id, model_name.to_string(), cascade_id,
|
||||||
|
body.timeout, req_params,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
} else {
|
||||||
|
handle_responses_sync(
|
||||||
|
state, response_id, model_name.to_string(), cascade_id,
|
||||||
|
body.timeout, req_params,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Captured request parameters needed to echo back in the response.
|
||||||
|
struct RequestParams {
|
||||||
|
user_text: String,
|
||||||
|
instructions: Option<String>,
|
||||||
|
store: bool,
|
||||||
|
temperature: f64,
|
||||||
|
top_p: f64,
|
||||||
|
max_output_tokens: Option<u64>,
|
||||||
|
previous_response_id: Option<String>,
|
||||||
|
user: Option<String>,
|
||||||
|
metadata: serde_json::Value,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Build Usage from the best available source:
|
||||||
|
/// 1. MITM intercepted data (real API tokens, including cache stats)
|
||||||
|
/// 2. LS trajectory data (real tokens, no cache info)
|
||||||
|
/// 3. Estimation from text lengths (fallback)
|
||||||
|
async fn usage_from_poll(
|
||||||
|
mitm_store: &crate::mitm::store::MitmStore,
|
||||||
|
cascade_id: &str,
|
||||||
|
model_usage: &Option<super::polling::ModelUsage>,
|
||||||
|
input_text: &str,
|
||||||
|
output_text: &str,
|
||||||
|
) -> Usage {
|
||||||
|
// Priority 1: MITM intercepted data (most accurate — includes cache tokens)
|
||||||
|
if let Some(mitm_usage) = mitm_store.take_usage(cascade_id).await {
|
||||||
|
tracing::debug!(
|
||||||
|
input = mitm_usage.input_tokens,
|
||||||
|
output = mitm_usage.output_tokens,
|
||||||
|
cache_read = mitm_usage.cache_read_input_tokens,
|
||||||
|
cache_create = mitm_usage.cache_creation_input_tokens,
|
||||||
|
thinking = mitm_usage.thinking_output_tokens,
|
||||||
|
"Using MITM intercepted usage"
|
||||||
|
);
|
||||||
|
return Usage {
|
||||||
|
input_tokens: mitm_usage.input_tokens,
|
||||||
|
input_tokens_details: InputTokensDetails {
|
||||||
|
cached_tokens: mitm_usage.cache_read_input_tokens,
|
||||||
|
},
|
||||||
|
output_tokens: mitm_usage.output_tokens,
|
||||||
|
output_tokens_details: OutputTokensDetails {
|
||||||
|
reasoning_tokens: mitm_usage.thinking_output_tokens,
|
||||||
|
},
|
||||||
|
total_tokens: mitm_usage.input_tokens + mitm_usage.output_tokens,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Priority 2: LS trajectory data (from CHECKPOINT/metadata steps)
|
||||||
|
if let Some(u) = model_usage {
|
||||||
|
return Usage {
|
||||||
|
input_tokens: u.input_tokens,
|
||||||
|
input_tokens_details: InputTokensDetails { cached_tokens: 0 },
|
||||||
|
output_tokens: u.output_tokens,
|
||||||
|
output_tokens_details: OutputTokensDetails { reasoning_tokens: 0 },
|
||||||
|
total_tokens: u.input_tokens + u.output_tokens,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Priority 3: Estimate from text lengths
|
||||||
|
Usage::estimate(input_text, output_text)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── Sync response ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async fn handle_responses_sync(
|
||||||
|
state: Arc<AppState>,
|
||||||
|
response_id: String,
|
||||||
|
model_name: String,
|
||||||
|
cascade_id: String,
|
||||||
|
timeout: u64,
|
||||||
|
params: RequestParams,
|
||||||
|
) -> axum::response::Response {
|
||||||
|
let created_at = now_unix();
|
||||||
|
let poll_result = poll_for_response(&state, &cascade_id, timeout).await;
|
||||||
|
let completed_at = now_unix();
|
||||||
|
let msg_id = format!(
|
||||||
|
"msg_{}",
|
||||||
|
uuid::Uuid::new_v4().to_string().replace('-', "")
|
||||||
|
);
|
||||||
|
|
||||||
|
let usage = usage_from_poll(&state.mitm_store, &cascade_id, &poll_result.usage, ¶ms.user_text, &poll_result.text).await;
|
||||||
|
|
||||||
|
let resp = build_response_object(
|
||||||
|
&response_id,
|
||||||
|
&model_name,
|
||||||
|
"completed",
|
||||||
|
created_at,
|
||||||
|
Some(completed_at),
|
||||||
|
vec![ResponseOutput {
|
||||||
|
output_type: "message",
|
||||||
|
id: msg_id,
|
||||||
|
status: "completed",
|
||||||
|
role: "assistant",
|
||||||
|
content: vec![OutputContent {
|
||||||
|
content_type: "output_text",
|
||||||
|
text: poll_result.text,
|
||||||
|
annotations: vec![],
|
||||||
|
}],
|
||||||
|
}],
|
||||||
|
Some(usage),
|
||||||
|
params.instructions.as_deref(),
|
||||||
|
params.store,
|
||||||
|
params.temperature,
|
||||||
|
params.top_p,
|
||||||
|
params.max_output_tokens,
|
||||||
|
params.previous_response_id.as_deref(),
|
||||||
|
params.user.as_deref(),
|
||||||
|
¶ms.metadata,
|
||||||
|
poll_result.thinking_signature,
|
||||||
|
poll_result.thinking,
|
||||||
|
poll_result.thinking_duration,
|
||||||
|
);
|
||||||
|
|
||||||
|
Json(resp).into_response()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── Streaming response ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async fn handle_responses_stream(
|
||||||
|
state: Arc<AppState>,
|
||||||
|
response_id: String,
|
||||||
|
model_name: String,
|
||||||
|
cascade_id: String,
|
||||||
|
timeout: u64,
|
||||||
|
params: RequestParams,
|
||||||
|
) -> axum::response::Response {
|
||||||
|
let stream = async_stream::stream! {
|
||||||
|
let msg_id = format!("msg_{}", uuid::Uuid::new_v4().to_string().replace('-', ""));
|
||||||
|
let created_at = now_unix();
|
||||||
|
let seq = AtomicU32::new(0);
|
||||||
|
let next_seq = || seq.fetch_add(1, Ordering::Relaxed);
|
||||||
|
const CONTENT_IDX: u32 = 0;
|
||||||
|
const OUTPUT_IDX: u32 = 0;
|
||||||
|
|
||||||
|
// Build the in-progress response shell (no output yet)
|
||||||
|
let in_progress_resp = build_response_object(
|
||||||
|
&response_id, &model_name, "in_progress", created_at, None,
|
||||||
|
vec![], None,
|
||||||
|
params.instructions.as_deref(), params.store,
|
||||||
|
params.temperature, params.top_p,
|
||||||
|
params.max_output_tokens, params.previous_response_id.as_deref(),
|
||||||
|
params.user.as_deref(), ¶ms.metadata,
|
||||||
|
None, None, None,
|
||||||
|
);
|
||||||
|
let resp_json = response_to_json(&in_progress_resp);
|
||||||
|
|
||||||
|
// 1. response.created
|
||||||
|
yield Ok::<_, std::convert::Infallible>(responses_sse_event(
|
||||||
|
"response.created",
|
||||||
|
serde_json::json!({
|
||||||
|
"type": "response.created",
|
||||||
|
"sequence_number": next_seq(),
|
||||||
|
"response": resp_json,
|
||||||
|
}),
|
||||||
|
));
|
||||||
|
|
||||||
|
// 2. response.in_progress
|
||||||
|
yield Ok(responses_sse_event(
|
||||||
|
"response.in_progress",
|
||||||
|
serde_json::json!({
|
||||||
|
"type": "response.in_progress",
|
||||||
|
"sequence_number": next_seq(),
|
||||||
|
"response": resp_json,
|
||||||
|
}),
|
||||||
|
));
|
||||||
|
|
||||||
|
// 3. response.output_item.added
|
||||||
|
yield Ok(responses_sse_event(
|
||||||
|
"response.output_item.added",
|
||||||
|
serde_json::json!({
|
||||||
|
"type": "response.output_item.added",
|
||||||
|
"sequence_number": next_seq(),
|
||||||
|
"output_index": OUTPUT_IDX,
|
||||||
|
"item": {
|
||||||
|
"type": "message",
|
||||||
|
"id": &msg_id,
|
||||||
|
"status": "in_progress",
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [],
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
));
|
||||||
|
|
||||||
|
// 4. response.content_part.added
|
||||||
|
yield Ok(responses_sse_event(
|
||||||
|
"response.content_part.added",
|
||||||
|
serde_json::json!({
|
||||||
|
"type": "response.content_part.added",
|
||||||
|
"sequence_number": next_seq(),
|
||||||
|
"output_index": OUTPUT_IDX,
|
||||||
|
"content_index": CONTENT_IDX,
|
||||||
|
"part": {
|
||||||
|
"type": "output_text",
|
||||||
|
"text": "",
|
||||||
|
"annotations": [],
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
));
|
||||||
|
|
||||||
|
// 5. Poll and emit text deltas
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let mut last_text = String::new();
|
||||||
|
|
||||||
|
while start.elapsed().as_secs() < timeout {
|
||||||
|
if let Ok((status, data)) = state.backend.get_steps(&cascade_id).await {
|
||||||
|
if status == 200 {
|
||||||
|
if let Some(steps) = data["steps"].as_array() {
|
||||||
|
let text = extract_response_text(steps);
|
||||||
|
|
||||||
|
if !text.is_empty() && text != last_text {
|
||||||
|
let new_content = if text.len() > last_text.len()
|
||||||
|
&& text.starts_with(&*last_text)
|
||||||
|
{
|
||||||
|
&text[last_text.len()..]
|
||||||
|
} else {
|
||||||
|
&text
|
||||||
|
};
|
||||||
|
|
||||||
|
if !new_content.is_empty() {
|
||||||
|
yield Ok(responses_sse_event(
|
||||||
|
"response.output_text.delta",
|
||||||
|
serde_json::json!({
|
||||||
|
"type": "response.output_text.delta",
|
||||||
|
"sequence_number": next_seq(),
|
||||||
|
"item_id": &msg_id,
|
||||||
|
"output_index": OUTPUT_IDX,
|
||||||
|
"content_index": CONTENT_IDX,
|
||||||
|
"delta": new_content,
|
||||||
|
}),
|
||||||
|
));
|
||||||
|
last_text = text.to_string();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if response is done AND we have text
|
||||||
|
if is_response_done(steps) && !last_text.is_empty() {
|
||||||
|
debug!("Response done, text length={}", last_text.len());
|
||||||
|
let mu = extract_model_usage(steps);
|
||||||
|
let usage = usage_from_poll(&state.mitm_store, &cascade_id, &mu, ¶ms.user_text, &last_text).await;
|
||||||
|
let ts = extract_thinking_signature(steps);
|
||||||
|
let tc = extract_thinking_content(steps);
|
||||||
|
let td = extract_thinking_duration(steps);
|
||||||
|
for evt in completion_events(
|
||||||
|
&response_id, &model_name, &msg_id,
|
||||||
|
OUTPUT_IDX, CONTENT_IDX, &last_text, usage,
|
||||||
|
created_at, &seq, ¶ms, ts, tc, td,
|
||||||
|
) {
|
||||||
|
yield Ok(evt);
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// IDLE fallback: check trajectory status periodically
|
||||||
|
let step_count = steps.len();
|
||||||
|
if step_count > 4 && step_count % 5 == 0 {
|
||||||
|
if let Ok((ts, td)) = state.backend.get_trajectory(&cascade_id).await {
|
||||||
|
if ts == 200 {
|
||||||
|
let run_status = td["status"].as_str().unwrap_or("");
|
||||||
|
if run_status.contains("IDLE") && !last_text.is_empty() {
|
||||||
|
debug!("Trajectory IDLE, text length={}", last_text.len());
|
||||||
|
let mu = extract_model_usage(steps);
|
||||||
|
let usage = usage_from_poll(&state.mitm_store, &cascade_id, &mu, ¶ms.user_text, &last_text).await;
|
||||||
|
let ts = extract_thinking_signature(steps);
|
||||||
|
let tc = extract_thinking_content(steps);
|
||||||
|
let td = extract_thinking_duration(steps);
|
||||||
|
for evt in completion_events(
|
||||||
|
&response_id, &model_name, &msg_id,
|
||||||
|
OUTPUT_IDX, CONTENT_IDX, &last_text, usage,
|
||||||
|
created_at, &seq, ¶ms, ts, tc, td,
|
||||||
|
) {
|
||||||
|
yield Ok(evt);
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let poll_ms: u64 = rand::thread_rng().gen_range(800..1200);
|
||||||
|
tokio::time::sleep(tokio::time::Duration::from_millis(poll_ms)).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Timeout — emit incomplete response
|
||||||
|
let timeout_resp = build_response_object(
|
||||||
|
&response_id, &model_name, "incomplete", created_at, None,
|
||||||
|
vec![], Some(Usage::estimate(¶ms.user_text, "")),
|
||||||
|
params.instructions.as_deref(), params.store,
|
||||||
|
params.temperature, params.top_p,
|
||||||
|
params.max_output_tokens, params.previous_response_id.as_deref(),
|
||||||
|
params.user.as_deref(), ¶ms.metadata,
|
||||||
|
None, None, None,
|
||||||
|
);
|
||||||
|
yield Ok(responses_sse_event(
|
||||||
|
"response.completed",
|
||||||
|
serde_json::json!({
|
||||||
|
"type": "response.completed",
|
||||||
|
"sequence_number": next_seq(),
|
||||||
|
"response": response_to_json(&timeout_resp),
|
||||||
|
}),
|
||||||
|
));
|
||||||
|
};
|
||||||
|
|
||||||
|
Sse::new(stream)
|
||||||
|
.keep_alive(
|
||||||
|
axum::response::sse::KeepAlive::new()
|
||||||
|
.interval(std::time::Duration::from_secs(15))
|
||||||
|
.text(""),
|
||||||
|
)
|
||||||
|
.into_response()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── SSE completion events ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Build the completion SSE events sequence matching the official protocol:
|
||||||
|
/// 1. response.output_text.done
|
||||||
|
/// 2. response.content_part.done
|
||||||
|
/// 3. response.output_item.done
|
||||||
|
/// 4. response.completed
|
||||||
|
fn completion_events(
|
||||||
|
resp_id: &str,
|
||||||
|
model: &str,
|
||||||
|
msg_id: &str,
|
||||||
|
out_idx: u32,
|
||||||
|
content_idx: u32,
|
||||||
|
text: &str,
|
||||||
|
usage: Usage,
|
||||||
|
created_at: u64,
|
||||||
|
seq: &AtomicU32,
|
||||||
|
params: &RequestParams,
|
||||||
|
thinking_signature: Option<String>,
|
||||||
|
thinking: Option<String>,
|
||||||
|
thinking_duration: Option<String>,
|
||||||
|
) -> Vec<Event> {
|
||||||
|
let next_seq = || seq.fetch_add(1, Ordering::Relaxed);
|
||||||
|
let completed_at = now_unix();
|
||||||
|
|
||||||
|
let output_item = serde_json::json!({
|
||||||
|
"type": "message",
|
||||||
|
"id": msg_id,
|
||||||
|
"status": "completed",
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [{
|
||||||
|
"type": "output_text",
|
||||||
|
"text": text,
|
||||||
|
"annotations": [],
|
||||||
|
}],
|
||||||
|
});
|
||||||
|
|
||||||
|
let completed_resp = build_response_object(
|
||||||
|
resp_id, model, "completed", created_at, Some(completed_at),
|
||||||
|
vec![ResponseOutput {
|
||||||
|
output_type: "message",
|
||||||
|
id: msg_id.to_string(),
|
||||||
|
status: "completed",
|
||||||
|
role: "assistant",
|
||||||
|
content: vec![OutputContent {
|
||||||
|
content_type: "output_text",
|
||||||
|
text: text.to_string(),
|
||||||
|
annotations: vec![],
|
||||||
|
}],
|
||||||
|
}],
|
||||||
|
Some(usage),
|
||||||
|
params.instructions.as_deref(),
|
||||||
|
params.store,
|
||||||
|
params.temperature,
|
||||||
|
params.top_p,
|
||||||
|
params.max_output_tokens,
|
||||||
|
params.previous_response_id.as_deref(),
|
||||||
|
params.user.as_deref(),
|
||||||
|
¶ms.metadata,
|
||||||
|
thinking_signature,
|
||||||
|
thinking,
|
||||||
|
thinking_duration,
|
||||||
|
);
|
||||||
|
|
||||||
|
vec![
|
||||||
|
// 1. response.output_text.done
|
||||||
|
responses_sse_event(
|
||||||
|
"response.output_text.done",
|
||||||
|
serde_json::json!({
|
||||||
|
"type": "response.output_text.done",
|
||||||
|
"sequence_number": next_seq(),
|
||||||
|
"item_id": msg_id,
|
||||||
|
"output_index": out_idx,
|
||||||
|
"content_index": content_idx,
|
||||||
|
"text": text,
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
// 2. response.content_part.done
|
||||||
|
responses_sse_event(
|
||||||
|
"response.content_part.done",
|
||||||
|
serde_json::json!({
|
||||||
|
"type": "response.content_part.done",
|
||||||
|
"sequence_number": next_seq(),
|
||||||
|
"output_index": out_idx,
|
||||||
|
"content_index": content_idx,
|
||||||
|
"part": {
|
||||||
|
"type": "output_text",
|
||||||
|
"text": text,
|
||||||
|
"annotations": [],
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
// 3. response.output_item.done
|
||||||
|
responses_sse_event(
|
||||||
|
"response.output_item.done",
|
||||||
|
serde_json::json!({
|
||||||
|
"type": "response.output_item.done",
|
||||||
|
"sequence_number": next_seq(),
|
||||||
|
"output_index": out_idx,
|
||||||
|
"item": output_item,
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
// 4. response.completed
|
||||||
|
responses_sse_event(
|
||||||
|
"response.completed",
|
||||||
|
serde_json::json!({
|
||||||
|
"type": "response.completed",
|
||||||
|
"sequence_number": next_seq(),
|
||||||
|
"response": response_to_json(&completed_resp),
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
}
|
||||||
241
src/api/types.rs
Normal file
241
src/api/types.rs
Normal file
@@ -0,0 +1,241 @@
|
|||||||
|
//! Request/response types for the OpenAI-compatible API.
|
||||||
|
//!
|
||||||
|
//! All response shapes strictly match the official OpenAI Responses API spec:
|
||||||
|
//! https://platform.openai.com/docs/api-reference/responses
|
||||||
|
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
// ─── Request types ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
pub(crate) struct ResponsesRequest {
|
||||||
|
pub model: Option<String>,
|
||||||
|
pub input: serde_json::Value,
|
||||||
|
#[serde(default)]
|
||||||
|
pub instructions: Option<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub stream: bool,
|
||||||
|
#[serde(default = "default_timeout")]
|
||||||
|
pub timeout: u64,
|
||||||
|
pub conversation: Option<serde_json::Value>,
|
||||||
|
#[serde(default = "default_true")]
|
||||||
|
pub store: bool,
|
||||||
|
#[serde(default)]
|
||||||
|
pub temperature: Option<f64>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub top_p: Option<f64>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub max_output_tokens: Option<u64>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub previous_response_id: Option<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub metadata: Option<serde_json::Value>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub user: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Chat Completions request (OpenAI-compatible).
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
pub(crate) struct CompletionRequest {
|
||||||
|
pub model: Option<String>,
|
||||||
|
pub messages: Vec<CompletionMessage>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub stream: bool,
|
||||||
|
#[serde(default = "default_timeout")]
|
||||||
|
pub timeout: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
pub(crate) struct CompletionMessage {
|
||||||
|
pub role: String,
|
||||||
|
pub content: serde_json::Value,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_timeout() -> u64 {
|
||||||
|
120
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_true() -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── Response types (official OpenAI Responses API shape) ────────────────────
|
||||||
|
|
||||||
|
/// Top-level Response object — matches OpenAI exactly.
|
||||||
|
#[derive(Serialize, Clone)]
|
||||||
|
pub(crate) struct ResponsesResponse {
|
||||||
|
pub id: String,
|
||||||
|
pub object: &'static str,
|
||||||
|
pub created_at: u64,
|
||||||
|
pub status: &'static str,
|
||||||
|
#[serde(serialize_with = "serialize_option_u64")]
|
||||||
|
pub completed_at: Option<u64>,
|
||||||
|
pub error: Option<serde_json::Value>,
|
||||||
|
pub incomplete_details: Option<serde_json::Value>,
|
||||||
|
pub instructions: Option<String>,
|
||||||
|
#[serde(serialize_with = "serialize_option_u64")]
|
||||||
|
pub max_output_tokens: Option<u64>,
|
||||||
|
pub model: String,
|
||||||
|
pub output: Vec<ResponseOutput>,
|
||||||
|
pub parallel_tool_calls: bool,
|
||||||
|
pub previous_response_id: Option<String>,
|
||||||
|
pub reasoning: Reasoning,
|
||||||
|
pub store: bool,
|
||||||
|
pub temperature: f64,
|
||||||
|
pub text: TextFormat,
|
||||||
|
pub tool_choice: &'static str,
|
||||||
|
pub tools: Vec<serde_json::Value>,
|
||||||
|
pub top_p: f64,
|
||||||
|
pub truncation: &'static str,
|
||||||
|
pub usage: Option<Usage>,
|
||||||
|
pub user: Option<String>,
|
||||||
|
pub metadata: serde_json::Value,
|
||||||
|
/// Proxy extension: opaque thinking verification signature.
|
||||||
|
/// Present for all models. Required for multi-turn chaining with thinking models.
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub thinking_signature: Option<String>,
|
||||||
|
/// Proxy extension: the model's internal reasoning/thinking content.
|
||||||
|
/// Available for all models (Opus, Gemini Flash, Gemini Pro).
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub thinking: Option<String>,
|
||||||
|
/// Proxy extension: time spent thinking (e.g. "0.041999832s").
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub thinking_duration: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Clone)]
|
||||||
|
pub(crate) struct ResponseOutput {
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
pub output_type: &'static str,
|
||||||
|
pub id: String,
|
||||||
|
pub status: &'static str,
|
||||||
|
pub role: &'static str,
|
||||||
|
pub content: Vec<OutputContent>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Clone)]
|
||||||
|
pub(crate) struct OutputContent {
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
pub content_type: &'static str,
|
||||||
|
pub text: String,
|
||||||
|
pub annotations: Vec<serde_json::Value>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Clone)]
|
||||||
|
pub(crate) struct Usage {
|
||||||
|
pub input_tokens: u64,
|
||||||
|
pub input_tokens_details: InputTokensDetails,
|
||||||
|
pub output_tokens: u64,
|
||||||
|
pub output_tokens_details: OutputTokensDetails,
|
||||||
|
pub total_tokens: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Clone)]
|
||||||
|
pub(crate) struct InputTokensDetails {
|
||||||
|
pub cached_tokens: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Clone)]
|
||||||
|
pub(crate) struct OutputTokensDetails {
|
||||||
|
pub reasoning_tokens: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Clone)]
|
||||||
|
pub(crate) struct Reasoning {
|
||||||
|
pub effort: Option<String>,
|
||||||
|
pub summary: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Clone)]
|
||||||
|
pub(crate) struct TextFormat {
|
||||||
|
pub format: TextFormatInner,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Clone)]
|
||||||
|
pub(crate) struct TextFormatInner {
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
pub format_type: &'static str,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Usage {
|
||||||
|
/// Estimate token counts from actual text.
|
||||||
|
/// Uses ~4 chars/token heuristic (standard GPT tokenizer average).
|
||||||
|
pub fn estimate(input_text: &str, output_text: &str) -> Self {
|
||||||
|
let input_tokens = (input_text.len() as u64 + 3) / 4;
|
||||||
|
let output_tokens = (output_text.len() as u64 + 3) / 4;
|
||||||
|
Self {
|
||||||
|
input_tokens,
|
||||||
|
input_tokens_details: InputTokensDetails { cached_tokens: 0 },
|
||||||
|
output_tokens,
|
||||||
|
output_tokens_details: OutputTokensDetails {
|
||||||
|
reasoning_tokens: 0,
|
||||||
|
},
|
||||||
|
total_tokens: input_tokens + output_tokens,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for Usage {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
input_tokens: 0,
|
||||||
|
input_tokens_details: InputTokensDetails { cached_tokens: 0 },
|
||||||
|
output_tokens: 0,
|
||||||
|
output_tokens_details: OutputTokensDetails {
|
||||||
|
reasoning_tokens: 0,
|
||||||
|
},
|
||||||
|
total_tokens: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for Reasoning {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
effort: None,
|
||||||
|
summary: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for TextFormat {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
format: TextFormatInner {
|
||||||
|
format_type: "text",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── Helpers ─────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Serialize Option<u64> as either the number or JSON null (not omitted).
|
||||||
|
fn serialize_option_u64<S>(val: &Option<u64>, s: S) -> Result<S::Ok, S::Error>
|
||||||
|
where
|
||||||
|
S: serde::Serializer,
|
||||||
|
{
|
||||||
|
match val {
|
||||||
|
Some(v) => s.serialize_u64(*v),
|
||||||
|
None => s.serialize_none(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── Shared types ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
pub(crate) struct TokenRequest {
|
||||||
|
pub token: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
pub(crate) struct ErrorResponse {
|
||||||
|
pub error: ErrorDetail,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
pub(crate) struct ErrorDetail {
|
||||||
|
pub message: String,
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
pub error_type: String,
|
||||||
|
}
|
||||||
36
src/api/util.rs
Normal file
36
src/api/util.rs
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
//! Shared utilities for API handlers.
|
||||||
|
|
||||||
|
use axum::{
|
||||||
|
http::StatusCode,
|
||||||
|
response::{sse::Event, IntoResponse, Json},
|
||||||
|
};
|
||||||
|
use std::time::{SystemTime, UNIX_EPOCH};
|
||||||
|
|
||||||
|
use super::types::{ErrorDetail, ErrorResponse};
|
||||||
|
|
||||||
|
pub(crate) fn err_response(
|
||||||
|
status: StatusCode,
|
||||||
|
message: String,
|
||||||
|
error_type: &str,
|
||||||
|
) -> axum::response::Response {
|
||||||
|
let body = ErrorResponse {
|
||||||
|
error: ErrorDetail {
|
||||||
|
message,
|
||||||
|
error_type: error_type.to_string(),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
(status, Json(body)).into_response()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn now_unix() -> u64 {
|
||||||
|
SystemTime::now()
|
||||||
|
.duration_since(UNIX_EPOCH)
|
||||||
|
.unwrap_or_default()
|
||||||
|
.as_secs()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn responses_sse_event(event_type: &str, data: serde_json::Value) -> Event {
|
||||||
|
Event::default()
|
||||||
|
.event(event_type)
|
||||||
|
.data(serde_json::to_string(&data).unwrap())
|
||||||
|
}
|
||||||
462
src/backend.rs
Normal file
462
src/backend.rs
Normal file
@@ -0,0 +1,462 @@
|
|||||||
|
//! Backend: discovery of the local Antigravity language server and HTTP client.
|
||||||
|
//!
|
||||||
|
//! Uses wreq (BoringSSL) to impersonate Chrome's TLS + HTTP/2 fingerprint,
|
||||||
|
//! making our requests indistinguishable from the real Electron webview.
|
||||||
|
|
||||||
|
use crate::constants::*;
|
||||||
|
use flate2::read::{DeflateDecoder, GzDecoder};
|
||||||
|
use std::fs;
|
||||||
|
use std::io::Read;
|
||||||
|
use std::process::Command;
|
||||||
|
use std::sync::LazyLock;
|
||||||
|
use tokio::sync::RwLock;
|
||||||
|
use tracing::{debug, info, warn};
|
||||||
|
use wreq::header::{HeaderMap, HeaderName, HeaderValue};
|
||||||
|
|
||||||
|
/// Connection details for the local language server.
|
||||||
|
pub struct Backend {
|
||||||
|
inner: RwLock<BackendInner>,
|
||||||
|
client: wreq::Client,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct BackendInner {
|
||||||
|
pid: String,
|
||||||
|
csrf: String,
|
||||||
|
https_port: String,
|
||||||
|
oauth_token: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Static headers that never change — built once, in Chrome's exact emission order.
|
||||||
|
///
|
||||||
|
/// Order matters: wreq preserves insertion order in HTTP/2 HEADERS frames.
|
||||||
|
/// This matches the order captured from Chrome DevTools on the real webview.
|
||||||
|
static STATIC_HEADERS: LazyLock<HeaderMap> = LazyLock::new(|| {
|
||||||
|
let mut h = HeaderMap::with_capacity(14);
|
||||||
|
// Chrome order: Origin → UA → Accept → Accept-Encoding → Accept-Language
|
||||||
|
// → sec-ch-ua → sec-ch-ua-mobile → sec-ch-ua-platform
|
||||||
|
// → Sec-Fetch-Dest → Sec-Fetch-Mode → Sec-Fetch-Site
|
||||||
|
// → Referer → Priority → Connect-Protocol-Version
|
||||||
|
h.insert("Origin", hv("vscode-file://vscode-app"));
|
||||||
|
h.insert("User-Agent", hv(&USER_AGENT));
|
||||||
|
h.insert("Accept", hv("*/*"));
|
||||||
|
h.insert("Accept-Encoding", hv("gzip, deflate, br, zstd"));
|
||||||
|
h.insert("Accept-Language", hv("en-US"));
|
||||||
|
h.insert(
|
||||||
|
HeaderName::from_static("sec-ch-ua"),
|
||||||
|
hv(&format!(
|
||||||
|
"\"Not_A Brand\";v=\"99\", \"Chromium\";v=\"{}\"",
|
||||||
|
*CHROME_MAJOR,
|
||||||
|
)),
|
||||||
|
);
|
||||||
|
h.insert(
|
||||||
|
HeaderName::from_static("sec-ch-ua-mobile"),
|
||||||
|
hv("?0"),
|
||||||
|
);
|
||||||
|
h.insert(
|
||||||
|
HeaderName::from_static("sec-ch-ua-platform"),
|
||||||
|
hv("\"Linux\""),
|
||||||
|
);
|
||||||
|
h.insert("Sec-Fetch-Dest", hv("empty"));
|
||||||
|
h.insert("Sec-Fetch-Mode", hv("cors"));
|
||||||
|
h.insert("Sec-Fetch-Site", hv("cross-site"));
|
||||||
|
h.insert("Priority", hv("u=1, i"));
|
||||||
|
h.insert("Connect-Protocol-Version", hv("1"));
|
||||||
|
h
|
||||||
|
});
|
||||||
|
|
||||||
|
impl Backend {
|
||||||
|
/// Discover the running language server and build a BoringSSL-backed connection.
|
||||||
|
pub fn new() -> Result<Self, String> {
|
||||||
|
let inner = discover()?;
|
||||||
|
|
||||||
|
// wreq with Chrome impersonation: BoringSSL + Chrome JA3/JA4 + H2 fingerprint
|
||||||
|
let client = wreq::Client::builder()
|
||||||
|
.emulation(wreq_util::Emulation::Chrome142)
|
||||||
|
.cert_verification(false) // LS uses self-signed cert
|
||||||
|
.verify_hostname(false)
|
||||||
|
.build()
|
||||||
|
.map_err(|e| format!("wreq client build failed: {e}"))?;
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
inner: RwLock::new(inner),
|
||||||
|
client,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Re-discover language server connection details.
|
||||||
|
/// Runs blocking I/O on a spawn_blocking thread to avoid starving tokio.
|
||||||
|
pub async fn refresh(&self) -> Result<(), String> {
|
||||||
|
let new_inner = tokio::task::spawn_blocking(discover)
|
||||||
|
.await
|
||||||
|
.map_err(|e| format!("spawn_blocking failed: {e}"))??;
|
||||||
|
let mut guard = self.inner.write().await;
|
||||||
|
*guard = new_inner;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get current connection info (for startup banner).
|
||||||
|
pub async fn info(&self) -> (String, String, String, String) {
|
||||||
|
let guard = self.inner.read().await;
|
||||||
|
let token_preview = if guard.oauth_token.is_empty() {
|
||||||
|
"NOT SET".to_string()
|
||||||
|
} else {
|
||||||
|
safe_truncate(&guard.oauth_token, 20)
|
||||||
|
};
|
||||||
|
let csrf_preview = safe_truncate(&guard.csrf, 8);
|
||||||
|
(
|
||||||
|
guard.pid.clone(),
|
||||||
|
guard.https_port.clone(),
|
||||||
|
csrf_preview,
|
||||||
|
token_preview,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get current OAuth token.
|
||||||
|
///
|
||||||
|
/// Priority: token file > env var > cached value.
|
||||||
|
/// Uses async I/O for file reads. Single write-lock acquisition
|
||||||
|
/// eliminates the TOCTOU race of read-check-then-write.
|
||||||
|
pub async fn oauth_token(&self) -> String {
|
||||||
|
// Check file first (async I/O — won't block tokio)
|
||||||
|
let token_path = token_file_path();
|
||||||
|
if let Ok(contents) = tokio::fs::read_to_string(&token_path).await {
|
||||||
|
let token = contents.trim().to_string();
|
||||||
|
if !token.is_empty() && token.starts_with("ya29.") {
|
||||||
|
// Single lock: compare-and-set atomically
|
||||||
|
let mut guard = self.inner.write().await;
|
||||||
|
if guard.oauth_token != token {
|
||||||
|
info!("Token updated from file");
|
||||||
|
guard.oauth_token = token.clone();
|
||||||
|
}
|
||||||
|
return token;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Then env var
|
||||||
|
if let Ok(env_token) = std::env::var("ANTIGRAVITY_OAUTH_TOKEN") {
|
||||||
|
if !env_token.is_empty() {
|
||||||
|
let mut guard = self.inner.write().await;
|
||||||
|
if guard.oauth_token != env_token {
|
||||||
|
info!("Token updated from env var");
|
||||||
|
guard.oauth_token = env_token.clone();
|
||||||
|
}
|
||||||
|
return env_token;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
self.inner.read().await.oauth_token.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Fire-and-forget: update conversation annotations alongside SendUserCascadeMessage.
|
||||||
|
///
|
||||||
|
/// The real webview calls this after every message to track lastUserViewTime.
|
||||||
|
/// Without it, the LS sees messages without annotation updates — a fingerprint.
|
||||||
|
pub async fn update_annotations(&self, cascade_id: &str) -> Result<(), String> {
|
||||||
|
let now = chrono::Utc::now().to_rfc3339_opts(chrono::SecondsFormat::Millis, true);
|
||||||
|
let body = serde_json::json!({
|
||||||
|
"cascadeId": cascade_id,
|
||||||
|
"annotations": {
|
||||||
|
"lastUserViewTime": now
|
||||||
|
},
|
||||||
|
"mergeAnnotations": true
|
||||||
|
});
|
||||||
|
match self.call_json("UpdateConversationAnnotations", &body).await {
|
||||||
|
Ok((status, _)) => {
|
||||||
|
debug!("UpdateConversationAnnotations: {status}");
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
warn!("UpdateConversationAnnotations failed: {e}");
|
||||||
|
Err(e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set OAuth token at runtime.
|
||||||
|
pub async fn set_oauth_token(&self, token: String) {
|
||||||
|
let mut guard = self.inner.write().await;
|
||||||
|
guard.oauth_token = token;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── RPC calls ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Common headers: clone cached static + insert per-request CSRF.
|
||||||
|
fn common_headers(csrf: &str) -> HeaderMap {
|
||||||
|
let mut h = STATIC_HEADERS.clone();
|
||||||
|
if let Ok(val) = HeaderValue::from_str(csrf) {
|
||||||
|
h.insert(
|
||||||
|
HeaderName::from_static("x-codeium-csrf-token"),
|
||||||
|
val,
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
warn!("CSRF token contains invalid header characters, omitting");
|
||||||
|
}
|
||||||
|
h
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Call a JSON RPC method on the language server.
|
||||||
|
pub async fn call_json(
|
||||||
|
&self,
|
||||||
|
method: &str,
|
||||||
|
body: &serde_json::Value,
|
||||||
|
) -> Result<(u16, serde_json::Value), String> {
|
||||||
|
let (base, csrf) = {
|
||||||
|
let guard = self.inner.read().await;
|
||||||
|
(
|
||||||
|
format!("https://127.0.0.1:{}", guard.https_port),
|
||||||
|
guard.csrf.clone(),
|
||||||
|
)
|
||||||
|
};
|
||||||
|
let url = format!("{base}/{LS_SERVICE}/{method}");
|
||||||
|
let mut headers = Self::common_headers(&csrf);
|
||||||
|
headers.insert("Content-Type", HeaderValue::from_static("application/json"));
|
||||||
|
|
||||||
|
let body_bytes = serde_json::to_vec(body)
|
||||||
|
.map_err(|e| format!("JSON serialize error: {e}"))?;
|
||||||
|
|
||||||
|
let resp = self
|
||||||
|
.client
|
||||||
|
.post(&url)
|
||||||
|
.headers(headers)
|
||||||
|
.body(body_bytes)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.map_err(|e| format!("HTTP error: {e}"))?;
|
||||||
|
|
||||||
|
let status = resp.status().as_u16();
|
||||||
|
let encoding = resp
|
||||||
|
.headers()
|
||||||
|
.get("content-encoding")
|
||||||
|
.and_then(|v| v.to_str().ok())
|
||||||
|
.unwrap_or("")
|
||||||
|
.to_string();
|
||||||
|
let raw = resp.bytes().await
|
||||||
|
.map_err(|e| format!("Read body error: {e}"))?;
|
||||||
|
let resp_bytes = decompress(method, &raw, &encoding);
|
||||||
|
tracing::debug!(
|
||||||
|
"{method} response ({status}, {} bytes, enc={encoding})",
|
||||||
|
resp_bytes.len(),
|
||||||
|
);
|
||||||
|
tracing::trace!(
|
||||||
|
"{method} body: {}",
|
||||||
|
String::from_utf8_lossy(&resp_bytes[..resp_bytes.len().min(200)])
|
||||||
|
);
|
||||||
|
let data: serde_json::Value = match serde_json::from_slice(&resp_bytes) {
|
||||||
|
Ok(v) => v,
|
||||||
|
Err(e) => {
|
||||||
|
tracing::warn!("{method} response is not valid JSON: {e}");
|
||||||
|
serde_json::Value::Object(serde_json::Map::new())
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Ok((status, data))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Call a binary protobuf RPC method.
|
||||||
|
pub async fn call_proto(
|
||||||
|
&self,
|
||||||
|
method: &str,
|
||||||
|
body: Vec<u8>,
|
||||||
|
) -> Result<(u16, Vec<u8>), String> {
|
||||||
|
let (base, csrf) = {
|
||||||
|
let guard = self.inner.read().await;
|
||||||
|
(
|
||||||
|
format!("https://127.0.0.1:{}", guard.https_port),
|
||||||
|
guard.csrf.clone(),
|
||||||
|
)
|
||||||
|
};
|
||||||
|
let url = format!("{base}/{LS_SERVICE}/{method}");
|
||||||
|
let mut headers = Self::common_headers(&csrf);
|
||||||
|
headers.insert("Content-Type", HeaderValue::from_static("application/proto"));
|
||||||
|
|
||||||
|
let resp = self
|
||||||
|
.client
|
||||||
|
.post(&url)
|
||||||
|
.headers(headers)
|
||||||
|
.body(body)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.map_err(|e| format!("HTTP error: {e}"))?;
|
||||||
|
|
||||||
|
let status = resp.status().as_u16();
|
||||||
|
let encoding = resp
|
||||||
|
.headers()
|
||||||
|
.get("content-encoding")
|
||||||
|
.and_then(|v| v.to_str().ok())
|
||||||
|
.unwrap_or("")
|
||||||
|
.to_string();
|
||||||
|
let raw = resp
|
||||||
|
.bytes()
|
||||||
|
.await
|
||||||
|
.map_err(|e| format!("Read body error: {e}"))?;
|
||||||
|
let decompressed = decompress(method, &raw, &encoding);
|
||||||
|
Ok((status, decompressed))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// StartCascade → returns cascade_id.
|
||||||
|
pub async fn create_cascade(&self) -> Result<String, String> {
|
||||||
|
let body = serde_json::json!({"prompt": "new chat"});
|
||||||
|
let (status, data) = self.call_json("StartCascade", &body).await?;
|
||||||
|
if status != 200 {
|
||||||
|
return Err(format!("StartCascade failed: {status} — {data}"));
|
||||||
|
}
|
||||||
|
tracing::debug!("StartCascade response: {data}");
|
||||||
|
data["cascadeId"]
|
||||||
|
.as_str()
|
||||||
|
.map(|s| s.to_string())
|
||||||
|
.ok_or_else(|| format!("Missing cascadeId in response: {data}"))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// SendUserCascadeMessage with binary protobuf body.
|
||||||
|
pub async fn send_message(
|
||||||
|
&self,
|
||||||
|
cascade_id: &str,
|
||||||
|
text: &str,
|
||||||
|
model_enum: u32,
|
||||||
|
) -> Result<(u16, Vec<u8>), String> {
|
||||||
|
let token = self.oauth_token().await;
|
||||||
|
if token.is_empty() {
|
||||||
|
return Err("No OAuth token available".to_string());
|
||||||
|
}
|
||||||
|
let proto = crate::proto::build_request(cascade_id, text, &token, model_enum);
|
||||||
|
self.call_proto("SendUserCascadeMessage", proto).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// GetCascadeTrajectorySteps → JSON with steps array.
|
||||||
|
pub async fn get_steps(
|
||||||
|
&self,
|
||||||
|
cascade_id: &str,
|
||||||
|
) -> Result<(u16, serde_json::Value), String> {
|
||||||
|
let body = serde_json::json!({"cascadeId": cascade_id});
|
||||||
|
self.call_json("GetCascadeTrajectorySteps", &body).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// GetCascadeTrajectory → JSON with trajectory status.
|
||||||
|
pub async fn get_trajectory(
|
||||||
|
&self,
|
||||||
|
cascade_id: &str,
|
||||||
|
) -> Result<(u16, serde_json::Value), String> {
|
||||||
|
let body = serde_json::json!({"cascadeId": cascade_id});
|
||||||
|
self.call_json("GetCascadeTrajectory", &body).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── Discovery helpers ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
fn discover() -> Result<BackendInner, String> {
|
||||||
|
let pid_output = Command::new("sh")
|
||||||
|
.args(["-c", "pgrep -f language_server_linux | head -1"])
|
||||||
|
.output()
|
||||||
|
.map_err(|e| format!("pgrep failed: {e}"))?;
|
||||||
|
|
||||||
|
let pid = String::from_utf8_lossy(&pid_output.stdout)
|
||||||
|
.trim()
|
||||||
|
.to_string();
|
||||||
|
if pid.is_empty() {
|
||||||
|
return Err("Language server not running".to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
let cmdline = fs::read(format!("/proc/{pid}/cmdline"))
|
||||||
|
.map_err(|e| format!("Can't read cmdline for PID {pid}: {e}"))?;
|
||||||
|
let args: Vec<&[u8]> = cmdline.split(|&b| b == 0).collect();
|
||||||
|
let mut csrf = String::new();
|
||||||
|
for (i, arg) in args.iter().enumerate() {
|
||||||
|
if let Ok(s) = std::str::from_utf8(arg) {
|
||||||
|
if s == "--csrf_token" {
|
||||||
|
if let Some(next) = args.get(i + 1) {
|
||||||
|
if let Ok(token) = std::str::from_utf8(next) {
|
||||||
|
csrf = token.to_string();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let csrf_preview = safe_truncate(&csrf, 8);
|
||||||
|
debug!("Discovered LS PID={pid}, CSRF={csrf_preview}");
|
||||||
|
|
||||||
|
let log_base = log_base();
|
||||||
|
let mut https_port = String::new();
|
||||||
|
|
||||||
|
if let Ok(mut entries) = fs::read_dir(&log_base) {
|
||||||
|
let mut dirs: Vec<String> = Vec::new();
|
||||||
|
while let Some(Ok(entry)) = entries.next() {
|
||||||
|
let name = entry.file_name().to_string_lossy().to_string();
|
||||||
|
if name.starts_with("202") {
|
||||||
|
dirs.push(name);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
dirs.sort_unstable_by(|a, b| b.cmp(a));
|
||||||
|
|
||||||
|
static PORT_RE: LazyLock<regex::Regex> =
|
||||||
|
LazyLock::new(|| regex::Regex::new(r"port at (\d+) for HTTPS").unwrap());
|
||||||
|
|
||||||
|
for d in &dirs {
|
||||||
|
let log_path = format!(
|
||||||
|
"{log_base}/{d}/window1/exthost/google.antigravity/Antigravity.log"
|
||||||
|
);
|
||||||
|
if let Ok(contents) = fs::read_to_string(&log_path) {
|
||||||
|
for line in contents.lines() {
|
||||||
|
if line.contains(&pid) && line.contains("listening") && line.contains("HTTPS") {
|
||||||
|
if let Some(caps) = PORT_RE.captures(line) {
|
||||||
|
https_port = caps[1].to_string();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !https_port.is_empty() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if https_port.is_empty() {
|
||||||
|
warn!("Could not find HTTPS port in logs, defaulting to 3100");
|
||||||
|
https_port = "3100".to_string();
|
||||||
|
}
|
||||||
|
|
||||||
|
let oauth_token = std::env::var("ANTIGRAVITY_OAUTH_TOKEN")
|
||||||
|
.ok()
|
||||||
|
.filter(|s| !s.is_empty())
|
||||||
|
.or_else(|| {
|
||||||
|
let home = std::env::var("HOME").unwrap_or_default();
|
||||||
|
let path = format!("{home}/.config/antigravity-proxy-token");
|
||||||
|
fs::read_to_string(&path)
|
||||||
|
.ok()
|
||||||
|
.map(|s| s.trim().to_string())
|
||||||
|
.filter(|s| !s.is_empty())
|
||||||
|
})
|
||||||
|
.unwrap_or_default();
|
||||||
|
|
||||||
|
Ok(BackendInner {
|
||||||
|
pid,
|
||||||
|
csrf,
|
||||||
|
https_port,
|
||||||
|
oauth_token,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Shorthand for HeaderValue (panics on invalid — only for known-safe static values).
|
||||||
|
fn hv(s: &str) -> HeaderValue {
|
||||||
|
HeaderValue::from_str(s).expect("invalid header value in static constant")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Decompress response bytes based on Content-Encoding header.
|
||||||
|
fn decompress(method: &str, data: &[u8], encoding: &str) -> Vec<u8> {
|
||||||
|
let mut out = Vec::new();
|
||||||
|
let res = match encoding {
|
||||||
|
"gzip" => GzDecoder::new(data).read_to_end(&mut out),
|
||||||
|
"deflate" => DeflateDecoder::new(data).read_to_end(&mut out),
|
||||||
|
"br" => brotli::Decompressor::new(data, 4096).read_to_end(&mut out),
|
||||||
|
_ => return data.to_vec(),
|
||||||
|
};
|
||||||
|
|
||||||
|
match res {
|
||||||
|
Ok(_) => out,
|
||||||
|
Err(e) => {
|
||||||
|
if !encoding.is_empty() {
|
||||||
|
let preview = String::from_utf8_lossy(&data[..data.len().min(100)]);
|
||||||
|
warn!("{method}: {encoding} decompress failed ({} bytes): {e}. Raw: {}", data.len(), preview);
|
||||||
|
}
|
||||||
|
data.to_vec()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
217
src/constants.rs
Normal file
217
src/constants.rs
Normal file
@@ -0,0 +1,217 @@
|
|||||||
|
//! Shared constants — auto-detected from the installed Antigravity binary at startup.
|
||||||
|
//!
|
||||||
|
//! On first access, we locate the Antigravity installation (via the running
|
||||||
|
//! language server PID or well-known paths), parse `product.json` for version
|
||||||
|
//! strings, and extract Chrome/Electron versions from the binary. If detection
|
||||||
|
//! fails, we fall back to hardcoded values.
|
||||||
|
|
||||||
|
use std::fs;
|
||||||
|
use std::process::Command;
|
||||||
|
use std::sync::LazyLock;
|
||||||
|
|
||||||
|
/// Auto-detected version info from the installed Antigravity app.
|
||||||
|
struct DetectedVersions {
|
||||||
|
antigravity: String,
|
||||||
|
chrome: String,
|
||||||
|
electron: String,
|
||||||
|
client: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Locate the Antigravity install directory by tracing the language server PID
|
||||||
|
/// back to its binary, then walking up to the app root. Falls back to
|
||||||
|
/// well-known install paths.
|
||||||
|
fn find_install_dir() -> Option<String> {
|
||||||
|
// 1. Try tracing the running language server → /usr/share/antigravity/resources/app/extensions/...
|
||||||
|
if let Ok(output) = Command::new("sh")
|
||||||
|
.args(["-c", "pgrep -f language_server_linux | head -1"])
|
||||||
|
.output()
|
||||||
|
{
|
||||||
|
let pid = String::from_utf8_lossy(&output.stdout).trim().to_string();
|
||||||
|
if !pid.is_empty() {
|
||||||
|
if let Ok(exe) = fs::read_link(format!("/proc/{pid}/exe")) {
|
||||||
|
let exe_str = exe.to_string_lossy().to_string();
|
||||||
|
// exe is like: /usr/share/antigravity/resources/app/extensions/antigravity/bin/language_server_linux_x64
|
||||||
|
// We want: /usr/share/antigravity
|
||||||
|
if let Some(idx) = exe_str.find("/resources/") {
|
||||||
|
return Some(exe_str[..idx].to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Fall back to well-known install paths
|
||||||
|
for path in &["/usr/share/antigravity", "/opt/Antigravity"] {
|
||||||
|
if fs::metadata(format!("{path}/resources/app/product.json")).is_ok() {
|
||||||
|
return Some(path.to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Read `product.json` from the install dir and extract version fields.
|
||||||
|
fn read_product_json(install_dir: &str) -> (Option<String>, Option<String>) {
|
||||||
|
let path = format!("{install_dir}/resources/app/product.json");
|
||||||
|
let Ok(contents) = fs::read_to_string(&path) else {
|
||||||
|
return (None, None);
|
||||||
|
};
|
||||||
|
let Ok(json) = serde_json::from_str::<serde_json::Value>(&contents) else {
|
||||||
|
return (None, None);
|
||||||
|
};
|
||||||
|
|
||||||
|
let version = json["version"].as_str().map(|s| s.to_string());
|
||||||
|
let ide_version = json["ideVersion"].as_str().map(|s| s.to_string());
|
||||||
|
(version, ide_version)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Extract Chrome and Electron versions from the main binary via `strings`.
|
||||||
|
/// Pattern: "Chrome/142.0.7444.175", "Electron/39.2.3".
|
||||||
|
fn extract_binary_versions(install_dir: &str) -> (Option<String>, Option<String>) {
|
||||||
|
let binary = format!("{install_dir}/antigravity");
|
||||||
|
if fs::metadata(&binary).is_err() {
|
||||||
|
return (None, None);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use grep -oP on the binary to avoid loading the whole thing into memory
|
||||||
|
let chrome = Command::new("sh")
|
||||||
|
.args([
|
||||||
|
"-c",
|
||||||
|
&format!(
|
||||||
|
"strings '{}' | grep -oP 'Chrome/[0-9]+\\.[0-9]+\\.[0-9]+\\.[0-9]+' | head -1",
|
||||||
|
binary
|
||||||
|
),
|
||||||
|
])
|
||||||
|
.output()
|
||||||
|
.ok()
|
||||||
|
.and_then(|o| {
|
||||||
|
let s = String::from_utf8_lossy(&o.stdout).trim().to_string();
|
||||||
|
s.strip_prefix("Chrome/").map(|v| v.to_string())
|
||||||
|
});
|
||||||
|
|
||||||
|
let electron = Command::new("sh")
|
||||||
|
.args([
|
||||||
|
"-c",
|
||||||
|
&format!(
|
||||||
|
"strings '{}' | grep -oP 'Electron/[0-9]+\\.[0-9]+\\.[0-9]+' | head -1",
|
||||||
|
binary
|
||||||
|
),
|
||||||
|
])
|
||||||
|
.output()
|
||||||
|
.ok()
|
||||||
|
.and_then(|o| {
|
||||||
|
let s = String::from_utf8_lossy(&o.stdout).trim().to_string();
|
||||||
|
s.strip_prefix("Electron/").map(|v| v.to_string())
|
||||||
|
});
|
||||||
|
|
||||||
|
(chrome, electron)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Detect all versions from the installed Antigravity app.
|
||||||
|
fn detect_versions() -> DetectedVersions {
|
||||||
|
// Hardcoded fallbacks — last known good values
|
||||||
|
const FALLBACK_ANTIGRAVITY: &str = "1.107.0";
|
||||||
|
const FALLBACK_CHROME: &str = "142.0.7444.175";
|
||||||
|
const FALLBACK_ELECTRON: &str = "39.2.3";
|
||||||
|
const FALLBACK_CLIENT: &str = "1.16.5";
|
||||||
|
|
||||||
|
let Some(install_dir) = find_install_dir() else {
|
||||||
|
eprintln!(
|
||||||
|
"[constants] ⚠ Could not find Antigravity install — using fallback versions"
|
||||||
|
);
|
||||||
|
return DetectedVersions {
|
||||||
|
antigravity: FALLBACK_ANTIGRAVITY.to_string(),
|
||||||
|
chrome: FALLBACK_CHROME.to_string(),
|
||||||
|
electron: FALLBACK_ELECTRON.to_string(),
|
||||||
|
client: FALLBACK_CLIENT.to_string(),
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
// product.json → antigravity version + client/IDE version
|
||||||
|
let (ag_ver, client_ver) = read_product_json(&install_dir);
|
||||||
|
|
||||||
|
// Binary → Chrome + Electron versions
|
||||||
|
let (chrome_ver, electron_ver) = extract_binary_versions(&install_dir);
|
||||||
|
|
||||||
|
let versions = DetectedVersions {
|
||||||
|
antigravity: ag_ver.unwrap_or_else(|| FALLBACK_ANTIGRAVITY.to_string()),
|
||||||
|
chrome: chrome_ver.unwrap_or_else(|| FALLBACK_CHROME.to_string()),
|
||||||
|
electron: electron_ver.unwrap_or_else(|| FALLBACK_ELECTRON.to_string()),
|
||||||
|
client: client_ver.unwrap_or_else(|| FALLBACK_CLIENT.to_string()),
|
||||||
|
};
|
||||||
|
|
||||||
|
eprintln!(
|
||||||
|
"[constants] ✓ Detected versions: Antigravity={}, Chrome={}, Electron={}, Client={}",
|
||||||
|
versions.antigravity, versions.chrome, versions.electron, versions.client
|
||||||
|
);
|
||||||
|
|
||||||
|
versions
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── Public API ──────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// All detected versions — computed once on first access.
|
||||||
|
static VERSIONS: LazyLock<DetectedVersions> = LazyLock::new(detect_versions);
|
||||||
|
|
||||||
|
/// Antigravity app version (e.g. "1.107.0").
|
||||||
|
pub fn antigravity_version() -> &'static str {
|
||||||
|
&VERSIONS.antigravity
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Chrome version bundled with Electron (e.g. "142.0.7444.175").
|
||||||
|
pub fn chrome_version() -> &'static str {
|
||||||
|
&VERSIONS.chrome
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Electron version (e.g. "39.2.3").
|
||||||
|
pub fn electron_version() -> &'static str {
|
||||||
|
&VERSIONS.electron
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Client/IDE version from product.json (e.g. "1.16.5").
|
||||||
|
pub fn client_version() -> &'static str {
|
||||||
|
&VERSIONS.client
|
||||||
|
}
|
||||||
|
|
||||||
|
pub const CLIENT_NAME: &str = "antigravity";
|
||||||
|
pub const LS_SERVICE: &str = "exa.language_server_pb.LanguageServerService";
|
||||||
|
|
||||||
|
/// Log base directory for Antigravity.
|
||||||
|
pub fn log_base() -> String {
|
||||||
|
let home = std::env::var("HOME").unwrap_or_else(|_| "/root".to_string());
|
||||||
|
format!("{home}/.config/Antigravity/logs")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Token file path.
|
||||||
|
pub fn token_file_path() -> String {
|
||||||
|
let home = std::env::var("HOME").unwrap_or_else(|_| "/root".to_string());
|
||||||
|
format!("{home}/.config/antigravity-proxy-token")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// User-Agent string matching the Electron webview — computed once.
|
||||||
|
pub static USER_AGENT: LazyLock<String> = LazyLock::new(|| {
|
||||||
|
format!(
|
||||||
|
"Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 \
|
||||||
|
(KHTML, like Gecko) Antigravity/{} \
|
||||||
|
Chrome/{} Electron/{} Safari/537.36",
|
||||||
|
antigravity_version(),
|
||||||
|
chrome_version(),
|
||||||
|
electron_version()
|
||||||
|
)
|
||||||
|
});
|
||||||
|
|
||||||
|
/// Chrome major version for sec-ch-ua header — computed once.
|
||||||
|
pub static CHROME_MAJOR: LazyLock<String> = LazyLock::new(|| {
|
||||||
|
chrome_version()
|
||||||
|
.split('.')
|
||||||
|
.next()
|
||||||
|
.unwrap_or("142")
|
||||||
|
.to_string()
|
||||||
|
});
|
||||||
|
|
||||||
|
/// Safely truncate a string to at most `max` characters (not bytes).
|
||||||
|
pub fn safe_truncate(s: &str, max: usize) -> String {
|
||||||
|
match s.char_indices().nth(max) {
|
||||||
|
None => s.to_string(),
|
||||||
|
Some((idx, _)) => format!("{}...", &s[..idx]),
|
||||||
|
}
|
||||||
|
}
|
||||||
332
src/main.rs
Normal file
332
src/main.rs
Normal file
@@ -0,0 +1,332 @@
|
|||||||
|
//! Antigravity OpenAI Proxy — Rust edition v3 (stealth hardened).
|
||||||
|
//!
|
||||||
|
//! Single-binary replacement for server.py. BoringSSL TLS impersonation,
|
||||||
|
//! byte-exact protobuf encoding, Chrome header fingerprinting, cascade
|
||||||
|
//! session management, warmup + heartbeat lifecycle mimicry.
|
||||||
|
|
||||||
|
mod api;
|
||||||
|
mod backend;
|
||||||
|
mod constants;
|
||||||
|
mod mitm;
|
||||||
|
mod proto;
|
||||||
|
mod quota;
|
||||||
|
mod session;
|
||||||
|
mod warmup;
|
||||||
|
|
||||||
|
use api::AppState;
|
||||||
|
use backend::Backend;
|
||||||
|
use clap::Parser;
|
||||||
|
use session::SessionManager;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tracing::{info, warn};
|
||||||
|
|
||||||
|
use mitm::store::MitmStore;
|
||||||
|
|
||||||
|
#[derive(Parser)]
|
||||||
|
#[command(name = "antigravity-proxy", about = "Antigravity OpenAI Proxy (stealth)")]
|
||||||
|
struct Cli {
|
||||||
|
/// Port to listen on
|
||||||
|
#[arg(long, default_value_t = 8741)]
|
||||||
|
port: u16,
|
||||||
|
|
||||||
|
/// Enable info-level logging (-v)
|
||||||
|
#[arg(short, long)]
|
||||||
|
verbose: bool,
|
||||||
|
|
||||||
|
/// Enable debug-level logging (-d)
|
||||||
|
#[arg(short, long)]
|
||||||
|
debug: bool,
|
||||||
|
|
||||||
|
/// Disable the MITM proxy (no API interception)
|
||||||
|
#[arg(long)]
|
||||||
|
no_mitm: bool,
|
||||||
|
|
||||||
|
/// MITM proxy port (default: 8742, matches wrapper script)
|
||||||
|
#[arg(long, default_value_t = 8742)]
|
||||||
|
mitm_port: u16,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() {
|
||||||
|
let cli = Cli::parse();
|
||||||
|
|
||||||
|
// Flag > env var > default (warn)
|
||||||
|
let log_level = if cli.debug {
|
||||||
|
"debug"
|
||||||
|
} else if cli.verbose {
|
||||||
|
"info"
|
||||||
|
} else {
|
||||||
|
// Fall back to RUST_LOG env, or warn-only
|
||||||
|
""
|
||||||
|
};
|
||||||
|
|
||||||
|
let filter = if log_level.is_empty() {
|
||||||
|
tracing_subscriber::EnvFilter::try_from_default_env()
|
||||||
|
.unwrap_or_else(|_| "warn".into())
|
||||||
|
} else {
|
||||||
|
tracing_subscriber::EnvFilter::new(log_level)
|
||||||
|
};
|
||||||
|
|
||||||
|
tracing_subscriber::fmt()
|
||||||
|
.with_env_filter(filter)
|
||||||
|
.init();
|
||||||
|
|
||||||
|
// ── Step 1: Bind main port FIRST (fail fast, before spawning anything) ────
|
||||||
|
let addr = format!("127.0.0.1:{}", cli.port);
|
||||||
|
let listener = match tokio::net::TcpListener::bind(&addr).await {
|
||||||
|
Ok(l) => l,
|
||||||
|
Err(e) => {
|
||||||
|
eprintln!("Fatal: cannot bind to {addr}: {e}");
|
||||||
|
eprintln!("Hint: kill $(lsof -ti:{}) 2>/dev/null", cli.port);
|
||||||
|
std::process::exit(1);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// ── Step 2: Backend discovery ─────────────────────────────────────────────
|
||||||
|
let backend = Arc::new(match Backend::new() {
|
||||||
|
Ok(b) => b,
|
||||||
|
Err(e) => {
|
||||||
|
eprintln!("Fatal: {e}");
|
||||||
|
std::process::exit(1);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let (pid, https_port, csrf, token) = backend.info().await;
|
||||||
|
|
||||||
|
// ── Step 3: MITM proxy (after port is secured) ────────────────────────────
|
||||||
|
let mitm_store = MitmStore::new();
|
||||||
|
let (mitm_port_actual, mitm_handle) = if !cli.no_mitm {
|
||||||
|
let data_dir = dirs_data_dir();
|
||||||
|
match mitm::ca::MitmCa::load_or_generate(&data_dir) {
|
||||||
|
Ok(ca) => {
|
||||||
|
let ca = Arc::new(ca);
|
||||||
|
let ca_pem = ca.ca_pem_path.display().to_string();
|
||||||
|
let config = mitm::proxy::MitmConfig {
|
||||||
|
port: cli.mitm_port,
|
||||||
|
modify_requests: false,
|
||||||
|
};
|
||||||
|
match mitm::proxy::run(ca, mitm_store.clone(), config).await {
|
||||||
|
Ok((port, handle)) => {
|
||||||
|
info!(port, ca = %ca_pem, "MITM proxy started");
|
||||||
|
// Write actual port to file for wrapper script discovery
|
||||||
|
let port_file = data_dir.join("mitm-port");
|
||||||
|
if let Err(e) = std::fs::write(&port_file, port.to_string()) {
|
||||||
|
warn!("Failed to write MITM port file: {e}");
|
||||||
|
}
|
||||||
|
(Some((port, ca_pem)), Some(handle))
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
warn!("MITM proxy failed to start: {e}");
|
||||||
|
(None, None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
warn!("MITM CA generation failed: {e}");
|
||||||
|
(None, None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
info!("MITM proxy disabled (--no-mitm)");
|
||||||
|
(None, None)
|
||||||
|
};
|
||||||
|
|
||||||
|
// ── Step 4: Warmup + heartbeat ────────────────────────────────────────────
|
||||||
|
warmup::warmup_sequence(&backend).await;
|
||||||
|
let heartbeat_handle = warmup::start_heartbeat(Arc::clone(&backend));
|
||||||
|
|
||||||
|
// ── Step 4b: Quota monitor ────────────────────────────────────────────────
|
||||||
|
let quota_store = quota::QuotaStore::new();
|
||||||
|
quota_store.clone().start_polling(Arc::clone(&backend));
|
||||||
|
info!("Quota monitor started (polling every 60s)");
|
||||||
|
|
||||||
|
let state = Arc::new(AppState {
|
||||||
|
backend,
|
||||||
|
sessions: SessionManager::new(),
|
||||||
|
mitm_store,
|
||||||
|
quota_store,
|
||||||
|
});
|
||||||
|
|
||||||
|
// Periodic backend refresh — keeps LS connection details fresh
|
||||||
|
let refresh_backend = Arc::clone(&state.backend);
|
||||||
|
let refresh_handle = tokio::spawn(async move {
|
||||||
|
loop {
|
||||||
|
tokio::time::sleep(tokio::time::Duration::from_secs(60)).await;
|
||||||
|
if let Err(e) = refresh_backend.refresh().await {
|
||||||
|
warn!("Periodic refresh failed: {e}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// ── Step 5: Start serving ─────────────────────────────────────────────────
|
||||||
|
let app = api::router(state.clone());
|
||||||
|
|
||||||
|
print_banner(cli.port, &pid, &https_port, &csrf, &token, &mitm_port_actual);
|
||||||
|
info!("Listening on http://{addr}");
|
||||||
|
|
||||||
|
axum::serve(listener, app)
|
||||||
|
.with_graceful_shutdown(shutdown_signal())
|
||||||
|
.await
|
||||||
|
.expect("server error");
|
||||||
|
|
||||||
|
// ── Cleanup: abort all background tasks ───────────────────────────────────
|
||||||
|
heartbeat_handle.abort();
|
||||||
|
refresh_handle.abort();
|
||||||
|
if let Some(h) = mitm_handle {
|
||||||
|
h.abort();
|
||||||
|
}
|
||||||
|
// Remove stale MITM port file
|
||||||
|
let _ = std::fs::remove_file(dirs_data_dir().join("mitm-port"));
|
||||||
|
info!("Server shutdown complete");
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Wait for SIGINT (Ctrl+C) or SIGTERM for graceful shutdown.
|
||||||
|
async fn shutdown_signal() {
|
||||||
|
let ctrl_c = async {
|
||||||
|
tokio::signal::ctrl_c()
|
||||||
|
.await
|
||||||
|
.expect("failed to install Ctrl+C handler");
|
||||||
|
};
|
||||||
|
|
||||||
|
#[cfg(unix)]
|
||||||
|
let terminate = async {
|
||||||
|
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
|
||||||
|
.expect("failed to install SIGTERM handler")
|
||||||
|
.recv()
|
||||||
|
.await;
|
||||||
|
};
|
||||||
|
|
||||||
|
#[cfg(not(unix))]
|
||||||
|
let terminate = std::future::pending::<()>();
|
||||||
|
|
||||||
|
tokio::select! {
|
||||||
|
_ = ctrl_c => info!("Received SIGINT, shutting down..."),
|
||||||
|
_ = terminate => info!("Received SIGTERM, shutting down..."),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn print_banner(port: u16, pid: &str, https_port: &str, csrf: &str, token: &str, mitm: &Option<(u16, String)>) {
|
||||||
|
let chrome_major = &*constants::CHROME_MAJOR;
|
||||||
|
let ver = crate::constants::antigravity_version();
|
||||||
|
|
||||||
|
println!();
|
||||||
|
println!(" \x1b[1;35m>> antigravity-proxy\x1b[0m \x1b[2mv{ver}\x1b[0m");
|
||||||
|
println!(" \x1b[2m────────────────────────────────────────────────\x1b[0m");
|
||||||
|
println!();
|
||||||
|
println!(" \x1b[1mcore\x1b[0m");
|
||||||
|
println!(" \x1b[36m tls\x1b[0m BoringSSL (Chrome {chrome_major})");
|
||||||
|
println!(" \x1b[36m listen\x1b[0m http://127.0.0.1:{port}");
|
||||||
|
println!(" \x1b[36m ls pid\x1b[0m {pid}");
|
||||||
|
println!(" \x1b[36m https\x1b[0m :{https_port}");
|
||||||
|
println!(" \x1b[36m csrf\x1b[0m {csrf}");
|
||||||
|
println!(" \x1b[36m oauth\x1b[0m {token}");
|
||||||
|
println!();
|
||||||
|
|
||||||
|
// MITM section
|
||||||
|
if let Some((mitm_port, ca_path)) = mitm {
|
||||||
|
println!(" \x1b[1mmitm\x1b[0m");
|
||||||
|
println!(" \x1b[36m proxy\x1b[0m 127.0.0.1:{mitm_port}");
|
||||||
|
println!(" \x1b[36m ca cert\x1b[0m {ca_path}");
|
||||||
|
|
||||||
|
// Check if wrapper is installed
|
||||||
|
let wrapper_installed = check_wrapper_installed();
|
||||||
|
if wrapper_installed {
|
||||||
|
println!(" \x1b[36m wrapper\x1b[0m \x1b[32minstalled\x1b[0m");
|
||||||
|
} else {
|
||||||
|
println!(" \x1b[36m wrapper\x1b[0m \x1b[33mnot installed\x1b[0m");
|
||||||
|
}
|
||||||
|
println!();
|
||||||
|
} else {
|
||||||
|
println!(" \x1b[1mmitm\x1b[0m \x1b[33mdisabled\x1b[0m");
|
||||||
|
println!();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Routes
|
||||||
|
println!(" \x1b[1mroutes\x1b[0m");
|
||||||
|
println!(" \x1b[33m POST\x1b[0m /v1/responses");
|
||||||
|
println!(" \x1b[33m POST\x1b[0m /v1/chat/completions");
|
||||||
|
println!(" \x1b[32m GET \x1b[0m /v1/models");
|
||||||
|
println!(" \x1b[32m GET \x1b[0m /v1/sessions");
|
||||||
|
println!(" \x1b[31m DEL \x1b[0m /v1/sessions/:id");
|
||||||
|
println!(" \x1b[33m POST\x1b[0m /v1/token");
|
||||||
|
println!(" \x1b[32m GET \x1b[0m /v1/usage");
|
||||||
|
println!(" \x1b[32m GET \x1b[0m /v1/quota");
|
||||||
|
println!(" \x1b[32m GET \x1b[0m /health");
|
||||||
|
println!();
|
||||||
|
|
||||||
|
// Status line
|
||||||
|
let mitm_tag = if mitm.is_some() { "\x1b[32mmitm\x1b[0m" } else { "\x1b[31mmitm\x1b[0m" };
|
||||||
|
println!(" \x1b[2mstealth:\x1b[0m \x1b[32mwarmup\x1b[0m \x1b[32mheartbeat\x1b[0m \x1b[32mjitter\x1b[0m {mitm_tag}");
|
||||||
|
println!();
|
||||||
|
|
||||||
|
// Setup hints
|
||||||
|
if let Some((mitm_port, ca_path)) = mitm {
|
||||||
|
if !check_wrapper_installed() {
|
||||||
|
println!(" \x1b[1;33m[!]\x1b[0m mitm wrapper not installed");
|
||||||
|
println!(" \x1b[2mrun:\x1b[0m ./scripts/mitm-wrapper.sh install");
|
||||||
|
println!(" \x1b[2mor:\x1b[0m HTTPS_PROXY=http://127.0.0.1:{mitm_port}");
|
||||||
|
println!(" NODE_EXTRA_CA_CERTS={ca_path}");
|
||||||
|
println!();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if token == "NOT SET" {
|
||||||
|
println!(" \x1b[1;33m[!]\x1b[0m no oauth token");
|
||||||
|
println!(" export ANTIGRAVITY_OAUTH_TOKEN=ya29.xxx");
|
||||||
|
println!(" curl -X POST http://127.0.0.1:{port}/v1/token -d '{{\"token\":\"ya29.xxx\"}}'");
|
||||||
|
println!(" echo 'ya29.xxx' > ~/.config/antigravity-proxy-token");
|
||||||
|
println!();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if the MITM wrapper is installed by looking for the .real backup file
|
||||||
|
/// next to the LS binary. Uses /proc to find the real LS path dynamically.
|
||||||
|
fn check_wrapper_installed() -> bool {
|
||||||
|
// Find the LS binary path from known PID or by scanning /proc
|
||||||
|
if let Some(ls_path) = find_ls_binary_path() {
|
||||||
|
let real_path = format!("{ls_path}.real");
|
||||||
|
return std::path::Path::new(&real_path).exists();
|
||||||
|
}
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Find the LS binary path by reading /proc/<pid>/exe for known language server processes.
|
||||||
|
fn find_ls_binary_path() -> Option<String> {
|
||||||
|
// Try all running processes, look for ones that look like the LS
|
||||||
|
let proc = std::path::Path::new("/proc");
|
||||||
|
if !proc.exists() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Ok(entries) = std::fs::read_dir(proc) {
|
||||||
|
for entry in entries.flatten() {
|
||||||
|
let name = entry.file_name();
|
||||||
|
let name_str = name.to_string_lossy();
|
||||||
|
// Only look at numeric dirs (PIDs)
|
||||||
|
if !name_str.chars().all(|c| c.is_ascii_digit()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let exe_link = entry.path().join("exe");
|
||||||
|
if let Ok(target) = std::fs::read_link(&exe_link) {
|
||||||
|
let target_str = target.to_string_lossy();
|
||||||
|
// Strip " (deleted)" suffix from unlinked binaries
|
||||||
|
let target_clean = target_str.trim_end_matches(" (deleted)");
|
||||||
|
// Match any binary that looks like the Antigravity LS
|
||||||
|
if target_clean.contains("language_server_linux")
|
||||||
|
|| target_clean.contains("antigravity-language-server")
|
||||||
|
{
|
||||||
|
// Strip .real suffix — if the wrapper exec'd the backup, we want the base name
|
||||||
|
let path = target_clean.trim_end_matches(".real");
|
||||||
|
return Some(path.to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the data directory for storing MITM CA cert/key.
|
||||||
|
fn dirs_data_dir() -> std::path::PathBuf {
|
||||||
|
let home = std::env::var("HOME").unwrap_or_else(|_| "/tmp".to_string());
|
||||||
|
std::path::PathBuf::from(home).join(".config").join("antigravity-proxy")
|
||||||
|
}
|
||||||
218
src/mitm/ca.rs
Normal file
218
src/mitm/ca.rs
Normal file
@@ -0,0 +1,218 @@
|
|||||||
|
//! Certificate Authority for MITM proxy.
|
||||||
|
//!
|
||||||
|
//! Generates a self-signed root CA at first run and caches it to disk.
|
||||||
|
//! Dynamically generates per-domain leaf certificates signed by this CA.
|
||||||
|
|
||||||
|
use rcgen::{
|
||||||
|
BasicConstraints, CertificateParams, DistinguishedName, DnType, ExtendedKeyUsagePurpose,
|
||||||
|
IsCa, KeyPair, KeyUsagePurpose, SanType,
|
||||||
|
};
|
||||||
|
use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::path::{Path, PathBuf};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tokio::sync::RwLock;
|
||||||
|
use tracing::info;
|
||||||
|
|
||||||
|
/// MITM Certificate Authority.
|
||||||
|
pub struct MitmCa {
|
||||||
|
/// Root CA certificate (DER-encoded for rustls).
|
||||||
|
ca_cert_der: CertificateDer<'static>,
|
||||||
|
/// Root CA private key.
|
||||||
|
ca_key: KeyPair,
|
||||||
|
/// Signed root CA cert (needed by rcgen to sign leaf certs).
|
||||||
|
ca_signed: rcgen::Certificate,
|
||||||
|
/// Cache of per-domain TLS configs.
|
||||||
|
domain_cache: Arc<RwLock<HashMap<String, Arc<rustls::ServerConfig>>>>,
|
||||||
|
/// Path to the CA PEM file (for SSL_CERT_FILE combined bundle).
|
||||||
|
pub ca_pem_path: PathBuf,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MitmCa {
|
||||||
|
/// Load or generate the MITM CA.
|
||||||
|
///
|
||||||
|
/// The CA cert/key are stored at:
|
||||||
|
/// `<data_dir>/mitm-ca.pem` (cert, for NODE_EXTRA_CA_CERTS)
|
||||||
|
/// `<data_dir>/mitm-ca.key` (private key)
|
||||||
|
pub fn load_or_generate(data_dir: &Path) -> Result<Self, String> {
|
||||||
|
let cert_path = data_dir.join("mitm-ca.pem");
|
||||||
|
let key_path = data_dir.join("mitm-ca.key");
|
||||||
|
|
||||||
|
if cert_path.exists() && key_path.exists() {
|
||||||
|
info!("Loading existing MITM CA from {}", cert_path.display());
|
||||||
|
let cert_pem = std::fs::read_to_string(&cert_path)
|
||||||
|
.map_err(|e| format!("Failed to read CA cert: {e}"))?;
|
||||||
|
let key_pem = std::fs::read_to_string(&key_path)
|
||||||
|
.map_err(|e| format!("Failed to read CA key: {e}"))?;
|
||||||
|
|
||||||
|
let ca_key = KeyPair::from_pem(&key_pem)
|
||||||
|
.map_err(|e| format!("Failed to parse CA key: {e}"))?;
|
||||||
|
|
||||||
|
// Re-create params and self-sign to get the rcgen Certificate object
|
||||||
|
// (needed for signing leaf certs — rcgen 0.13 doesn't have from_ca_cert_pem).
|
||||||
|
// The re-signed cert will have a different serial/notBefore, but that's fine
|
||||||
|
// because we only use it for the rcgen signing API, NOT for the on-disk PEM.
|
||||||
|
let params = Self::ca_params();
|
||||||
|
let ca_signed = params.self_signed(&ca_key)
|
||||||
|
.map_err(|e| format!("Failed to self-sign CA: {e}"))?;
|
||||||
|
|
||||||
|
// Use the ORIGINAL on-disk PEM cert for DER — this is what the LS trusts
|
||||||
|
// (via the combined CA bundle built by the wrapper script). Writing the
|
||||||
|
// re-signed cert back would desync the LS's trust anchor.
|
||||||
|
let ca_cert_der = Self::pem_to_der(&cert_pem)
|
||||||
|
.unwrap_or_else(|| CertificateDer::from(ca_signed.der().to_vec()));
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
ca_cert_der,
|
||||||
|
ca_key,
|
||||||
|
ca_signed,
|
||||||
|
domain_cache: Arc::new(RwLock::new(HashMap::new())),
|
||||||
|
ca_pem_path: cert_path,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
info!("Generating new MITM CA at {}", cert_path.display());
|
||||||
|
|
||||||
|
// Ensure data dir exists
|
||||||
|
std::fs::create_dir_all(data_dir)
|
||||||
|
.map_err(|e| format!("Failed to create data dir: {e}"))?;
|
||||||
|
|
||||||
|
let ca_key = KeyPair::generate()
|
||||||
|
.map_err(|e| format!("Failed to generate CA key: {e}"))?;
|
||||||
|
|
||||||
|
let params = Self::ca_params();
|
||||||
|
let ca_signed = params.self_signed(&ca_key)
|
||||||
|
.map_err(|e| format!("Failed to self-sign CA: {e}"))?;
|
||||||
|
|
||||||
|
// Write cert and key to disk
|
||||||
|
std::fs::write(&cert_path, ca_signed.pem())
|
||||||
|
.map_err(|e| format!("Failed to write CA cert: {e}"))?;
|
||||||
|
std::fs::write(&key_path, ca_key.serialize_pem())
|
||||||
|
.map_err(|e| format!("Failed to write CA key: {e}"))?;
|
||||||
|
|
||||||
|
#[cfg(unix)]
|
||||||
|
{
|
||||||
|
use std::os::unix::fs::PermissionsExt;
|
||||||
|
let _ = std::fs::set_permissions(&key_path, std::fs::Permissions::from_mode(0o600));
|
||||||
|
}
|
||||||
|
|
||||||
|
let ca_cert_der = CertificateDer::from(ca_signed.der().to_vec());
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
ca_cert_der,
|
||||||
|
ca_key,
|
||||||
|
ca_signed,
|
||||||
|
domain_cache: Arc::new(RwLock::new(HashMap::new())),
|
||||||
|
ca_pem_path: cert_path,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Build the CA certificate parameters (reusable for both generate and load).
|
||||||
|
fn ca_params() -> CertificateParams {
|
||||||
|
let mut params = CertificateParams::default();
|
||||||
|
|
||||||
|
let mut dn = DistinguishedName::new();
|
||||||
|
dn.push(DnType::CommonName, "Antigravity MITM CA");
|
||||||
|
dn.push(DnType::OrganizationName, "Antigravity Proxy");
|
||||||
|
params.distinguished_name = dn;
|
||||||
|
|
||||||
|
params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained);
|
||||||
|
params.key_usages = vec![
|
||||||
|
KeyUsagePurpose::KeyCertSign,
|
||||||
|
KeyUsagePurpose::CrlSign,
|
||||||
|
];
|
||||||
|
|
||||||
|
// Valid for 10 years
|
||||||
|
let now = time::OffsetDateTime::now_utc();
|
||||||
|
params.not_before = now;
|
||||||
|
params.not_after = now + time::Duration::days(3650);
|
||||||
|
|
||||||
|
params
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parse a PEM certificate into a DER-encoded CertificateDer.
|
||||||
|
fn pem_to_der(pem: &str) -> Option<CertificateDer<'static>> {
|
||||||
|
// Extract base64 content between BEGIN/END markers
|
||||||
|
let mut in_cert = false;
|
||||||
|
let mut b64 = String::new();
|
||||||
|
for line in pem.lines() {
|
||||||
|
if line.contains("BEGIN CERTIFICATE") {
|
||||||
|
in_cert = true;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if line.contains("END CERTIFICATE") {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if in_cert {
|
||||||
|
b64.push_str(line.trim());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if b64.is_empty() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
use base64::Engine;
|
||||||
|
let der = base64::engine::general_purpose::STANDARD.decode(&b64).ok()?;
|
||||||
|
Some(CertificateDer::from(der))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get or create a TLS ServerConfig for the given domain.
|
||||||
|
pub async fn server_config_for_domain(&self, domain: &str) -> Result<Arc<rustls::ServerConfig>, String> {
|
||||||
|
// Check cache first
|
||||||
|
{
|
||||||
|
let cache = self.domain_cache.read().await;
|
||||||
|
if let Some(config) = cache.get(domain) {
|
||||||
|
return Ok(config.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate leaf cert for this domain
|
||||||
|
let mut params = CertificateParams::default();
|
||||||
|
|
||||||
|
let mut dn = DistinguishedName::new();
|
||||||
|
dn.push(DnType::CommonName, domain);
|
||||||
|
params.distinguished_name = dn;
|
||||||
|
|
||||||
|
params.subject_alt_names = vec![SanType::DnsName(domain.try_into().map_err(|e| format!("Invalid domain: {e}"))?)];
|
||||||
|
params.extended_key_usages = vec![ExtendedKeyUsagePurpose::ServerAuth];
|
||||||
|
params.key_usages = vec![
|
||||||
|
KeyUsagePurpose::DigitalSignature,
|
||||||
|
KeyUsagePurpose::KeyEncipherment,
|
||||||
|
];
|
||||||
|
|
||||||
|
// Valid for 1 year
|
||||||
|
let now = time::OffsetDateTime::now_utc();
|
||||||
|
params.not_before = now;
|
||||||
|
params.not_after = now + time::Duration::days(365);
|
||||||
|
|
||||||
|
let leaf_key = KeyPair::generate()
|
||||||
|
.map_err(|e| format!("Failed to generate leaf key: {e}"))?;
|
||||||
|
|
||||||
|
let leaf_cert = params.signed_by(&leaf_key, &self.ca_signed, &self.ca_key)
|
||||||
|
.map_err(|e| format!("Failed to sign leaf cert for {domain}: {e}"))?;
|
||||||
|
|
||||||
|
// Build rustls ServerConfig
|
||||||
|
let leaf_cert_der = CertificateDer::from(leaf_cert.der().to_vec());
|
||||||
|
let leaf_key_der = PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(leaf_key.serialize_der()));
|
||||||
|
|
||||||
|
let mut config = rustls::ServerConfig::builder()
|
||||||
|
.with_no_client_auth()
|
||||||
|
.with_single_cert(
|
||||||
|
vec![leaf_cert_der, self.ca_cert_der.clone()],
|
||||||
|
leaf_key_der,
|
||||||
|
)
|
||||||
|
.map_err(|e| format!("Failed to build ServerConfig for {domain}: {e}"))?;
|
||||||
|
|
||||||
|
// Advertise both h2 and http/1.1 so gRPC clients can negotiate HTTP/2
|
||||||
|
config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
|
||||||
|
|
||||||
|
let config = Arc::new(config);
|
||||||
|
|
||||||
|
// Cache it
|
||||||
|
{
|
||||||
|
let mut cache = self.domain_cache.write().await;
|
||||||
|
cache.insert(domain.to_string(), config.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(config)
|
||||||
|
}
|
||||||
|
}
|
||||||
512
src/mitm/h2_handler.rs
Normal file
512
src/mitm/h2_handler.rs
Normal file
@@ -0,0 +1,512 @@
|
|||||||
|
//! HTTP/2 handler for gRPC traffic interception.
|
||||||
|
//!
|
||||||
|
//! When the LS negotiates HTTP/2 via ALPN (which all gRPC connections do),
|
||||||
|
//! this module handles the bidirectional HTTP/2 connection:
|
||||||
|
//! 1. Accepts HTTP/2 frames from the client (LS)
|
||||||
|
//! 2. Connects to the real upstream via TLS + HTTP/2 (single connection reused)
|
||||||
|
//! 3. Forwards each request stream to upstream
|
||||||
|
//! 4. For non-streaming: buffers response, extracts usage, forwards
|
||||||
|
//! 5. For streaming: forwards response body chunks in real-time, tees to a
|
||||||
|
//! side buffer for usage extraction after stream completes
|
||||||
|
//!
|
||||||
|
//! ## Streaming vs Non-streaming
|
||||||
|
//!
|
||||||
|
//! gRPC has both unary (non-streaming) and server-streaming RPCs.
|
||||||
|
//! The LS uses server-streaming for methods like `StreamGenerateContent`.
|
||||||
|
//! We MUST forward streaming responses immediately — buffering would break
|
||||||
|
//! the LS's perception of real-time generation.
|
||||||
|
//!
|
||||||
|
//! For usage extraction: ModelUsageStats is typically in the LAST message
|
||||||
|
//! of a streaming response, so we tee the data and parse after stream ends.
|
||||||
|
|
||||||
|
use crate::mitm::proto::parse_grpc_response_for_usage;
|
||||||
|
use crate::mitm::store::{ApiUsage, MitmStore};
|
||||||
|
|
||||||
|
use bytes::Bytes;
|
||||||
|
use http_body_util::{BodyExt, Full, StreamBody};
|
||||||
|
use hyper::body::{Frame, Incoming};
|
||||||
|
use hyper::server::conn::http2::Builder as H2ServerBuilder;
|
||||||
|
use hyper::service::service_fn;
|
||||||
|
use hyper::{Request, Response};
|
||||||
|
use hyper_util::rt::TokioExecutor;
|
||||||
|
use hyper_util::rt::TokioIo;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tokio::io::{AsyncRead, AsyncWrite};
|
||||||
|
use tokio::net::TcpStream;
|
||||||
|
use tokio::sync::Mutex;
|
||||||
|
use tracing::{debug, info, trace, warn};
|
||||||
|
|
||||||
|
/// A lazily-initialized, shared HTTP/2 connection to the upstream server.
|
||||||
|
///
|
||||||
|
/// gRPC multiplexes many requests over a single HTTP/2 connection.
|
||||||
|
/// We mirror this by maintaining a single upstream connection per domain.
|
||||||
|
struct UpstreamPool {
|
||||||
|
domain: String,
|
||||||
|
tls_config: Arc<rustls::ClientConfig>,
|
||||||
|
sender: Mutex<Option<hyper::client::conn::http2::SendRequest<Full<Bytes>>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl UpstreamPool {
|
||||||
|
fn new(domain: String, tls_config: Arc<rustls::ClientConfig>) -> Self {
|
||||||
|
Self {
|
||||||
|
domain,
|
||||||
|
tls_config,
|
||||||
|
sender: Mutex::new(None),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get or create the upstream HTTP/2 sender.
|
||||||
|
async fn get_sender(
|
||||||
|
&self,
|
||||||
|
) -> Result<hyper::client::conn::http2::SendRequest<Full<Bytes>>, String> {
|
||||||
|
let mut guard = self.sender.lock().await;
|
||||||
|
|
||||||
|
// Check if existing sender is still usable
|
||||||
|
if let Some(ref sender) = *guard {
|
||||||
|
if !sender.is_closed() {
|
||||||
|
return Ok(sender.clone());
|
||||||
|
}
|
||||||
|
debug!(domain = %self.domain, "MITM H2: upstream connection closed, reconnecting");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create new connection
|
||||||
|
let sender = self.connect().await?;
|
||||||
|
*guard = Some(sender.clone());
|
||||||
|
Ok(sender)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn connect(
|
||||||
|
&self,
|
||||||
|
) -> Result<hyper::client::conn::http2::SendRequest<Full<Bytes>>, String> {
|
||||||
|
let upstream_tcp = TcpStream::connect(format!("{}:443", self.domain))
|
||||||
|
.await
|
||||||
|
.map_err(|e| format!("upstream TCP connect to {} failed: {e}", self.domain))?;
|
||||||
|
|
||||||
|
let connector = tokio_rustls::TlsConnector::from(self.tls_config.clone());
|
||||||
|
let server_name = rustls::pki_types::ServerName::try_from(self.domain.clone())
|
||||||
|
.map_err(|e| format!("invalid domain {}: {e}", self.domain))?;
|
||||||
|
|
||||||
|
let upstream_tls = connector
|
||||||
|
.connect(server_name, upstream_tcp)
|
||||||
|
.await
|
||||||
|
.map_err(|e| format!("upstream TLS to {} failed: {e}", self.domain))?;
|
||||||
|
|
||||||
|
let upstream_io = TokioIo::new(upstream_tls);
|
||||||
|
let (sender, conn) =
|
||||||
|
hyper::client::conn::http2::Builder::new(TokioExecutor::new())
|
||||||
|
.handshake(upstream_io)
|
||||||
|
.await
|
||||||
|
.map_err(|e| format!("upstream h2 handshake to {} failed: {e}", self.domain))?;
|
||||||
|
|
||||||
|
let domain = self.domain.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
if let Err(e) = conn.await {
|
||||||
|
debug!(domain = %domain, error = %e, "MITM H2: upstream connection driver ended");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
info!(domain = %self.domain, "MITM H2: established upstream HTTP/2 connection");
|
||||||
|
Ok(sender)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// gRPC methods that carry ModelUsageStats in their responses.
|
||||||
|
const USAGE_METHODS: &[&str] = &[
|
||||||
|
// Unary methods
|
||||||
|
"GenerateContent",
|
||||||
|
"AsyncGenerateContent",
|
||||||
|
"GenerateChat",
|
||||||
|
"GenerateCode",
|
||||||
|
"CompleteCode",
|
||||||
|
"InternalAtomicAgenticChat",
|
||||||
|
"Predict",
|
||||||
|
"DirectPredict",
|
||||||
|
// Streaming methods
|
||||||
|
"StreamGenerateContent",
|
||||||
|
"StreamAsyncGenerateContent",
|
||||||
|
"StreamGenerateChat",
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Handle an HTTP/2 connection from the LS after TLS termination.
|
||||||
|
///
|
||||||
|
/// Uses hyper's HTTP/2 server to accept requests and a shared upstream
|
||||||
|
/// HTTP/2 connection to forward them.
|
||||||
|
pub async fn handle_h2_connection<S>(
|
||||||
|
tls_stream: S,
|
||||||
|
domain: String,
|
||||||
|
store: MitmStore,
|
||||||
|
) -> Result<(), String>
|
||||||
|
where
|
||||||
|
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
|
||||||
|
{
|
||||||
|
info!(domain = %domain, "MITM H2: handling HTTP/2 connection");
|
||||||
|
|
||||||
|
// Build TLS config for upstream connections
|
||||||
|
let mut root_store = rustls::RootCertStore::empty();
|
||||||
|
let native_certs = rustls_native_certs::load_native_certs();
|
||||||
|
for cert in native_certs.certs {
|
||||||
|
let _ = root_store.add(cert);
|
||||||
|
}
|
||||||
|
let mut upstream_tls_config = rustls::ClientConfig::builder()
|
||||||
|
.with_root_certificates(root_store)
|
||||||
|
.with_no_client_auth();
|
||||||
|
upstream_tls_config.alpn_protocols = vec![b"h2".to_vec()];
|
||||||
|
|
||||||
|
// Shared upstream connection pool (single connection, multiplexed)
|
||||||
|
let pool = Arc::new(UpstreamPool::new(
|
||||||
|
domain.clone(),
|
||||||
|
Arc::new(upstream_tls_config),
|
||||||
|
));
|
||||||
|
|
||||||
|
let io = TokioIo::new(tls_stream);
|
||||||
|
let domain = Arc::new(domain);
|
||||||
|
|
||||||
|
let result = H2ServerBuilder::new(TokioExecutor::new())
|
||||||
|
.serve_connection(
|
||||||
|
io,
|
||||||
|
service_fn(move |req: Request<Incoming>| {
|
||||||
|
let domain = domain.clone();
|
||||||
|
let store = store.clone();
|
||||||
|
let pool = pool.clone();
|
||||||
|
async move { handle_h2_request(req, &domain, store, pool).await }
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
match result {
|
||||||
|
Ok(()) => {
|
||||||
|
debug!("MITM H2: connection closed cleanly");
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
// Connection errors are expected on clean close
|
||||||
|
debug!(error = %e, "MITM H2: connection ended");
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Response body type — either buffered or streaming.
|
||||||
|
type BoxBody = http_body_util::Either<
|
||||||
|
Full<Bytes>,
|
||||||
|
StreamBody<tokio_stream::wrappers::ReceiverStream<Result<Frame<Bytes>, hyper::Error>>>,
|
||||||
|
>;
|
||||||
|
|
||||||
|
/// Handle a single HTTP/2 request: forward to upstream, capture usage.
|
||||||
|
///
|
||||||
|
/// For streaming responses, forwards chunks in real-time while teeing
|
||||||
|
/// data to a side buffer for post-stream usage extraction.
|
||||||
|
async fn handle_h2_request(
|
||||||
|
req: Request<Incoming>,
|
||||||
|
domain: &str,
|
||||||
|
store: MitmStore,
|
||||||
|
pool: Arc<UpstreamPool>,
|
||||||
|
) -> Result<Response<BoxBody>, hyper::Error> {
|
||||||
|
let method = req.method().clone();
|
||||||
|
let uri = req.uri().clone();
|
||||||
|
let path = uri.path().to_string();
|
||||||
|
|
||||||
|
// Identify gRPC method
|
||||||
|
let is_grpc = req
|
||||||
|
.headers()
|
||||||
|
.get("content-type")
|
||||||
|
.and_then(|v| v.to_str().ok())
|
||||||
|
.map(|ct| ct.starts_with("application/grpc"))
|
||||||
|
.unwrap_or(false);
|
||||||
|
|
||||||
|
// Check if this method carries usage data
|
||||||
|
let is_usage_method = is_grpc
|
||||||
|
&& USAGE_METHODS.iter().any(|m| path.contains(m));
|
||||||
|
|
||||||
|
// Check if this is a streaming method
|
||||||
|
let is_streaming = is_grpc
|
||||||
|
&& (path.contains("Stream") || path.contains("stream"));
|
||||||
|
|
||||||
|
debug!(
|
||||||
|
domain,
|
||||||
|
%method,
|
||||||
|
path = %path,
|
||||||
|
grpc = is_grpc,
|
||||||
|
usage_method = is_usage_method,
|
||||||
|
streaming = is_streaming,
|
||||||
|
"MITM H2: forwarding request"
|
||||||
|
);
|
||||||
|
|
||||||
|
// Collect request body (we need it for cascade ID extraction)
|
||||||
|
let (parts, body) = req.into_parts();
|
||||||
|
let request_body = match body.collect().await {
|
||||||
|
Ok(collected) => collected.to_bytes(),
|
||||||
|
Err(e) => {
|
||||||
|
warn!(error = %e, "MITM H2: failed to collect request body");
|
||||||
|
Bytes::new()
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Get upstream sender from pool
|
||||||
|
let mut upstream_sender = match pool.get_sender().await {
|
||||||
|
Ok(s) => s,
|
||||||
|
Err(e) => {
|
||||||
|
warn!(error = %e, domain, "MITM H2: upstream connect failed");
|
||||||
|
let resp = Response::builder()
|
||||||
|
.status(502)
|
||||||
|
.body(http_body_util::Either::Left(Full::new(
|
||||||
|
Bytes::from(format!("upstream connect failed: {e}")),
|
||||||
|
)))
|
||||||
|
.unwrap();
|
||||||
|
return Ok(resp);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Build the upstream request with proper authority
|
||||||
|
let upstream_uri = http::Uri::builder()
|
||||||
|
.scheme("https")
|
||||||
|
.authority(domain)
|
||||||
|
.path_and_query(
|
||||||
|
uri.path_and_query()
|
||||||
|
.map(|pq| pq.as_str())
|
||||||
|
.unwrap_or("/"),
|
||||||
|
)
|
||||||
|
.build()
|
||||||
|
.unwrap_or(uri);
|
||||||
|
|
||||||
|
let mut upstream_req = Request::builder()
|
||||||
|
.method(parts.method)
|
||||||
|
.uri(upstream_uri);
|
||||||
|
|
||||||
|
// Copy headers, skip hop-by-hop
|
||||||
|
for (name, value) in &parts.headers {
|
||||||
|
let n = name.as_str();
|
||||||
|
if n == "host" || n == "connection" || n == "transfer-encoding" {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
upstream_req = upstream_req.header(name, value);
|
||||||
|
}
|
||||||
|
|
||||||
|
let upstream_req = match upstream_req.body(Full::new(request_body.clone())) {
|
||||||
|
Ok(r) => r,
|
||||||
|
Err(e) => {
|
||||||
|
let resp = Response::builder()
|
||||||
|
.status(502)
|
||||||
|
.body(http_body_util::Either::Left(Full::new(
|
||||||
|
Bytes::from(format!("build request failed: {e}")),
|
||||||
|
)))
|
||||||
|
.unwrap();
|
||||||
|
return Ok(resp);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Send to upstream
|
||||||
|
let upstream_resp = match upstream_sender.send_request(upstream_req).await {
|
||||||
|
Ok(r) => r,
|
||||||
|
Err(e) => {
|
||||||
|
warn!(error = %e, domain, path = %path, "MITM H2: upstream request failed");
|
||||||
|
let resp = Response::builder()
|
||||||
|
.status(502)
|
||||||
|
.body(http_body_util::Either::Left(Full::new(
|
||||||
|
Bytes::from(format!("upstream request failed: {e}")),
|
||||||
|
)))
|
||||||
|
.unwrap();
|
||||||
|
return Ok(resp);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let (resp_parts, resp_body) = upstream_resp.into_parts();
|
||||||
|
let status = resp_parts.status;
|
||||||
|
|
||||||
|
// ──────────────────────────────────────────────────────────────────
|
||||||
|
// Streaming path: forward chunks immediately, tee for usage parsing
|
||||||
|
// ──────────────────────────────────────────────────────────────────
|
||||||
|
if is_streaming && status.is_success() {
|
||||||
|
let should_track_usage = is_usage_method;
|
||||||
|
let (tx, rx) = tokio::sync::mpsc::channel::<Result<Frame<Bytes>, hyper::Error>>(32);
|
||||||
|
|
||||||
|
let store_clone = store.clone();
|
||||||
|
let path_clone = path.clone();
|
||||||
|
let request_body_clone = request_body.clone();
|
||||||
|
|
||||||
|
// Spawn a task to forward body chunks and tee for usage extraction
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let mut tee_buffer = if should_track_usage { Some(Vec::new()) } else { None };
|
||||||
|
let mut body = resp_body;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
match body.frame().await {
|
||||||
|
Some(Ok(frame)) => {
|
||||||
|
if let (Some(ref mut buf), Some(data)) = (&mut tee_buffer, frame.data_ref()) {
|
||||||
|
buf.extend_from_slice(data);
|
||||||
|
}
|
||||||
|
if tx.send(Ok(frame)).await.is_err() {
|
||||||
|
break; // client disconnected
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Some(Err(e)) => {
|
||||||
|
warn!(error = %e, path = %path_clone, "MITM H2: streaming error");
|
||||||
|
let _ = tx.send(Err(e)).await;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
None => break, // stream ended
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stream completed — parse the tee buffer for usage
|
||||||
|
if let Some(tee_buffer) = tee_buffer {
|
||||||
|
if !tee_buffer.is_empty() {
|
||||||
|
if let Some(grpc_usage) = parse_grpc_response_for_usage(&tee_buffer) {
|
||||||
|
let usage = ApiUsage {
|
||||||
|
input_tokens: grpc_usage.input_tokens,
|
||||||
|
output_tokens: grpc_usage.output_tokens,
|
||||||
|
thinking_output_tokens: grpc_usage.thinking_output_tokens,
|
||||||
|
response_output_tokens: grpc_usage.response_output_tokens,
|
||||||
|
cache_creation_input_tokens: grpc_usage.cache_write_tokens,
|
||||||
|
cache_read_input_tokens: grpc_usage.cache_read_tokens,
|
||||||
|
model: grpc_usage.model,
|
||||||
|
api_provider: grpc_usage.api_provider,
|
||||||
|
grpc_method: Some(path_clone.clone()),
|
||||||
|
stop_reason: None,
|
||||||
|
total_cost_usd: None,
|
||||||
|
captured_at: std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.unwrap_or_default()
|
||||||
|
.as_secs(),
|
||||||
|
};
|
||||||
|
let cascade_hint = extract_cascade_from_grpc_request(&request_body_clone);
|
||||||
|
store_clone.record_usage(cascade_hint.as_deref(), usage).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
|
||||||
|
let stream_body = StreamBody::new(stream);
|
||||||
|
|
||||||
|
let mut client_resp = Response::builder().status(resp_parts.status);
|
||||||
|
for (name, value) in &resp_parts.headers {
|
||||||
|
client_resp = client_resp.header(name, value);
|
||||||
|
}
|
||||||
|
|
||||||
|
let client_resp = client_resp
|
||||||
|
.body(http_body_util::Either::Right(stream_body))
|
||||||
|
.unwrap_or_else(|_| {
|
||||||
|
Response::builder()
|
||||||
|
.status(500)
|
||||||
|
.body(http_body_util::Either::Left(Full::new(Bytes::from(
|
||||||
|
"internal error",
|
||||||
|
))))
|
||||||
|
.unwrap()
|
||||||
|
});
|
||||||
|
|
||||||
|
return Ok(client_resp);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ──────────────────────────────────────────────────────────────────
|
||||||
|
// Non-streaming path: buffer full response, extract usage, forward
|
||||||
|
// ──────────────────────────────────────────────────────────────────
|
||||||
|
let response_body = match resp_body.collect().await {
|
||||||
|
Ok(collected) => collected.to_bytes(),
|
||||||
|
Err(e) => {
|
||||||
|
warn!(error = %e, "MITM H2: failed to collect response body");
|
||||||
|
Bytes::new()
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
trace!(
|
||||||
|
domain,
|
||||||
|
path = %path,
|
||||||
|
status = %status,
|
||||||
|
body_len = response_body.len(),
|
||||||
|
"MITM H2: got upstream response"
|
||||||
|
);
|
||||||
|
|
||||||
|
// Extract usage data from usage-carrying gRPC methods
|
||||||
|
if is_usage_method && !response_body.is_empty() && status.is_success() {
|
||||||
|
if let Some(grpc_usage) = parse_grpc_response_for_usage(&response_body) {
|
||||||
|
let usage = ApiUsage {
|
||||||
|
input_tokens: grpc_usage.input_tokens,
|
||||||
|
output_tokens: grpc_usage.output_tokens,
|
||||||
|
thinking_output_tokens: grpc_usage.thinking_output_tokens,
|
||||||
|
response_output_tokens: grpc_usage.response_output_tokens,
|
||||||
|
cache_creation_input_tokens: grpc_usage.cache_write_tokens,
|
||||||
|
cache_read_input_tokens: grpc_usage.cache_read_tokens,
|
||||||
|
model: grpc_usage.model,
|
||||||
|
api_provider: grpc_usage.api_provider,
|
||||||
|
grpc_method: Some(path.clone()),
|
||||||
|
stop_reason: None,
|
||||||
|
total_cost_usd: None,
|
||||||
|
captured_at: std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.unwrap_or_default()
|
||||||
|
.as_secs(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let cascade_hint = extract_cascade_from_grpc_request(&request_body);
|
||||||
|
store.record_usage(cascade_hint.as_deref(), usage).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build response for the client
|
||||||
|
let mut client_resp = Response::builder().status(resp_parts.status);
|
||||||
|
for (name, value) in &resp_parts.headers {
|
||||||
|
client_resp = client_resp.header(name, value);
|
||||||
|
}
|
||||||
|
|
||||||
|
let client_resp = client_resp
|
||||||
|
.body(http_body_util::Either::Left(Full::new(response_body)))
|
||||||
|
.unwrap_or_else(|_| {
|
||||||
|
Response::builder()
|
||||||
|
.status(500)
|
||||||
|
.body(http_body_util::Either::Left(Full::new(Bytes::from(
|
||||||
|
"internal error",
|
||||||
|
))))
|
||||||
|
.unwrap()
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(client_resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Try to extract a cascade ID from a gRPC request body.
|
||||||
|
///
|
||||||
|
/// Looks for UUID-formatted strings in the protobuf fields.
|
||||||
|
fn extract_cascade_from_grpc_request(body: &[u8]) -> Option<String> {
|
||||||
|
use crate::mitm::proto::{decode_proto, extract_grpc_messages};
|
||||||
|
|
||||||
|
let messages = extract_grpc_messages(body);
|
||||||
|
for msg in &messages {
|
||||||
|
let fields = decode_proto(msg);
|
||||||
|
for field in &fields {
|
||||||
|
if let Some(id) = extract_uuid_from_field(field) {
|
||||||
|
return Some(id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
fn extract_uuid_from_field(field: &crate::mitm::proto::ProtoField) -> Option<String> {
|
||||||
|
use crate::mitm::proto::ProtoValue;
|
||||||
|
|
||||||
|
match &field.value {
|
||||||
|
ProtoValue::Bytes(b) => {
|
||||||
|
if let Ok(s) = std::str::from_utf8(b) {
|
||||||
|
if is_uuid(s) {
|
||||||
|
return Some(s.to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ProtoValue::Message(nested) => {
|
||||||
|
for nf in nested {
|
||||||
|
if let Some(id) = extract_uuid_from_field(nf) {
|
||||||
|
return Some(id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_uuid(s: &str) -> bool {
|
||||||
|
s.len() == 36
|
||||||
|
&& s.chars().all(|c| c.is_ascii_hexdigit() || c == '-')
|
||||||
|
&& s.chars().filter(|&c| c == '-').count() == 4
|
||||||
|
}
|
||||||
271
src/mitm/intercept.rs
Normal file
271
src/mitm/intercept.rs
Normal file
@@ -0,0 +1,271 @@
|
|||||||
|
//! API response interceptor: parses Anthropic/Google API responses to extract usage data.
|
||||||
|
//!
|
||||||
|
//! Handles both streaming (SSE) and non-streaming (JSON) responses.
|
||||||
|
|
||||||
|
use super::store::ApiUsage;
|
||||||
|
use serde_json::Value;
|
||||||
|
use tracing::{debug, trace};
|
||||||
|
|
||||||
|
/// Parse a complete (non-streaming) Anthropic Messages API response body.
|
||||||
|
///
|
||||||
|
/// Response format:
|
||||||
|
/// ```json
|
||||||
|
/// {
|
||||||
|
/// "id": "msg_...",
|
||||||
|
/// "type": "message",
|
||||||
|
/// "model": "claude-sonnet-4-20250514",
|
||||||
|
/// "usage": {
|
||||||
|
/// "input_tokens": 1234,
|
||||||
|
/// "output_tokens": 567,
|
||||||
|
/// "cache_creation_input_tokens": 0,
|
||||||
|
/// "cache_read_input_tokens": 890
|
||||||
|
/// },
|
||||||
|
/// "stop_reason": "end_turn"
|
||||||
|
/// }
|
||||||
|
/// ```
|
||||||
|
pub fn parse_non_streaming_response(body: &[u8]) -> Option<ApiUsage> {
|
||||||
|
let json: Value = serde_json::from_slice(body).ok()?;
|
||||||
|
extract_usage_from_message(&json)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parse SSE events from a streaming Anthropic response body chunk.
|
||||||
|
///
|
||||||
|
/// Events of interest:
|
||||||
|
/// - `message_start` — contains `message.usage.input_tokens` + cache tokens
|
||||||
|
/// - `message_delta` — contains `usage.output_tokens`
|
||||||
|
/// - `message_stop` — marks end (no usage data)
|
||||||
|
///
|
||||||
|
/// Returns accumulated usage across all events in this chunk.
|
||||||
|
pub fn parse_streaming_chunk(chunk: &str, accumulator: &mut StreamingAccumulator) {
|
||||||
|
for line in chunk.lines() {
|
||||||
|
if let Some(data) = line.strip_prefix("data: ") {
|
||||||
|
if data.trim() == "[DONE]" {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if let Ok(event) = serde_json::from_str::<Value>(data) {
|
||||||
|
accumulator.process_event(&event);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Accumulates usage data across streaming SSE events.
|
||||||
|
#[derive(Debug, Default)]
|
||||||
|
pub struct StreamingAccumulator {
|
||||||
|
pub input_tokens: u64,
|
||||||
|
pub output_tokens: u64,
|
||||||
|
pub cache_creation_input_tokens: u64,
|
||||||
|
pub cache_read_input_tokens: u64,
|
||||||
|
pub model: Option<String>,
|
||||||
|
pub stop_reason: Option<String>,
|
||||||
|
pub is_complete: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl StreamingAccumulator {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self::default()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Process a single SSE event.
|
||||||
|
pub fn process_event(&mut self, event: &Value) {
|
||||||
|
let event_type = event["type"].as_str().unwrap_or("");
|
||||||
|
|
||||||
|
match event_type {
|
||||||
|
"message_start" => {
|
||||||
|
// message_start contains the initial usage (input tokens + cache)
|
||||||
|
if let Some(usage) = event.get("message").and_then(|m| m.get("usage")) {
|
||||||
|
self.input_tokens = usage["input_tokens"].as_u64().unwrap_or(0);
|
||||||
|
self.cache_creation_input_tokens = usage["cache_creation_input_tokens"].as_u64().unwrap_or(0);
|
||||||
|
self.cache_read_input_tokens = usage["cache_read_input_tokens"].as_u64().unwrap_or(0);
|
||||||
|
}
|
||||||
|
if let Some(model) = event.get("message").and_then(|m| m["model"].as_str()) {
|
||||||
|
self.model = Some(model.to_string());
|
||||||
|
}
|
||||||
|
trace!(
|
||||||
|
input = self.input_tokens,
|
||||||
|
cache_read = self.cache_read_input_tokens,
|
||||||
|
cache_create = self.cache_creation_input_tokens,
|
||||||
|
"SSE message_start: captured input usage"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
"message_delta" => {
|
||||||
|
// message_delta contains the output usage
|
||||||
|
if let Some(usage) = event.get("usage") {
|
||||||
|
self.output_tokens = usage["output_tokens"].as_u64().unwrap_or(self.output_tokens);
|
||||||
|
}
|
||||||
|
if let Some(reason) = event["delta"]["stop_reason"].as_str() {
|
||||||
|
self.stop_reason = Some(reason.to_string());
|
||||||
|
}
|
||||||
|
trace!(output = self.output_tokens, "SSE message_delta: updated output tokens");
|
||||||
|
}
|
||||||
|
"message_stop" => {
|
||||||
|
self.is_complete = true;
|
||||||
|
debug!(
|
||||||
|
input = self.input_tokens,
|
||||||
|
output = self.output_tokens,
|
||||||
|
cache_read = self.cache_read_input_tokens,
|
||||||
|
model = ?self.model,
|
||||||
|
"SSE message_stop: stream complete"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
"content_block_start" | "content_block_delta" | "content_block_stop" | "ping" => {
|
||||||
|
// Content events — no usage data, just pass through
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
trace!(event_type, "SSE: unknown event type");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Convert accumulated data to an ApiUsage.
|
||||||
|
pub fn into_usage(self) -> ApiUsage {
|
||||||
|
ApiUsage {
|
||||||
|
input_tokens: self.input_tokens,
|
||||||
|
output_tokens: self.output_tokens,
|
||||||
|
cache_creation_input_tokens: self.cache_creation_input_tokens,
|
||||||
|
cache_read_input_tokens: self.cache_read_input_tokens,
|
||||||
|
thinking_output_tokens: 0,
|
||||||
|
response_output_tokens: 0,
|
||||||
|
total_cost_usd: None,
|
||||||
|
model: self.model,
|
||||||
|
stop_reason: self.stop_reason,
|
||||||
|
api_provider: Some("anthropic".to_string()),
|
||||||
|
grpc_method: None,
|
||||||
|
captured_at: std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.unwrap_or_default()
|
||||||
|
.as_secs(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Extract usage from a complete Message JSON object.
|
||||||
|
fn extract_usage_from_message(msg: &Value) -> Option<ApiUsage> {
|
||||||
|
let usage = msg.get("usage")?;
|
||||||
|
|
||||||
|
Some(ApiUsage {
|
||||||
|
input_tokens: usage["input_tokens"].as_u64().unwrap_or(0),
|
||||||
|
output_tokens: usage["output_tokens"].as_u64().unwrap_or(0),
|
||||||
|
cache_creation_input_tokens: usage["cache_creation_input_tokens"].as_u64().unwrap_or(0),
|
||||||
|
cache_read_input_tokens: usage["cache_read_input_tokens"].as_u64().unwrap_or(0),
|
||||||
|
thinking_output_tokens: 0,
|
||||||
|
response_output_tokens: 0,
|
||||||
|
total_cost_usd: None,
|
||||||
|
model: msg["model"].as_str().map(|s| s.to_string()),
|
||||||
|
stop_reason: msg["stop_reason"].as_str().map(|s| s.to_string()),
|
||||||
|
api_provider: Some("anthropic".to_string()),
|
||||||
|
grpc_method: None,
|
||||||
|
captured_at: std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.unwrap_or_default()
|
||||||
|
.as_secs(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Try to identify a cascade ID from the request body.
|
||||||
|
///
|
||||||
|
/// The LS includes cascade-related metadata in its API requests (as part of
|
||||||
|
/// the system prompt or metadata field). We try to find it.
|
||||||
|
pub fn extract_cascade_hint(request_body: &[u8]) -> Option<String> {
|
||||||
|
let json: Value = serde_json::from_slice(request_body).ok()?;
|
||||||
|
|
||||||
|
// Check for metadata field (some API configurations include it)
|
||||||
|
if let Some(metadata) = json.get("metadata") {
|
||||||
|
if let Some(user_id) = metadata["user_id"].as_str() {
|
||||||
|
// The LS often sets user_id to the cascadeId
|
||||||
|
return Some(user_id.to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check system prompt for cascade/workspace markers
|
||||||
|
if let Some(system) = json.get("system") {
|
||||||
|
let system_str = match system {
|
||||||
|
Value::String(s) => s.clone(),
|
||||||
|
Value::Array(arr) => {
|
||||||
|
// Array of content blocks
|
||||||
|
arr.iter()
|
||||||
|
.filter_map(|b| b["text"].as_str())
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join(" ")
|
||||||
|
}
|
||||||
|
_ => return None,
|
||||||
|
};
|
||||||
|
// Look for workspace_id or cascade_id patterns
|
||||||
|
if let Some(pos) = system_str.find("workspace_id") {
|
||||||
|
let rest = &system_str[pos..];
|
||||||
|
// Extract the value after workspace_id
|
||||||
|
if let Some(val) = rest.split_whitespace().nth(1) {
|
||||||
|
return Some(val.to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parse_non_streaming() {
|
||||||
|
let body = r#"{
|
||||||
|
"id": "msg_123",
|
||||||
|
"type": "message",
|
||||||
|
"model": "claude-sonnet-4-20250514",
|
||||||
|
"usage": {
|
||||||
|
"input_tokens": 100,
|
||||||
|
"output_tokens": 50,
|
||||||
|
"cache_creation_input_tokens": 10,
|
||||||
|
"cache_read_input_tokens": 30
|
||||||
|
},
|
||||||
|
"stop_reason": "end_turn"
|
||||||
|
}"#;
|
||||||
|
|
||||||
|
let usage = parse_non_streaming_response(body.as_bytes()).unwrap();
|
||||||
|
assert_eq!(usage.input_tokens, 100);
|
||||||
|
assert_eq!(usage.output_tokens, 50);
|
||||||
|
assert_eq!(usage.cache_creation_input_tokens, 10);
|
||||||
|
assert_eq!(usage.cache_read_input_tokens, 30);
|
||||||
|
assert_eq!(usage.model.as_deref(), Some("claude-sonnet-4-20250514"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_streaming_accumulator() {
|
||||||
|
let mut acc = StreamingAccumulator::new();
|
||||||
|
|
||||||
|
// message_start
|
||||||
|
let start = serde_json::json!({
|
||||||
|
"type": "message_start",
|
||||||
|
"message": {
|
||||||
|
"model": "claude-sonnet-4-20250514",
|
||||||
|
"usage": {
|
||||||
|
"input_tokens": 200,
|
||||||
|
"cache_creation_input_tokens": 5,
|
||||||
|
"cache_read_input_tokens": 50
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
acc.process_event(&start);
|
||||||
|
assert_eq!(acc.input_tokens, 200);
|
||||||
|
assert_eq!(acc.cache_read_input_tokens, 50);
|
||||||
|
|
||||||
|
// message_delta
|
||||||
|
let delta = serde_json::json!({
|
||||||
|
"type": "message_delta",
|
||||||
|
"delta": { "stop_reason": "end_turn" },
|
||||||
|
"usage": { "output_tokens": 75 }
|
||||||
|
});
|
||||||
|
acc.process_event(&delta);
|
||||||
|
assert_eq!(acc.output_tokens, 75);
|
||||||
|
|
||||||
|
// message_stop
|
||||||
|
let stop = serde_json::json!({ "type": "message_stop" });
|
||||||
|
acc.process_event(&stop);
|
||||||
|
assert!(acc.is_complete);
|
||||||
|
|
||||||
|
let usage = acc.into_usage();
|
||||||
|
assert_eq!(usage.input_tokens, 200);
|
||||||
|
assert_eq!(usage.output_tokens, 75);
|
||||||
|
}
|
||||||
|
}
|
||||||
19
src/mitm/mod.rs
Normal file
19
src/mitm/mod.rs
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
//! MITM proxy module: intercepts LS ↔ Google/Anthropic API traffic.
|
||||||
|
//!
|
||||||
|
//! The LS (Go binary with BoringCrypto) respects `HTTPS_PROXY` and `SSL_CERT_FILE`.
|
||||||
|
//! By setting these env vars via the wrapper script, we route all outbound HTTPS
|
||||||
|
//! traffic through our local MITM proxy, which:
|
||||||
|
//!
|
||||||
|
//! 1. Terminates TLS using dynamically-generated per-domain certificates
|
||||||
|
//! 2. Detects protocol: HTTP/1.1 (REST) or HTTP/2 (gRPC)
|
||||||
|
//! 3. For HTTP/1.1: parses JSON/SSE responses (Anthropic format)
|
||||||
|
//! 4. For HTTP/2: decodes gRPC protobuf responses (Google format)
|
||||||
|
//! 5. Captures token usage data (input, output, thinking, cache)
|
||||||
|
//! 6. Forwards everything transparently to real upstream servers
|
||||||
|
|
||||||
|
pub mod ca;
|
||||||
|
pub mod h2_handler;
|
||||||
|
pub mod intercept;
|
||||||
|
pub mod proto;
|
||||||
|
pub mod proxy;
|
||||||
|
pub mod store;
|
||||||
584
src/mitm/proto.rs
Normal file
584
src/mitm/proto.rs
Normal file
@@ -0,0 +1,584 @@
|
|||||||
|
//! Raw protobuf decoder for extracting ModelUsageStats from gRPC responses.
|
||||||
|
//!
|
||||||
|
//! We don't have the .proto schema, so we decode protobuf messages generically
|
||||||
|
//! and search for usage-like structures by matching field patterns.
|
||||||
|
//!
|
||||||
|
//! gRPC wire format:
|
||||||
|
//! - 1 byte: compression flag (0 = uncompressed, 1 = compressed)
|
||||||
|
//! - 4 bytes: message length (big-endian u32)
|
||||||
|
//! - N bytes: protobuf message
|
||||||
|
//!
|
||||||
|
//! Protobuf wire format:
|
||||||
|
//! - Each field: (field_number << 3 | wire_type) as varint, then value
|
||||||
|
//! - Wire type 0: varint
|
||||||
|
//! - Wire type 1: 64-bit fixed
|
||||||
|
//! - Wire type 2: length-delimited (string, bytes, embedded message)
|
||||||
|
//! - Wire type 5: 32-bit fixed
|
||||||
|
//!
|
||||||
|
//! ## ModelUsageStats schema (reverse-engineered from LS binary):
|
||||||
|
//!
|
||||||
|
//! ```protobuf
|
||||||
|
//! message ModelUsageStats {
|
||||||
|
//! Model model = 1; // enum (varint)
|
||||||
|
//! uint64 input_tokens = 2;
|
||||||
|
//! uint64 output_tokens = 3;
|
||||||
|
//! uint64 cache_write_tokens = 4;
|
||||||
|
//! uint64 cache_read_tokens = 5;
|
||||||
|
//! APIProvider api_provider = 6; // enum (varint)
|
||||||
|
//! string message_id = 7;
|
||||||
|
//! map<string,string> response_header = 8; // repeated message
|
||||||
|
//! uint64 thinking_output_tokens = 9;
|
||||||
|
//! uint64 response_output_tokens = 10;
|
||||||
|
//! string response_id = 11;
|
||||||
|
//! }
|
||||||
|
//! ```
|
||||||
|
|
||||||
|
use flate2::read::GzDecoder;
|
||||||
|
use std::io::Read;
|
||||||
|
use tracing::{debug, trace, warn};
|
||||||
|
|
||||||
|
/// A decoded protobuf field.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub enum ProtoValue {
|
||||||
|
Varint(u64),
|
||||||
|
#[allow(dead_code)]
|
||||||
|
Fixed64(u64),
|
||||||
|
#[allow(dead_code)]
|
||||||
|
Fixed32(u32),
|
||||||
|
Bytes(Vec<u8>),
|
||||||
|
/// Nested message (parsed recursively)
|
||||||
|
Message(Vec<ProtoField>),
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A single protobuf field with its number and value.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct ProtoField {
|
||||||
|
pub number: u32,
|
||||||
|
pub value: ProtoValue,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Extracted usage data from a gRPC response.
|
||||||
|
#[derive(Debug, Default)]
|
||||||
|
pub struct GrpcUsage {
|
||||||
|
pub input_tokens: u64,
|
||||||
|
pub output_tokens: u64,
|
||||||
|
pub thinking_output_tokens: u64,
|
||||||
|
pub response_output_tokens: u64,
|
||||||
|
pub cache_read_tokens: u64,
|
||||||
|
pub cache_write_tokens: u64,
|
||||||
|
pub model: Option<String>,
|
||||||
|
pub api_provider: Option<String>,
|
||||||
|
pub message_id: Option<String>,
|
||||||
|
pub response_id: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Extract gRPC message frames from a buffer.
|
||||||
|
///
|
||||||
|
/// A gRPC message is:
|
||||||
|
/// [1 byte compressed flag] [4 bytes length BE] [N bytes protobuf]
|
||||||
|
///
|
||||||
|
/// Multiple messages can be concatenated in a single buffer.
|
||||||
|
/// If compressed flag is 1, the message is gzip-decompressed.
|
||||||
|
pub fn extract_grpc_messages(data: &[u8]) -> Vec<Vec<u8>> {
|
||||||
|
let mut messages = Vec::new();
|
||||||
|
let mut offset = 0;
|
||||||
|
|
||||||
|
while offset + 5 <= data.len() {
|
||||||
|
let compressed = data[offset];
|
||||||
|
let length = u32::from_be_bytes([
|
||||||
|
data[offset + 1],
|
||||||
|
data[offset + 2],
|
||||||
|
data[offset + 3],
|
||||||
|
data[offset + 4],
|
||||||
|
]) as usize;
|
||||||
|
|
||||||
|
offset += 5;
|
||||||
|
|
||||||
|
if offset + length > data.len() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
let payload = &data[offset..offset + length];
|
||||||
|
|
||||||
|
if compressed == 1 {
|
||||||
|
// gzip-compressed frame
|
||||||
|
let mut decoder = GzDecoder::new(payload);
|
||||||
|
let mut decompressed = Vec::new();
|
||||||
|
match decoder.read_to_end(&mut decompressed) {
|
||||||
|
Ok(_) => messages.push(decompressed),
|
||||||
|
Err(e) => {
|
||||||
|
warn!(error = %e, "Proto: failed to decompress gRPC frame");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
messages.push(payload.to_vec());
|
||||||
|
}
|
||||||
|
|
||||||
|
offset += length;
|
||||||
|
}
|
||||||
|
|
||||||
|
messages
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Decode a protobuf message into a list of fields.
|
||||||
|
///
|
||||||
|
/// This is a best-effort decoder that handles the common wire types.
|
||||||
|
/// Embedded messages (wire type 2) are attempted to be parsed recursively.
|
||||||
|
pub fn decode_proto(data: &[u8]) -> Vec<ProtoField> {
|
||||||
|
let mut fields = Vec::new();
|
||||||
|
let mut offset = 0;
|
||||||
|
|
||||||
|
while offset < data.len() {
|
||||||
|
// Read tag (varint)
|
||||||
|
let (tag, bytes_read) = match read_varint(&data[offset..]) {
|
||||||
|
Some(v) => v,
|
||||||
|
None => break,
|
||||||
|
};
|
||||||
|
offset += bytes_read;
|
||||||
|
|
||||||
|
let field_number = (tag >> 3) as u32;
|
||||||
|
let wire_type = (tag & 0x07) as u8;
|
||||||
|
|
||||||
|
if field_number == 0 {
|
||||||
|
break; // invalid
|
||||||
|
}
|
||||||
|
|
||||||
|
let value = match wire_type {
|
||||||
|
0 => {
|
||||||
|
// Varint
|
||||||
|
let (val, bytes_read) = match read_varint(&data[offset..]) {
|
||||||
|
Some(v) => v,
|
||||||
|
None => break,
|
||||||
|
};
|
||||||
|
offset += bytes_read;
|
||||||
|
ProtoValue::Varint(val)
|
||||||
|
}
|
||||||
|
1 => {
|
||||||
|
// 64-bit fixed
|
||||||
|
if offset + 8 > data.len() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
let val = u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap());
|
||||||
|
offset += 8;
|
||||||
|
ProtoValue::Fixed64(val)
|
||||||
|
}
|
||||||
|
2 => {
|
||||||
|
// Length-delimited
|
||||||
|
let (len, bytes_read) = match read_varint(&data[offset..]) {
|
||||||
|
Some(v) => v,
|
||||||
|
None => break,
|
||||||
|
};
|
||||||
|
offset += bytes_read;
|
||||||
|
let len = len as usize;
|
||||||
|
|
||||||
|
if offset + len > data.len() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
let payload = &data[offset..offset + len];
|
||||||
|
offset += len;
|
||||||
|
|
||||||
|
// Try to parse as a nested message
|
||||||
|
let nested = decode_proto(payload);
|
||||||
|
if !nested.is_empty() && looks_like_valid_message(&nested, payload.len()) {
|
||||||
|
ProtoValue::Message(nested)
|
||||||
|
} else {
|
||||||
|
ProtoValue::Bytes(payload.to_vec())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
5 => {
|
||||||
|
// 32-bit fixed
|
||||||
|
if offset + 4 > data.len() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
let val = u32::from_le_bytes(data[offset..offset + 4].try_into().unwrap());
|
||||||
|
offset += 4;
|
||||||
|
ProtoValue::Fixed32(val)
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
// Unknown wire type — stop parsing
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
fields.push(ProtoField {
|
||||||
|
number: field_number,
|
||||||
|
value,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
fields
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Heuristic: does this list of fields look like a valid protobuf message?
|
||||||
|
/// (vs. a random string that happened to partially decode)
|
||||||
|
fn looks_like_valid_message(fields: &[ProtoField], original_len: usize) -> bool {
|
||||||
|
if fields.is_empty() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that field numbers are reasonable (< 10000)
|
||||||
|
let valid_numbers = fields.iter().all(|f| f.number < 10000);
|
||||||
|
if !valid_numbers {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we have very few fields relative to the data size, it's probably not a message
|
||||||
|
// (e.g., a long string that happened to have a valid first-field prefix)
|
||||||
|
if fields.len() == 1 && original_len > 100 {
|
||||||
|
// Single-field messages of >100 bytes are suspicious unless the field is bytes/message
|
||||||
|
match &fields[0].value {
|
||||||
|
ProtoValue::Bytes(_) | ProtoValue::Message(_) => true,
|
||||||
|
_ => false,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Read a varint from a byte slice. Returns (value, bytes_consumed).
|
||||||
|
pub fn read_varint(data: &[u8]) -> Option<(u64, usize)> {
|
||||||
|
let mut result: u64 = 0;
|
||||||
|
let mut shift = 0;
|
||||||
|
|
||||||
|
for (i, &byte) in data.iter().enumerate() {
|
||||||
|
if i >= 10 {
|
||||||
|
return None; // Too many bytes for a varint
|
||||||
|
}
|
||||||
|
|
||||||
|
result |= ((byte & 0x7F) as u64) << shift;
|
||||||
|
shift += 7;
|
||||||
|
|
||||||
|
if byte & 0x80 == 0 {
|
||||||
|
return Some((result, i + 1));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Search a decoded protobuf message tree for usage-like structures.
|
||||||
|
///
|
||||||
|
/// Uses the exact field numbers from the reverse-engineered ModelUsageStats schema:
|
||||||
|
///
|
||||||
|
/// field 1: model (enum/varint)
|
||||||
|
/// field 2: input_tokens (uint64)
|
||||||
|
/// field 3: output_tokens (uint64)
|
||||||
|
/// field 4: cache_write_tokens (uint64)
|
||||||
|
/// field 5: cache_read_tokens (uint64)
|
||||||
|
/// field 6: api_provider (enum/varint)
|
||||||
|
/// field 7: message_id (string)
|
||||||
|
/// field 8: response_header (map, repeated message)
|
||||||
|
/// field 9: thinking_output_tokens (uint64)
|
||||||
|
/// field 10: response_output_tokens (uint64)
|
||||||
|
/// field 11: response_id (string)
|
||||||
|
pub fn extract_usage_from_proto(fields: &[ProtoField]) -> Option<GrpcUsage> {
|
||||||
|
// Strategy: recursively search for any sub-message that looks like usage data
|
||||||
|
// Try this level first
|
||||||
|
if let Some(usage) = try_extract_usage(fields) {
|
||||||
|
return Some(usage);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Recurse into nested messages
|
||||||
|
for field in fields {
|
||||||
|
if let ProtoValue::Message(ref nested) = field.value {
|
||||||
|
if let Some(usage) = extract_usage_from_proto(nested) {
|
||||||
|
return Some(usage);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Try to extract usage from this specific set of fields.
|
||||||
|
///
|
||||||
|
/// Uses verified field numbers from the binary's embedded proto descriptor.
|
||||||
|
fn try_extract_usage(fields: &[ProtoField]) -> Option<GrpcUsage> {
|
||||||
|
// We need:
|
||||||
|
// - At least 2 varint fields with values in token range
|
||||||
|
// - Ideally field 2 (input_tokens) or field 3 (output_tokens) present
|
||||||
|
let varint_fields: Vec<_> = fields
|
||||||
|
.iter()
|
||||||
|
.filter(|f| matches!(f.value, ProtoValue::Varint(_)))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let string_fields: Vec<_> = fields
|
||||||
|
.iter()
|
||||||
|
.filter_map(|f| {
|
||||||
|
if let ProtoValue::Bytes(ref b) = f.value {
|
||||||
|
std::str::from_utf8(b).ok().map(|s| (f.number, s.to_string()))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
// Need at least 2 varint fields to be a candidate
|
||||||
|
if varint_fields.len() < 2 {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the varint values make sense as token counts
|
||||||
|
let plausible_token_count = |v: u64| v <= 10_000_000;
|
||||||
|
let plausible_varints = varint_fields
|
||||||
|
.iter()
|
||||||
|
.filter(|f| {
|
||||||
|
if let ProtoValue::Varint(v) = f.value {
|
||||||
|
plausible_token_count(v) && v > 0
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.count();
|
||||||
|
|
||||||
|
// Need at least 2 non-zero plausible values
|
||||||
|
if plausible_varints < 2 {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if there's a model-like string (field 7 = message_id or field 11 = response_id
|
||||||
|
// can contain model names, or model enum values map to known names)
|
||||||
|
let has_model_string = string_fields.iter().any(|(_, s)| {
|
||||||
|
s.contains("claude") || s.contains("gemini") || s.contains("gpt")
|
||||||
|
|| s.starts_with("models/") || s.contains("sonnet") || s.contains("opus")
|
||||||
|
|| s.contains("flash") || s.contains("pro")
|
||||||
|
});
|
||||||
|
|
||||||
|
// Check for fields at the known ModelUsageStats field numbers
|
||||||
|
let has_field_2 = fields.iter().any(|f| f.number == 2 && matches!(f.value, ProtoValue::Varint(_)));
|
||||||
|
let has_field_3 = fields.iter().any(|f| f.number == 3 && matches!(f.value, ProtoValue::Varint(_)));
|
||||||
|
|
||||||
|
// Strong signal: has both input and output token fields
|
||||||
|
let is_likely_usage = (has_field_2 && has_field_3) || has_model_string;
|
||||||
|
|
||||||
|
if !is_likely_usage && varint_fields.len() < 3 {
|
||||||
|
// Without strong signal, need more fields
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build usage from exact field numbers (verified from binary)
|
||||||
|
let mut usage = GrpcUsage::default();
|
||||||
|
|
||||||
|
for field in fields {
|
||||||
|
match &field.value {
|
||||||
|
ProtoValue::Varint(v) => {
|
||||||
|
let v = *v;
|
||||||
|
if !plausible_token_count(v) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
match field.number {
|
||||||
|
// field 1 = model enum (varint, not string!)
|
||||||
|
2 => usage.input_tokens = v,
|
||||||
|
3 => usage.output_tokens = v,
|
||||||
|
4 => usage.cache_write_tokens = v, // VERIFIED: field 4
|
||||||
|
5 => usage.cache_read_tokens = v, // VERIFIED: field 5
|
||||||
|
// field 6 = api_provider enum (varint)
|
||||||
|
9 => usage.thinking_output_tokens = v, // VERIFIED: field 9
|
||||||
|
10 => usage.response_output_tokens = v, // VERIFIED: field 10
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ProtoValue::Bytes(ref b) => {
|
||||||
|
if let Ok(s) = std::str::from_utf8(b) {
|
||||||
|
match field.number {
|
||||||
|
7 => usage.message_id = Some(s.to_string()),
|
||||||
|
11 => usage.response_id = Some(s.to_string()),
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Model and api_provider are enums (varints), not strings
|
||||||
|
// We can map known enum values later if needed
|
||||||
|
// For now, extract the enum value as a string representation
|
||||||
|
for field in fields {
|
||||||
|
if let ProtoValue::Varint(v) = &field.value {
|
||||||
|
match field.number {
|
||||||
|
1 => {
|
||||||
|
// Model enum — we don't have the mapping, store as number
|
||||||
|
usage.model = Some(format!("model_enum_{v}"));
|
||||||
|
}
|
||||||
|
6 => {
|
||||||
|
// APIProvider enum
|
||||||
|
usage.api_provider = Some(match *v {
|
||||||
|
0 => "unknown".to_string(),
|
||||||
|
1 => "google".to_string(),
|
||||||
|
2 => "anthropic".to_string(),
|
||||||
|
_ => format!("provider_{v}"),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate — we should have at least input OR output tokens
|
||||||
|
if usage.input_tokens == 0 && usage.output_tokens == 0 {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
debug!(
|
||||||
|
input = usage.input_tokens,
|
||||||
|
output = usage.output_tokens,
|
||||||
|
thinking = usage.thinking_output_tokens,
|
||||||
|
response = usage.response_output_tokens,
|
||||||
|
cache_read = usage.cache_read_tokens,
|
||||||
|
cache_write = usage.cache_write_tokens,
|
||||||
|
model = ?usage.model,
|
||||||
|
api_provider = ?usage.api_provider,
|
||||||
|
"Proto: extracted ModelUsageStats from protobuf"
|
||||||
|
);
|
||||||
|
|
||||||
|
Some(usage)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parse a gRPC response body (may contain multiple messages) for usage data.
|
||||||
|
///
|
||||||
|
/// Handles both compressed and uncompressed gRPC frames.
|
||||||
|
pub fn parse_grpc_response_for_usage(body: &[u8]) -> Option<GrpcUsage> {
|
||||||
|
let messages = extract_grpc_messages(body);
|
||||||
|
|
||||||
|
trace!(count = messages.len(), "Proto: extracted gRPC messages");
|
||||||
|
|
||||||
|
// Check each message for usage data (last message usually has it)
|
||||||
|
for msg in messages.iter().rev() {
|
||||||
|
let fields = decode_proto(msg);
|
||||||
|
if let Some(usage) = extract_usage_from_proto(&fields) {
|
||||||
|
return Some(usage);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_read_varint() {
|
||||||
|
assert_eq!(read_varint(&[0x00]), Some((0, 1)));
|
||||||
|
assert_eq!(read_varint(&[0x01]), Some((1, 1)));
|
||||||
|
assert_eq!(read_varint(&[0x96, 0x01]), Some((150, 2)));
|
||||||
|
assert_eq!(read_varint(&[0xAC, 0x02]), Some((300, 2)));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_extract_grpc_messages_uncompressed() {
|
||||||
|
// Construct a test gRPC frame: [0x00] [0x00, 0x00, 0x00, 0x05] [5 bytes data]
|
||||||
|
let mut buf = vec![0u8]; // not compressed
|
||||||
|
buf.extend_from_slice(&5u32.to_be_bytes());
|
||||||
|
buf.extend_from_slice(&[0x08, 0x96, 0x01, 0x10, 0x42]); // field 1 varint 150, field 2 varint 66
|
||||||
|
|
||||||
|
let messages = extract_grpc_messages(&buf);
|
||||||
|
assert_eq!(messages.len(), 1);
|
||||||
|
assert_eq!(messages[0].len(), 5);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_extract_grpc_messages_compressed() {
|
||||||
|
use flate2::write::GzEncoder;
|
||||||
|
use flate2::Compression;
|
||||||
|
use std::io::Write;
|
||||||
|
|
||||||
|
// Create a payload
|
||||||
|
let payload = vec![0x08, 0x96, 0x01, 0x10, 0x42];
|
||||||
|
|
||||||
|
// Compress it
|
||||||
|
let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
|
||||||
|
encoder.write_all(&payload).unwrap();
|
||||||
|
let compressed = encoder.finish().unwrap();
|
||||||
|
|
||||||
|
// Build gRPC frame with compressed flag
|
||||||
|
let mut buf = vec![1u8]; // compressed
|
||||||
|
buf.extend_from_slice(&(compressed.len() as u32).to_be_bytes());
|
||||||
|
buf.extend_from_slice(&compressed);
|
||||||
|
|
||||||
|
let messages = extract_grpc_messages(&buf);
|
||||||
|
assert_eq!(messages.len(), 1);
|
||||||
|
assert_eq!(messages[0], payload);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_decode_proto_varints() {
|
||||||
|
// field 1 = 150, field 2 = 66
|
||||||
|
let data = [0x08, 0x96, 0x01, 0x10, 0x42];
|
||||||
|
let fields = decode_proto(&data);
|
||||||
|
assert_eq!(fields.len(), 2);
|
||||||
|
assert_eq!(fields[0].number, 1);
|
||||||
|
assert!(matches!(fields[0].value, ProtoValue::Varint(150)));
|
||||||
|
assert_eq!(fields[1].number, 2);
|
||||||
|
assert!(matches!(fields[1].value, ProtoValue::Varint(66)));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_decode_proto_with_string() {
|
||||||
|
// field 1 = "hello" (string), field 2 = varint 42
|
||||||
|
let mut data = Vec::new();
|
||||||
|
// field 1, wire type 2 (length-delimited)
|
||||||
|
data.push(0x0A); // (1 << 3) | 2
|
||||||
|
data.push(0x05); // length 5
|
||||||
|
data.extend_from_slice(b"hello");
|
||||||
|
// field 2, wire type 0 (varint)
|
||||||
|
data.push(0x10); // (2 << 3) | 0
|
||||||
|
data.push(0x2A); // 42
|
||||||
|
|
||||||
|
let fields = decode_proto(&data);
|
||||||
|
assert!(fields.len() >= 2);
|
||||||
|
assert_eq!(fields[0].number, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_extract_usage_correct_field_numbers() {
|
||||||
|
// Build a mock ModelUsageStats with the correct field numbers:
|
||||||
|
// field 1 (model enum) = 5 (some model)
|
||||||
|
// field 2 (input_tokens) = 1000
|
||||||
|
// field 3 (output_tokens) = 500
|
||||||
|
// field 4 (cache_write_tokens) = 100
|
||||||
|
// field 5 (cache_read_tokens) = 200
|
||||||
|
// field 9 (thinking_output_tokens) = 300
|
||||||
|
// field 10 (response_output_tokens) = 200
|
||||||
|
let mut data = Vec::new();
|
||||||
|
|
||||||
|
// Helper: encode varint field
|
||||||
|
fn encode_varint_field(data: &mut Vec<u8>, field_num: u32, value: u64) {
|
||||||
|
// Tag
|
||||||
|
let tag = (field_num << 3) | 0; // wire type 0
|
||||||
|
let mut t = tag;
|
||||||
|
while t >= 0x80 {
|
||||||
|
data.push((t as u8) | 0x80);
|
||||||
|
t >>= 7;
|
||||||
|
}
|
||||||
|
data.push(t as u8);
|
||||||
|
// Value
|
||||||
|
let mut v = value;
|
||||||
|
while v >= 0x80 {
|
||||||
|
data.push((v as u8) | 0x80);
|
||||||
|
v >>= 7;
|
||||||
|
}
|
||||||
|
data.push(v as u8);
|
||||||
|
}
|
||||||
|
|
||||||
|
encode_varint_field(&mut data, 1, 5); // model enum
|
||||||
|
encode_varint_field(&mut data, 2, 1000); // input_tokens
|
||||||
|
encode_varint_field(&mut data, 3, 500); // output_tokens
|
||||||
|
encode_varint_field(&mut data, 4, 100); // cache_write_tokens
|
||||||
|
encode_varint_field(&mut data, 5, 200); // cache_read_tokens
|
||||||
|
encode_varint_field(&mut data, 9, 300); // thinking_output_tokens
|
||||||
|
encode_varint_field(&mut data, 10, 200); // response_output_tokens
|
||||||
|
|
||||||
|
let fields = decode_proto(&data);
|
||||||
|
let usage = try_extract_usage(&fields).expect("should extract usage");
|
||||||
|
|
||||||
|
assert_eq!(usage.input_tokens, 1000);
|
||||||
|
assert_eq!(usage.output_tokens, 500);
|
||||||
|
assert_eq!(usage.cache_write_tokens, 100);
|
||||||
|
assert_eq!(usage.cache_read_tokens, 200);
|
||||||
|
assert_eq!(usage.thinking_output_tokens, 300);
|
||||||
|
assert_eq!(usage.response_output_tokens, 200);
|
||||||
|
}
|
||||||
|
}
|
||||||
591
src/mitm/proxy.rs
Normal file
591
src/mitm/proxy.rs
Normal file
@@ -0,0 +1,591 @@
|
|||||||
|
//! MITM proxy server: handles CONNECT tunnels and TLS interception.
|
||||||
|
//!
|
||||||
|
//! Listens on a local port for HTTP CONNECT requests from the LS.
|
||||||
|
//! For intercepted domains, it terminates TLS with our CA-signed cert,
|
||||||
|
//! reads/modifies the request, forwards to the real upstream, and captures
|
||||||
|
//! the response (especially usage data).
|
||||||
|
//!
|
||||||
|
//! For non-intercepted domains, it acts as a transparent TCP tunnel.
|
||||||
|
|
||||||
|
use super::ca::MitmCa;
|
||||||
|
use super::intercept::{
|
||||||
|
extract_cascade_hint, parse_non_streaming_response, parse_streaming_chunk,
|
||||||
|
StreamingAccumulator,
|
||||||
|
};
|
||||||
|
use super::store::MitmStore;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||||
|
use tokio::net::{TcpListener, TcpStream};
|
||||||
|
use tokio_rustls::TlsAcceptor;
|
||||||
|
use tracing::{debug, error, info, trace, warn};
|
||||||
|
|
||||||
|
/// Domains we intercept (terminate TLS and inspect traffic).
|
||||||
|
/// This includes exact matches and suffix matches for regional endpoints
|
||||||
|
/// (e.g., us-central1-aiplatform.googleapis.com).
|
||||||
|
const INTERCEPT_DOMAINS: &[&str] = &[
|
||||||
|
"cloudcode-pa.googleapis.com",
|
||||||
|
"aiplatform.googleapis.com",
|
||||||
|
"api.anthropic.com",
|
||||||
|
"speech.googleapis.com",
|
||||||
|
"modelarmor.googleapis.com",
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Domains we NEVER intercept (transparent tunnel).
|
||||||
|
const PASSTHROUGH_DOMAINS: &[&str] = &[
|
||||||
|
"oauth2.googleapis.com",
|
||||||
|
"accounts.google.com",
|
||||||
|
"storage.googleapis.com",
|
||||||
|
"www.googleapis.com",
|
||||||
|
"firebaseinstallations.googleapis.com",
|
||||||
|
"crashlyticsreports-pa.googleapis.com",
|
||||||
|
"play.googleapis.com",
|
||||||
|
"update.googleapis.com",
|
||||||
|
"dl.google.com",
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Configuration for the MITM proxy.
|
||||||
|
pub struct MitmConfig {
|
||||||
|
/// Port to listen on (0 = auto-assign).
|
||||||
|
pub port: u16,
|
||||||
|
/// Whether to enable request modification.
|
||||||
|
pub modify_requests: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for MitmConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
port: 0,
|
||||||
|
modify_requests: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Run the MITM proxy server.
|
||||||
|
///
|
||||||
|
/// Returns (port, task_handle) — port it's listening on, handle to abort on shutdown.
|
||||||
|
pub async fn run(
|
||||||
|
ca: Arc<MitmCa>,
|
||||||
|
store: MitmStore,
|
||||||
|
config: MitmConfig,
|
||||||
|
) -> Result<(u16, tokio::task::JoinHandle<()>), String> {
|
||||||
|
let listener = TcpListener::bind(format!("127.0.0.1:{}", config.port))
|
||||||
|
.await
|
||||||
|
.map_err(|e| format!("MITM bind failed: {e}"))?;
|
||||||
|
|
||||||
|
let port = listener
|
||||||
|
.local_addr()
|
||||||
|
.map_err(|e| format!("MITM local_addr failed: {e}"))?
|
||||||
|
.port();
|
||||||
|
|
||||||
|
info!(port, "MITM proxy listening");
|
||||||
|
|
||||||
|
let modify_requests = config.modify_requests;
|
||||||
|
|
||||||
|
let handle = tokio::spawn(async move {
|
||||||
|
loop {
|
||||||
|
match listener.accept().await {
|
||||||
|
Ok((stream, addr)) => {
|
||||||
|
trace!(?addr, "MITM: new connection");
|
||||||
|
let ca = ca.clone();
|
||||||
|
let store = store.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
if let Err(e) = handle_connection(stream, ca, store, modify_requests).await {
|
||||||
|
debug!(error = %e, "MITM connection error");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
error!(error = %e, "MITM accept error");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok((port, handle))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Handle a single incoming connection from the LS.
|
||||||
|
///
|
||||||
|
/// The LS sends an HTTP CONNECT request to establish a tunnel.
|
||||||
|
/// We then decide whether to intercept or passthrough.
|
||||||
|
async fn handle_connection(
|
||||||
|
mut stream: TcpStream,
|
||||||
|
ca: Arc<MitmCa>,
|
||||||
|
store: MitmStore,
|
||||||
|
modify_requests: bool,
|
||||||
|
) -> Result<(), String> {
|
||||||
|
// Read the CONNECT request
|
||||||
|
let mut buf = vec![0u8; 8192];
|
||||||
|
let n = stream
|
||||||
|
.read(&mut buf)
|
||||||
|
.await
|
||||||
|
.map_err(|e| format!("Read CONNECT: {e}"))?;
|
||||||
|
|
||||||
|
if n == 0 {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let request = String::from_utf8_lossy(&buf[..n]);
|
||||||
|
let first_line = request.lines().next().unwrap_or("");
|
||||||
|
|
||||||
|
// Parse "CONNECT host:port HTTP/1.1"
|
||||||
|
let parts: Vec<&str> = first_line.split_whitespace().collect();
|
||||||
|
if parts.len() < 3 || parts[0] != "CONNECT" {
|
||||||
|
// Not a CONNECT request — return 400
|
||||||
|
let resp = "HTTP/1.1 400 Bad Request\r\n\r\n";
|
||||||
|
let _ = stream.write_all(resp.as_bytes()).await;
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let host_port = parts[1];
|
||||||
|
let (domain, _port) = match host_port.rsplit_once(':') {
|
||||||
|
Some((h, p)) => (h, p.parse::<u16>().unwrap_or(443)),
|
||||||
|
None => (host_port, 443),
|
||||||
|
};
|
||||||
|
|
||||||
|
debug!(domain, "MITM: CONNECT request");
|
||||||
|
|
||||||
|
// Decide: intercept or passthrough
|
||||||
|
let should_intercept = should_intercept_domain(domain);
|
||||||
|
|
||||||
|
// Send 200 Connection Established
|
||||||
|
let response = "HTTP/1.1 200 Connection Established\r\n\r\n";
|
||||||
|
stream
|
||||||
|
.write_all(response.as_bytes())
|
||||||
|
.await
|
||||||
|
.map_err(|e| format!("Write 200: {e}"))?;
|
||||||
|
|
||||||
|
if should_intercept {
|
||||||
|
handle_intercepted(stream, domain, ca, store, modify_requests).await
|
||||||
|
} else {
|
||||||
|
handle_passthrough(stream, domain, _port).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if a domain should be intercepted.
|
||||||
|
fn should_intercept_domain(domain: &str) -> bool {
|
||||||
|
// Never intercept passthrough domains
|
||||||
|
for &pt in PASSTHROUGH_DOMAINS {
|
||||||
|
if domain == pt {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Intercept known API domains (exact match, subdomain, or regional prefix)
|
||||||
|
for &intercept in INTERCEPT_DOMAINS {
|
||||||
|
if domain == intercept
|
||||||
|
|| domain.ends_with(&format!(".{intercept}"))
|
||||||
|
|| domain.ends_with(&format!("-{intercept}"))
|
||||||
|
{
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default: passthrough
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Handle an intercepted connection: terminate TLS, inspect traffic.
|
||||||
|
///
|
||||||
|
/// After TLS termination, checks the negotiated ALPN protocol:
|
||||||
|
/// - `h2` → HTTP/2 handler (for gRPC traffic to Google APIs)
|
||||||
|
/// - `http/1.1` or none → HTTP/1.1 handler (for REST/SSE traffic)
|
||||||
|
async fn handle_intercepted(
|
||||||
|
stream: TcpStream,
|
||||||
|
domain: &str,
|
||||||
|
ca: Arc<MitmCa>,
|
||||||
|
store: MitmStore,
|
||||||
|
modify_requests: bool,
|
||||||
|
) -> Result<(), String> {
|
||||||
|
info!(domain, "MITM: intercepting TLS");
|
||||||
|
|
||||||
|
// Get or create server TLS config for this domain
|
||||||
|
let server_config = ca
|
||||||
|
.server_config_for_domain(domain)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let acceptor = TlsAcceptor::from(server_config);
|
||||||
|
|
||||||
|
// Perform TLS handshake with the client (LS)
|
||||||
|
let tls_stream = acceptor
|
||||||
|
.accept(stream)
|
||||||
|
.await
|
||||||
|
.map_err(|e| format!("TLS handshake with client failed for {domain}: {e}"))?;
|
||||||
|
|
||||||
|
// Check negotiated ALPN protocol
|
||||||
|
let alpn = tls_stream.get_ref().1
|
||||||
|
.alpn_protocol()
|
||||||
|
.map(|p| String::from_utf8_lossy(p).to_string());
|
||||||
|
|
||||||
|
debug!(domain, alpn = ?alpn, "MITM: TLS handshake successful");
|
||||||
|
|
||||||
|
match alpn.as_deref() {
|
||||||
|
Some("h2") => {
|
||||||
|
// HTTP/2 — use the hyper-based gRPC handler
|
||||||
|
info!(domain, "MITM: routing to HTTP/2 handler (gRPC)");
|
||||||
|
super::h2_handler::handle_h2_connection(
|
||||||
|
tls_stream,
|
||||||
|
domain.to_string(),
|
||||||
|
store,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
// HTTP/1.1 or no ALPN — use the existing handler
|
||||||
|
debug!(domain, "MITM: routing to HTTP/1.1 handler");
|
||||||
|
handle_http_over_tls(tls_stream, domain, store, modify_requests).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Handle HTTP traffic over the decrypted TLS connection.
|
||||||
|
///
|
||||||
|
/// Loops to handle multiple requests on the same connection (HTTP keep-alive).
|
||||||
|
/// Reads full request, connects to upstream, forwards request, streams response
|
||||||
|
/// back to client while capturing usage data.
|
||||||
|
async fn handle_http_over_tls(
|
||||||
|
mut client: tokio_rustls::server::TlsStream<TcpStream>,
|
||||||
|
domain: &str,
|
||||||
|
store: MitmStore,
|
||||||
|
_modify_requests: bool,
|
||||||
|
) -> Result<(), String> {
|
||||||
|
let mut tmp = vec![0u8; 32768];
|
||||||
|
|
||||||
|
// Build upstream TLS connector once for this connection
|
||||||
|
let mut root_store = rustls::RootCertStore::empty();
|
||||||
|
let native_certs = rustls_native_certs::load_native_certs();
|
||||||
|
for cert in native_certs.certs {
|
||||||
|
let _ = root_store.add(cert);
|
||||||
|
}
|
||||||
|
let upstream_config = Arc::new(
|
||||||
|
rustls::ClientConfig::builder()
|
||||||
|
.with_root_certificates(root_store)
|
||||||
|
.with_no_client_auth(),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Reusable upstream connection — created lazily, reconnected if stale
|
||||||
|
let mut upstream: Option<tokio_rustls::client::TlsStream<TcpStream>> = None;
|
||||||
|
|
||||||
|
/// Connect (or reconnect) to the real upstream via TLS.
|
||||||
|
async fn connect_upstream(
|
||||||
|
domain: &str,
|
||||||
|
config: &Arc<rustls::ClientConfig>,
|
||||||
|
) -> Result<tokio_rustls::client::TlsStream<TcpStream>, String> {
|
||||||
|
let connector = tokio_rustls::TlsConnector::from(config.clone());
|
||||||
|
let tcp = TcpStream::connect(format!("{domain}:443"))
|
||||||
|
.await
|
||||||
|
.map_err(|e| format!("Connect to upstream {domain}: {e}"))?;
|
||||||
|
let server_name = rustls::pki_types::ServerName::try_from(domain.to_string())
|
||||||
|
.map_err(|e| format!("Invalid server name: {e}"))?;
|
||||||
|
connector
|
||||||
|
.connect(server_name, tcp)
|
||||||
|
.await
|
||||||
|
.map_err(|e| format!("TLS connect to upstream {domain}: {e}"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Keep-alive loop: handle multiple requests on this connection
|
||||||
|
loop {
|
||||||
|
// ── Read the HTTP request from the client ─────────────────────────
|
||||||
|
let mut request_buf = Vec::with_capacity(1024 * 64);
|
||||||
|
|
||||||
|
loop {
|
||||||
|
let n = match client.read(&mut tmp).await {
|
||||||
|
Ok(0) => return Ok(()), // Client closed connection cleanly
|
||||||
|
Ok(n) => n,
|
||||||
|
Err(e) => {
|
||||||
|
// Connection reset / broken pipe is normal for keep-alive end
|
||||||
|
debug!(domain, error = %e, "MITM: client read finished");
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
request_buf.extend_from_slice(&tmp[..n]);
|
||||||
|
|
||||||
|
// Check if we have the full request (headers + body)
|
||||||
|
if has_complete_http_request(&request_buf) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if request_buf.is_empty() {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the HTTP request to find headers and body
|
||||||
|
let (headers_end, content_length, is_streaming_request) = parse_http_request_meta(&request_buf);
|
||||||
|
|
||||||
|
// Try to extract cascade hint from request body
|
||||||
|
let cascade_hint = if headers_end < request_buf.len() {
|
||||||
|
extract_cascade_hint(&request_buf[headers_end..])
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
debug!(
|
||||||
|
domain,
|
||||||
|
content_length,
|
||||||
|
streaming = is_streaming_request,
|
||||||
|
cascade = ?cascade_hint,
|
||||||
|
"MITM: forwarding request to upstream"
|
||||||
|
);
|
||||||
|
|
||||||
|
// ── Ensure upstream connection is alive ──────────────────────────────
|
||||||
|
// Lazily connect on first request, or reconnect if the previous connection died
|
||||||
|
let conn = match upstream.as_mut() {
|
||||||
|
Some(c) => c,
|
||||||
|
None => {
|
||||||
|
let c = connect_upstream(domain, &upstream_config).await?;
|
||||||
|
upstream.insert(c)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Forward the request — if write fails, reconnect and retry once
|
||||||
|
if let Err(e) = conn.write_all(&request_buf).await {
|
||||||
|
debug!(domain, error = %e, "MITM: upstream write failed, reconnecting");
|
||||||
|
let c = connect_upstream(domain, &upstream_config).await?;
|
||||||
|
let conn = upstream.insert(c);
|
||||||
|
conn.write_all(&request_buf)
|
||||||
|
.await
|
||||||
|
.map_err(|e| format!("Write to upstream (retry): {e}"))?;
|
||||||
|
}
|
||||||
|
|
||||||
|
let conn = upstream.as_mut().unwrap();
|
||||||
|
|
||||||
|
// ── Stream response back to client ──────────────────────────────────
|
||||||
|
let mut streaming_acc = StreamingAccumulator::new();
|
||||||
|
let mut is_streaming_response = false;
|
||||||
|
let mut headers_parsed = false;
|
||||||
|
// Only buffer response body for non-streaming (for usage parsing)
|
||||||
|
let mut non_streaming_buf: Option<Vec<u8>> = None;
|
||||||
|
// Track if upstream connection is still usable after this response
|
||||||
|
let mut upstream_ok = true;
|
||||||
|
|
||||||
|
// Per-request timeout: 5 minutes (covers large context API calls)
|
||||||
|
const READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(300);
|
||||||
|
|
||||||
|
loop {
|
||||||
|
let n = match tokio::time::timeout(READ_TIMEOUT, conn.read(&mut tmp)).await {
|
||||||
|
Ok(Ok(0)) => {
|
||||||
|
// Upstream closed — connection is no longer reusable
|
||||||
|
upstream_ok = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
Ok(Ok(n)) => n,
|
||||||
|
Ok(Err(e)) => {
|
||||||
|
debug!(domain, error = %e, "MITM: upstream read finished");
|
||||||
|
upstream_ok = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
Err(_) => {
|
||||||
|
warn!(domain, "MITM: upstream read timed out after 5 minutes");
|
||||||
|
upstream_ok = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let chunk = &tmp[..n];
|
||||||
|
|
||||||
|
// Check response headers for content-type
|
||||||
|
if !headers_parsed {
|
||||||
|
// We need to buffer until we see the end of headers
|
||||||
|
let buf = non_streaming_buf.get_or_insert_with(|| Vec::with_capacity(1024 * 64));
|
||||||
|
buf.extend_from_slice(chunk);
|
||||||
|
if let Some(_hdr_end) = find_headers_end(buf) {
|
||||||
|
// Use httparse for response header parsing
|
||||||
|
let mut resp_headers = [httparse::EMPTY_HEADER; 64];
|
||||||
|
let mut resp = httparse::Response::new(&mut resp_headers);
|
||||||
|
let hdr_end = match resp.parse(buf) {
|
||||||
|
Ok(httparse::Status::Complete(n)) => n,
|
||||||
|
_ => _hdr_end, // Fallback to manual detection
|
||||||
|
};
|
||||||
|
|
||||||
|
// Detect content type and connection handling from parsed headers
|
||||||
|
for header in resp.headers.iter() {
|
||||||
|
if header.name.eq_ignore_ascii_case("content-type") {
|
||||||
|
if let Ok(val) = std::str::from_utf8(header.value) {
|
||||||
|
if val.contains("text/event-stream") {
|
||||||
|
is_streaming_response = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if header.name.eq_ignore_ascii_case("connection") {
|
||||||
|
if let Ok(val) = std::str::from_utf8(header.value) {
|
||||||
|
if val.trim().eq_ignore_ascii_case("close") {
|
||||||
|
upstream_ok = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
headers_parsed = true;
|
||||||
|
|
||||||
|
if is_streaming_response {
|
||||||
|
// For streaming, parse any SSE data already in the buffer
|
||||||
|
let body_so_far = String::from_utf8_lossy(&buf[hdr_end..]);
|
||||||
|
if !body_so_far.is_empty() {
|
||||||
|
parse_streaming_chunk(&body_so_far, &mut streaming_acc);
|
||||||
|
}
|
||||||
|
// Forward the accumulated buffer to client
|
||||||
|
if let Err(e) = client.write_all(buf).await {
|
||||||
|
warn!(error = %e, "MITM: write to client failed");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
non_streaming_buf = None;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
// Non-streaming: keep buffering the response body for parsing
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If streaming, parse SSE events and forward immediately
|
||||||
|
if is_streaming_response {
|
||||||
|
let chunk_str = String::from_utf8_lossy(chunk);
|
||||||
|
parse_streaming_chunk(&chunk_str, &mut streaming_acc);
|
||||||
|
|
||||||
|
if let Err(e) = client.write_all(chunk).await {
|
||||||
|
warn!(error = %e, "MITM: write to client failed (client disconnected?)");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Non-streaming: keep accumulating to parse usage at the end
|
||||||
|
if let Some(ref mut buf) = non_streaming_buf {
|
||||||
|
buf.extend_from_slice(chunk);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Forward non-streaming response all at once
|
||||||
|
if !is_streaming_response {
|
||||||
|
if let Some(ref buf) = non_streaming_buf {
|
||||||
|
if let Err(e) = client.write_all(buf).await {
|
||||||
|
warn!(error = %e, "MITM: write to client failed");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Capture usage data
|
||||||
|
if is_streaming_response {
|
||||||
|
if streaming_acc.is_complete || streaming_acc.output_tokens > 0 {
|
||||||
|
let usage = streaming_acc.into_usage();
|
||||||
|
store.record_usage(cascade_hint.as_deref(), usage).await;
|
||||||
|
}
|
||||||
|
} else if let Some(ref buf) = non_streaming_buf {
|
||||||
|
if let Some(body_start) = find_headers_end(buf) {
|
||||||
|
let body = &buf[body_start..];
|
||||||
|
if let Some(usage) = parse_non_streaming_response(body) {
|
||||||
|
store.record_usage(cascade_hint.as_deref(), usage).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If upstream closed, drop the connection so next iteration reconnects
|
||||||
|
if !upstream_ok {
|
||||||
|
upstream = None;
|
||||||
|
}
|
||||||
|
} // end keep-alive loop
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Handle a passthrough connection: transparent TCP tunnel to upstream.
|
||||||
|
async fn handle_passthrough(
|
||||||
|
mut client: TcpStream,
|
||||||
|
domain: &str,
|
||||||
|
port: u16,
|
||||||
|
) -> Result<(), String> {
|
||||||
|
trace!(domain, port, "MITM: transparent tunnel");
|
||||||
|
|
||||||
|
let mut upstream = TcpStream::connect(format!("{domain}:{port}"))
|
||||||
|
.await
|
||||||
|
.map_err(|e| format!("Connect to {domain}:{port}: {e}"))?;
|
||||||
|
|
||||||
|
// Bidirectional copy
|
||||||
|
match tokio::io::copy_bidirectional(&mut client, &mut upstream).await {
|
||||||
|
Ok((client_to_server, server_to_client)) => {
|
||||||
|
trace!(domain, client_to_server, server_to_client, "MITM: tunnel closed");
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
trace!(domain, error = %e, "MITM: tunnel error (likely clean close)");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if buffer contains a complete HTTP request (headers + full body).
|
||||||
|
/// Uses `httparse` for zero-copy, case-insensitive header parsing.
|
||||||
|
fn has_complete_http_request(buf: &[u8]) -> bool {
|
||||||
|
let mut headers = [httparse::EMPTY_HEADER; 64];
|
||||||
|
let mut req = httparse::Request::new(&mut headers);
|
||||||
|
|
||||||
|
let headers_end = match req.parse(buf) {
|
||||||
|
Ok(httparse::Status::Complete(n)) => n,
|
||||||
|
_ => return false, // Incomplete or parse error — need more data
|
||||||
|
};
|
||||||
|
|
||||||
|
// Look for Content-Length
|
||||||
|
for header in req.headers.iter() {
|
||||||
|
if header.name.eq_ignore_ascii_case("content-length") {
|
||||||
|
if let Ok(val) = std::str::from_utf8(header.value) {
|
||||||
|
if let Ok(len) = val.trim().parse::<usize>() {
|
||||||
|
return buf.len() >= headers_end + len;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if header.name.eq_ignore_ascii_case("transfer-encoding") {
|
||||||
|
if let Ok(val) = std::str::from_utf8(header.value) {
|
||||||
|
if val.trim().eq_ignore_ascii_case("chunked") {
|
||||||
|
let body = &buf[headers_end..];
|
||||||
|
return body.len() >= 5 && body.ends_with(b"0\r\n\r\n");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// No Content-Length or Transfer-Encoding — no body expected (e.g., GET)
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Find the end of HTTP headers (position after \r\n\r\n).
|
||||||
|
fn find_headers_end(buf: &[u8]) -> Option<usize> {
|
||||||
|
buf.windows(4)
|
||||||
|
.position(|w| w == b"\r\n\r\n")
|
||||||
|
.map(|pos| pos + 4)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parse HTTP request metadata from raw bytes using `httparse`.
|
||||||
|
/// Returns (headers_end, content_length, is_streaming_request).
|
||||||
|
fn parse_http_request_meta(buf: &[u8]) -> (usize, usize, bool) {
|
||||||
|
let mut headers = [httparse::EMPTY_HEADER; 64];
|
||||||
|
let mut req = httparse::Request::new(&mut headers);
|
||||||
|
|
||||||
|
let headers_end = match req.parse(buf) {
|
||||||
|
Ok(httparse::Status::Complete(n)) => n,
|
||||||
|
_ => {
|
||||||
|
// Fallback if httparse can't parse
|
||||||
|
return (find_headers_end(buf).unwrap_or(buf.len()), 0, false);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut content_length = 0usize;
|
||||||
|
|
||||||
|
for header in req.headers.iter() {
|
||||||
|
if header.name.eq_ignore_ascii_case("content-length") {
|
||||||
|
if let Ok(val) = std::str::from_utf8(header.value) {
|
||||||
|
content_length = val.trim().parse().unwrap_or(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if request body asks for streaming
|
||||||
|
let is_streaming = if headers_end < buf.len() {
|
||||||
|
let body_str = String::from_utf8_lossy(&buf[headers_end..]);
|
||||||
|
body_str.contains("\"stream\":true") || body_str.contains("\"stream\": true")
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
};
|
||||||
|
|
||||||
|
(headers_end, content_length, is_streaming)
|
||||||
|
}
|
||||||
|
|
||||||
163
src/mitm/store.rs
Normal file
163
src/mitm/store.rs
Normal file
@@ -0,0 +1,163 @@
|
|||||||
|
//! Shared store for intercepted API usage data.
|
||||||
|
//!
|
||||||
|
//! The MITM proxy writes usage data here; the API handlers read from it.
|
||||||
|
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tokio::sync::RwLock;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use tracing::debug;
|
||||||
|
|
||||||
|
/// Token usage from an intercepted API response.
|
||||||
|
///
|
||||||
|
/// Covers both Anthropic JSON/SSE responses and Google gRPC protobuf responses.
|
||||||
|
/// Fields map to the superset of Anthropic's `usage` object and Google's `ModelUsageStats` proto.
|
||||||
|
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||||
|
pub struct ApiUsage {
|
||||||
|
pub input_tokens: u64,
|
||||||
|
pub output_tokens: u64,
|
||||||
|
/// Anthropic: cache_creation_input_tokens / Google: cache_write_tokens
|
||||||
|
pub cache_creation_input_tokens: u64,
|
||||||
|
/// Anthropic: cache_read_input_tokens / Google: cache_read_tokens
|
||||||
|
pub cache_read_input_tokens: u64,
|
||||||
|
/// Google-specific: thinking/reasoning output tokens (extended thinking)
|
||||||
|
pub thinking_output_tokens: u64,
|
||||||
|
/// Google-specific: response output tokens (non-thinking portion)
|
||||||
|
pub response_output_tokens: u64,
|
||||||
|
/// Total cost in USD (if provided by the API).
|
||||||
|
pub total_cost_usd: Option<f64>,
|
||||||
|
/// The actual model that served the request.
|
||||||
|
pub model: Option<String>,
|
||||||
|
/// Stop reason / finish reason from the API.
|
||||||
|
pub stop_reason: Option<String>,
|
||||||
|
/// API provider (e.g. "anthropic", "google")
|
||||||
|
pub api_provider: Option<String>,
|
||||||
|
/// gRPC method path (e.g. "/google.internal.cloud.code.v1internal.PredictionService/GenerateContent")
|
||||||
|
pub grpc_method: Option<String>,
|
||||||
|
/// Timestamp when this usage was captured.
|
||||||
|
pub captured_at: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Thread-safe store for intercepted data.
|
||||||
|
///
|
||||||
|
/// Keyed by a unique request ID that we can correlate with cascade operations.
|
||||||
|
/// In practice, we use the cascade ID + a sequence number.
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct MitmStore {
|
||||||
|
/// Most recent usage per cascade ID.
|
||||||
|
latest_usage: Arc<RwLock<HashMap<String, ApiUsage>>>,
|
||||||
|
/// Global aggregate stats.
|
||||||
|
stats: Arc<RwLock<MitmStats>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Aggregate statistics across all intercepted traffic.
|
||||||
|
#[derive(Debug, Clone, Default, Serialize)]
|
||||||
|
pub struct MitmStats {
|
||||||
|
pub total_requests: u64,
|
||||||
|
pub total_input_tokens: u64,
|
||||||
|
pub total_output_tokens: u64,
|
||||||
|
pub total_cache_read_tokens: u64,
|
||||||
|
pub total_cache_creation_tokens: u64,
|
||||||
|
pub total_thinking_output_tokens: u64,
|
||||||
|
pub total_response_output_tokens: u64,
|
||||||
|
/// Per-model usage breakdown (model name → stats).
|
||||||
|
pub per_model: HashMap<String, ModelStats>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Per-model usage counters.
|
||||||
|
#[derive(Debug, Clone, Default, Serialize)]
|
||||||
|
pub struct ModelStats {
|
||||||
|
pub requests: u64,
|
||||||
|
pub input_tokens: u64,
|
||||||
|
pub output_tokens: u64,
|
||||||
|
pub cache_read_tokens: u64,
|
||||||
|
pub cache_creation_tokens: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MitmStore {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
latest_usage: Arc::new(RwLock::new(HashMap::new())),
|
||||||
|
stats: Arc::new(RwLock::new(MitmStats::default())),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Record a completed API exchange with usage data.
|
||||||
|
pub async fn record_usage(&self, cascade_id: Option<&str>, usage: ApiUsage) {
|
||||||
|
debug!(
|
||||||
|
input = usage.input_tokens,
|
||||||
|
output = usage.output_tokens,
|
||||||
|
cache_read = usage.cache_read_input_tokens,
|
||||||
|
cache_create = usage.cache_creation_input_tokens,
|
||||||
|
thinking = usage.thinking_output_tokens,
|
||||||
|
response = usage.response_output_tokens,
|
||||||
|
model = ?usage.model,
|
||||||
|
provider = ?usage.api_provider,
|
||||||
|
grpc = ?usage.grpc_method,
|
||||||
|
"MITM captured API usage"
|
||||||
|
);
|
||||||
|
|
||||||
|
// Update aggregate stats
|
||||||
|
{
|
||||||
|
let mut stats = self.stats.write().await;
|
||||||
|
stats.total_requests += 1;
|
||||||
|
stats.total_input_tokens += usage.input_tokens;
|
||||||
|
stats.total_output_tokens += usage.output_tokens;
|
||||||
|
stats.total_cache_read_tokens += usage.cache_read_input_tokens;
|
||||||
|
stats.total_cache_creation_tokens += usage.cache_creation_input_tokens;
|
||||||
|
stats.total_thinking_output_tokens += usage.thinking_output_tokens;
|
||||||
|
stats.total_response_output_tokens += usage.response_output_tokens;
|
||||||
|
|
||||||
|
// Per-model breakdown
|
||||||
|
if let Some(ref model_name) = usage.model {
|
||||||
|
let model_stats = stats.per_model.entry(model_name.clone()).or_default();
|
||||||
|
model_stats.requests += 1;
|
||||||
|
model_stats.input_tokens += usage.input_tokens;
|
||||||
|
model_stats.output_tokens += usage.output_tokens;
|
||||||
|
model_stats.cache_read_tokens += usage.cache_read_input_tokens;
|
||||||
|
model_stats.cache_creation_tokens += usage.cache_creation_input_tokens;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store latest usage for the cascade (if we can identify it)
|
||||||
|
let key = cascade_id.map(|s| s.to_string()).unwrap_or_else(|| "_latest".to_string());
|
||||||
|
let mut latest = self.latest_usage.write().await;
|
||||||
|
latest.insert(key, usage);
|
||||||
|
|
||||||
|
// Evict old entries to prevent unbounded memory growth
|
||||||
|
const MAX_ENTRIES: usize = 500;
|
||||||
|
if latest.len() > MAX_ENTRIES {
|
||||||
|
// Find the oldest entry by captured_at and remove it
|
||||||
|
let oldest_key = latest
|
||||||
|
.iter()
|
||||||
|
.min_by_key(|(_, v)| v.captured_at)
|
||||||
|
.map(|(k, _)| k.clone());
|
||||||
|
if let Some(key) = oldest_key {
|
||||||
|
latest.remove(&key);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the latest usage for a cascade, consuming it (one-shot read).
|
||||||
|
///
|
||||||
|
/// Only returns exact cascade_id matches — no cross-cascade fallback.
|
||||||
|
/// The `_latest` key is only consumed when the caller explicitly requests it
|
||||||
|
/// (i.e., when the MITM couldn't identify the cascade).
|
||||||
|
pub async fn take_usage(&self, cascade_id: &str) -> Option<ApiUsage> {
|
||||||
|
let mut latest = self.latest_usage.write().await;
|
||||||
|
latest.remove(cascade_id)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Peek at the latest usage without consuming it.
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub async fn peek_usage(&self, cascade_id: &str) -> Option<ApiUsage> {
|
||||||
|
let latest = self.latest_usage.read().await;
|
||||||
|
latest.get(cascade_id)
|
||||||
|
.cloned()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get aggregate stats.
|
||||||
|
pub async fn stats(&self) -> MitmStats {
|
||||||
|
self.stats.read().await.clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
233
src/proto.rs
Normal file
233
src/proto.rs
Normal file
@@ -0,0 +1,233 @@
|
|||||||
|
//! Protobuf wire-format encoder — byte-exact match to the real Antigravity webview.
|
||||||
|
//!
|
||||||
|
//! This is a minimal, hand-rolled encoder. We do NOT use prost or any codegen
|
||||||
|
//! because we need precise control over field ordering and encoding to produce
|
||||||
|
//! byte-identical output to the captured webview traffic.
|
||||||
|
|
||||||
|
use crate::constants::{client_version, CLIENT_NAME};
|
||||||
|
|
||||||
|
// ─── Wire primitives ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Encode a varint (base-128, little-endian, MSB continuation).
|
||||||
|
pub fn varint(mut val: u64) -> Vec<u8> {
|
||||||
|
if val == 0 {
|
||||||
|
return vec![0x00];
|
||||||
|
}
|
||||||
|
let mut out = Vec::with_capacity(10);
|
||||||
|
while val > 0x7F {
|
||||||
|
out.push(((val & 0x7F) | 0x80) as u8);
|
||||||
|
val >>= 7;
|
||||||
|
}
|
||||||
|
out.push((val & 0x7F) as u8);
|
||||||
|
out
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Encode a field tag (field_number << 3 | wire_type).
|
||||||
|
pub fn tag(field: u32, wire: u8) -> Vec<u8> {
|
||||||
|
varint(((field as u64) << 3) | (wire as u64))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Wire type 2: length-delimited string/bytes field.
|
||||||
|
pub fn proto_string(field: u32, val: &[u8]) -> Vec<u8> {
|
||||||
|
let mut out = tag(field, 2);
|
||||||
|
out.extend(varint(val.len() as u64));
|
||||||
|
out.extend_from_slice(val);
|
||||||
|
out
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Wire type 2: length-delimited sub-message field.
|
||||||
|
pub fn proto_message(field: u32, inner: &[u8]) -> Vec<u8> {
|
||||||
|
let mut out = tag(field, 2);
|
||||||
|
out.extend(varint(inner.len() as u64));
|
||||||
|
out.extend_from_slice(inner);
|
||||||
|
out
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Wire type 0: boolean field (varint 0 or 1).
|
||||||
|
pub fn bool_field(field: u32, val: bool) -> Vec<u8> {
|
||||||
|
let mut out = tag(field, 0);
|
||||||
|
out.extend(varint(if val { 1 } else { 0 }));
|
||||||
|
out
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Wire type 0: varint field.
|
||||||
|
pub fn varint_field(field: u32, val: u64) -> Vec<u8> {
|
||||||
|
let mut out = tag(field, 0);
|
||||||
|
out.extend(varint(val));
|
||||||
|
out
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── SendUserCascadeMessageRequest builder ───────────────────────────────────
|
||||||
|
|
||||||
|
/// Build the `SendUserCascadeMessageRequest` protobuf binary.
|
||||||
|
///
|
||||||
|
/// Produces a byte-exact match to real Antigravity webview traffic.
|
||||||
|
/// Verified against Chrome DevTools network capture 2026-02-12.
|
||||||
|
///
|
||||||
|
/// Field layout:
|
||||||
|
/// 1: cascade_id (string)
|
||||||
|
/// 2: { 1: text } (message)
|
||||||
|
/// 3: metadata { 1: client_name, 3: oauth_token, 4: "en", 7: version, 12: client_name }
|
||||||
|
/// 5: PlannerConfig { 1: inner_config, 7: { 1: 1 } }
|
||||||
|
/// inner_config contains: f2 (conv mode), f13 (tool config), f15 (model), f21 (ephemeral), f32 (knowledge)
|
||||||
|
/// 11: conversation_history = true
|
||||||
|
pub fn build_request(cascade_id: &str, text: &str, oauth_token: &str, model_enum: u32) -> Vec<u8> {
|
||||||
|
let mut msg = Vec::with_capacity(256);
|
||||||
|
|
||||||
|
// Field 1: cascade_id
|
||||||
|
msg.extend(proto_string(1, cascade_id.as_bytes()));
|
||||||
|
|
||||||
|
// Field 2: { field 1: text }
|
||||||
|
msg.extend(proto_message(2, &proto_string(1, text.as_bytes())));
|
||||||
|
|
||||||
|
// Field 3: Metadata (Auth + Client ID)
|
||||||
|
let mut meta = Vec::new();
|
||||||
|
meta.extend(proto_string(1, CLIENT_NAME.as_bytes()));
|
||||||
|
meta.extend(proto_string(3, oauth_token.as_bytes()));
|
||||||
|
meta.extend(proto_string(4, b"en"));
|
||||||
|
meta.extend(proto_string(7, client_version().as_bytes()));
|
||||||
|
meta.extend(proto_string(12, CLIENT_NAME.as_bytes()));
|
||||||
|
msg.extend(proto_message(3, &meta));
|
||||||
|
|
||||||
|
// Field 5: PlannerConfig
|
||||||
|
let mut inner = Vec::new();
|
||||||
|
|
||||||
|
// field 2: conversational mode { f4: 1, f14: 0 }
|
||||||
|
let conv_mode = [varint_field(4, 1), varint_field(14, 0)].concat();
|
||||||
|
inner.extend(proto_message(2, &conv_mode));
|
||||||
|
|
||||||
|
// field 13: toolConfig
|
||||||
|
// field 8 (runCommand): field 3 (autoCommandConfig) -> field 6 (policy) = 3 (EAGER)
|
||||||
|
// field 33 (artifactReviewPolicy): field 1 = 2 (TURBO)
|
||||||
|
let run_cmd = proto_message(3, &varint_field(6, 3));
|
||||||
|
let tool_config = [
|
||||||
|
proto_message(8, &run_cmd),
|
||||||
|
proto_message(33, &varint_field(1, 2)),
|
||||||
|
]
|
||||||
|
.concat();
|
||||||
|
inner.extend(proto_message(13, &tool_config));
|
||||||
|
|
||||||
|
// field 15: requested model { f1: model_enum }
|
||||||
|
inner.extend(proto_message(15, &varint_field(1, model_enum as u64)));
|
||||||
|
|
||||||
|
// field 21: ephemeral messages config { f1: 1 }
|
||||||
|
inner.extend(proto_message(21, &varint_field(1, 1)));
|
||||||
|
|
||||||
|
// field 32: knowledge config { f1: true }
|
||||||
|
inner.extend(proto_message(32, &bool_field(1, true)));
|
||||||
|
|
||||||
|
// Field 5 wraps: field 1 (inner config) + field 7 { f1: 1 }
|
||||||
|
let f5_payload = [
|
||||||
|
proto_message(1, &inner),
|
||||||
|
proto_message(7, &varint_field(1, 1)),
|
||||||
|
]
|
||||||
|
.concat();
|
||||||
|
msg.extend(proto_message(5, &f5_payload));
|
||||||
|
|
||||||
|
// Field 11: conversation history flag
|
||||||
|
msg.extend(bool_field(11, true));
|
||||||
|
|
||||||
|
msg
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_varint_zero() {
|
||||||
|
assert_eq!(varint(0), vec![0x00]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_varint_small() {
|
||||||
|
assert_eq!(varint(1), vec![0x01]);
|
||||||
|
assert_eq!(varint(127), vec![0x7F]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_varint_multibyte() {
|
||||||
|
assert_eq!(varint(128), vec![0x80, 0x01]);
|
||||||
|
assert_eq!(varint(300), vec![0xAC, 0x02]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_varint_1026() {
|
||||||
|
// model_enum 1026 = 0x402 → varint [0x82, 0x08]
|
||||||
|
assert_eq!(varint(1026), vec![0x82, 0x08]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_tag() {
|
||||||
|
// field 1, wire type 2 (LEN) = (1 << 3) | 2 = 0x0A
|
||||||
|
assert_eq!(tag(1, 2), vec![0x0A]);
|
||||||
|
// field 3, wire type 0 (VARINT) = (3 << 3) | 0 = 0x18
|
||||||
|
assert_eq!(tag(3, 0), vec![0x18]);
|
||||||
|
// field 33, wire type 2 = (33 << 3) | 2 = 266 → varint [0x8A, 0x02]
|
||||||
|
assert_eq!(tag(33, 2), vec![0x8A, 0x02]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_proto_string() {
|
||||||
|
let result = proto_string(1, b"hi");
|
||||||
|
// tag(1,2) = 0x0A, len=2, 'h'=0x68, 'i'=0x69
|
||||||
|
assert_eq!(result, vec![0x0A, 0x02, 0x68, 0x69]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_build_request_deterministic() {
|
||||||
|
let a = build_request("cid", "hello", "ya29.tok", 1026);
|
||||||
|
let b = build_request("cid", "hello", "ya29.tok", 1026);
|
||||||
|
assert_eq!(a, b, "build_request must be deterministic");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_build_request_structure() {
|
||||||
|
let msg = build_request("test-cascade-id", "hello", "ya29.test-token", 1026);
|
||||||
|
|
||||||
|
assert_eq!(msg[0], 0x0A, "first byte must be field 1 tag");
|
||||||
|
|
||||||
|
let cascade_bytes = b"test-cascade-id";
|
||||||
|
assert!(
|
||||||
|
msg.windows(cascade_bytes.len())
|
||||||
|
.any(|w| w == cascade_bytes),
|
||||||
|
"cascade_id must appear in output"
|
||||||
|
);
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
msg.windows(5).any(|w| w == b"hello"),
|
||||||
|
"text must appear in output"
|
||||||
|
);
|
||||||
|
|
||||||
|
let token_bytes = b"ya29.test-token";
|
||||||
|
assert!(
|
||||||
|
msg.windows(token_bytes.len()).any(|w| w == token_bytes),
|
||||||
|
"oauth token must appear in output"
|
||||||
|
);
|
||||||
|
|
||||||
|
// model enum 1026 varint [0x82, 0x08]
|
||||||
|
assert!(
|
||||||
|
msg.windows(2).any(|w| w == [0x82, 0x08]),
|
||||||
|
"model enum 1026 varint must appear in output"
|
||||||
|
);
|
||||||
|
|
||||||
|
// field 11 bool true at end: tag(11,0)=0x58, varint(1)=0x01
|
||||||
|
let len = msg.len();
|
||||||
|
assert_eq!(msg[len - 2], 0x58);
|
||||||
|
assert_eq!(msg[len - 1], 0x01);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Cross-verified against Python output: 127/127 bytes identical.
|
||||||
|
#[test]
|
||||||
|
fn test_byte_exact_match_with_python() {
|
||||||
|
let msg = build_request("test-cascade-id", "hello", "ya29.test-token", 1026);
|
||||||
|
let hex: String = msg.iter().map(|b| format!("{:02x}", b)).collect();
|
||||||
|
let expected = "0a0f746573742d636173636164652d696412070a0568656c6c6f\
|
||||||
|
1a370a0b616e7469677261766974791a0f796132392e746573742d746f6b656e\
|
||||||
|
2202656e3a06312e31362e35620b616e7469677261766974792a280a22120420\
|
||||||
|
0170006a0b42041a0230038a020208027a03088208aa010208018202020801\
|
||||||
|
3a0208015801";
|
||||||
|
assert_eq!(hex, expected, "must be byte-exact match with Python");
|
||||||
|
assert_eq!(msg.len(), 127);
|
||||||
|
}
|
||||||
|
}
|
||||||
218
src/quota.rs
Normal file
218
src/quota.rs
Normal file
@@ -0,0 +1,218 @@
|
|||||||
|
//! Quota monitor — polls the local LS `GetUserStatus` to track
|
||||||
|
//! prompt/flow credits and per-model rate limits without touching Google servers.
|
||||||
|
|
||||||
|
use serde::Serialize;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tokio::sync::RwLock;
|
||||||
|
use tracing::{debug, warn};
|
||||||
|
|
||||||
|
/// How often to poll the LS for fresh quota data (seconds).
|
||||||
|
const POLL_INTERVAL_SECS: u64 = 60;
|
||||||
|
|
||||||
|
// ─── Public types ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Default)]
|
||||||
|
pub struct QuotaSnapshot {
|
||||||
|
/// When this snapshot was last refreshed (ISO-8601 UTC).
|
||||||
|
pub last_updated: String,
|
||||||
|
/// Overall plan info.
|
||||||
|
pub plan: PlanInfo,
|
||||||
|
/// Monthly credit balances.
|
||||||
|
pub credits: CreditInfo,
|
||||||
|
/// Per-model rate limits.
|
||||||
|
pub models: Vec<ModelQuota>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Default)]
|
||||||
|
pub struct PlanInfo {
|
||||||
|
pub plan_name: String,
|
||||||
|
pub tier_id: String,
|
||||||
|
pub tier_name: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Default)]
|
||||||
|
pub struct CreditInfo {
|
||||||
|
pub prompt_available: i64,
|
||||||
|
pub prompt_total: i64,
|
||||||
|
pub prompt_used_pct: f64,
|
||||||
|
pub flow_available: i64,
|
||||||
|
pub flow_total: i64,
|
||||||
|
pub flow_used_pct: f64,
|
||||||
|
pub flex_purchasable: i64,
|
||||||
|
pub can_buy_more: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Default)]
|
||||||
|
pub struct ModelQuota {
|
||||||
|
pub label: String,
|
||||||
|
pub model_id: String,
|
||||||
|
/// 0.0–1.0 remaining fraction (1.0 = full quota).
|
||||||
|
pub remaining_fraction: f64,
|
||||||
|
/// Percentage remaining (0–100).
|
||||||
|
pub remaining_pct: f64,
|
||||||
|
/// ISO-8601 UTC reset time.
|
||||||
|
pub reset_time: String,
|
||||||
|
/// Seconds until reset (negative = already reset).
|
||||||
|
pub reset_in_secs: i64,
|
||||||
|
/// Human-readable countdown.
|
||||||
|
pub reset_in_human: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── Quota Store ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct QuotaStore {
|
||||||
|
inner: Arc<RwLock<QuotaSnapshot>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl QuotaStore {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
inner: Arc::new(RwLock::new(QuotaSnapshot::default())),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the latest cached snapshot.
|
||||||
|
pub async fn snapshot(&self) -> QuotaSnapshot {
|
||||||
|
self.inner.read().await.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Start the background polling loop. Call once at startup.
|
||||||
|
pub fn start_polling(self, backend: Arc<crate::backend::Backend>) {
|
||||||
|
tokio::spawn(async move {
|
||||||
|
// Initial poll immediately.
|
||||||
|
self.poll_once(&backend).await;
|
||||||
|
|
||||||
|
let mut interval = tokio::time::interval(
|
||||||
|
std::time::Duration::from_secs(POLL_INTERVAL_SECS),
|
||||||
|
);
|
||||||
|
interval.tick().await; // consume the first immediate tick
|
||||||
|
|
||||||
|
loop {
|
||||||
|
interval.tick().await;
|
||||||
|
self.poll_once(&backend).await;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn poll_once(&self, backend: &crate::backend::Backend) {
|
||||||
|
match backend
|
||||||
|
.call_json("GetUserStatus", &serde_json::json!({}))
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok((200, data)) => {
|
||||||
|
let snapshot = parse_user_status(&data);
|
||||||
|
debug!(
|
||||||
|
"Quota poll: prompt {}/{} flow {}/{}",
|
||||||
|
snapshot.credits.prompt_available,
|
||||||
|
snapshot.credits.prompt_total,
|
||||||
|
snapshot.credits.flow_available,
|
||||||
|
snapshot.credits.flow_total,
|
||||||
|
);
|
||||||
|
*self.inner.write().await = snapshot;
|
||||||
|
}
|
||||||
|
Ok((status, data)) => {
|
||||||
|
warn!("GetUserStatus returned {status}: {data}");
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
warn!("GetUserStatus poll failed: {e}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── Parsing ─────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
fn parse_user_status(data: &serde_json::Value) -> QuotaSnapshot {
|
||||||
|
let now = chrono::Utc::now();
|
||||||
|
let us = &data["userStatus"];
|
||||||
|
let ps = &us["planStatus"];
|
||||||
|
let pi = &ps["planInfo"];
|
||||||
|
let ut = &us["userTier"];
|
||||||
|
|
||||||
|
let prompt_total = pi["monthlyPromptCredits"].as_i64().unwrap_or(0);
|
||||||
|
let prompt_avail = ps["availablePromptCredits"].as_i64().unwrap_or(0);
|
||||||
|
let flow_total = pi["monthlyFlowCredits"].as_i64().unwrap_or(0);
|
||||||
|
let flow_avail = ps["availableFlowCredits"].as_i64().unwrap_or(0);
|
||||||
|
|
||||||
|
let prompt_used_pct = if prompt_total > 0 {
|
||||||
|
((prompt_total - prompt_avail) as f64 / prompt_total as f64) * 100.0
|
||||||
|
} else {
|
||||||
|
0.0
|
||||||
|
};
|
||||||
|
let flow_used_pct = if flow_total > 0 {
|
||||||
|
((flow_total - flow_avail) as f64 / flow_total as f64) * 100.0
|
||||||
|
} else {
|
||||||
|
0.0
|
||||||
|
};
|
||||||
|
|
||||||
|
let models = us["cascadeModelConfigData"]["clientModelConfigs"]
|
||||||
|
.as_array()
|
||||||
|
.map(|arr| {
|
||||||
|
arr.iter()
|
||||||
|
.map(|m| {
|
||||||
|
let label = m["label"].as_str().unwrap_or("").to_string();
|
||||||
|
let model_id = m["modelOrAlias"]["model"]
|
||||||
|
.as_str()
|
||||||
|
.unwrap_or("")
|
||||||
|
.to_string();
|
||||||
|
let frac = m["quotaInfo"]["remainingFraction"]
|
||||||
|
.as_f64()
|
||||||
|
.unwrap_or(0.0);
|
||||||
|
let reset_str = m["quotaInfo"]["resetTime"]
|
||||||
|
.as_str()
|
||||||
|
.unwrap_or("")
|
||||||
|
.to_string();
|
||||||
|
|
||||||
|
let reset_in_secs = if !reset_str.is_empty() {
|
||||||
|
chrono::DateTime::parse_from_rfc3339(&reset_str)
|
||||||
|
.map(|dt| (dt.with_timezone(&chrono::Utc) - now).num_seconds())
|
||||||
|
.unwrap_or(0)
|
||||||
|
} else {
|
||||||
|
0
|
||||||
|
};
|
||||||
|
|
||||||
|
let reset_in_human = if reset_in_secs > 0 {
|
||||||
|
let h = reset_in_secs / 3600;
|
||||||
|
let m = (reset_in_secs % 3600) / 60;
|
||||||
|
format!("{h}h {m}m")
|
||||||
|
} else {
|
||||||
|
"available".to_string()
|
||||||
|
};
|
||||||
|
|
||||||
|
ModelQuota {
|
||||||
|
label,
|
||||||
|
model_id,
|
||||||
|
remaining_fraction: frac,
|
||||||
|
remaining_pct: frac * 100.0,
|
||||||
|
reset_time: reset_str,
|
||||||
|
reset_in_secs,
|
||||||
|
reset_in_human,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
})
|
||||||
|
.unwrap_or_default();
|
||||||
|
|
||||||
|
QuotaSnapshot {
|
||||||
|
last_updated: now.to_rfc3339(),
|
||||||
|
plan: PlanInfo {
|
||||||
|
plan_name: pi["planName"].as_str().unwrap_or("").to_string(),
|
||||||
|
tier_id: ut["id"].as_str().unwrap_or("").to_string(),
|
||||||
|
tier_name: ut["name"].as_str().unwrap_or("").to_string(),
|
||||||
|
},
|
||||||
|
credits: CreditInfo {
|
||||||
|
prompt_available: prompt_avail,
|
||||||
|
prompt_total,
|
||||||
|
prompt_used_pct,
|
||||||
|
flow_available: flow_avail,
|
||||||
|
flow_total,
|
||||||
|
flow_used_pct,
|
||||||
|
flex_purchasable: pi["monthlyFlexCreditPurchaseAmount"]
|
||||||
|
.as_i64()
|
||||||
|
.unwrap_or(0),
|
||||||
|
can_buy_more: pi["canBuyMoreCredits"].as_bool().unwrap_or(false),
|
||||||
|
},
|
||||||
|
models,
|
||||||
|
}
|
||||||
|
}
|
||||||
152
src/session.rs
Normal file
152
src/session.rs
Normal file
@@ -0,0 +1,152 @@
|
|||||||
|
//! Cascade session manager — maps session IDs to cascade IDs for reuse.
|
||||||
|
//!
|
||||||
|
//! Mimics real webview behavior: one chat tab = one cascade with many messages.
|
||||||
|
//! Without this, every API call creates a new cascade — an obvious automation
|
||||||
|
//! fingerprint (100 calls = 100 single-message cascades).
|
||||||
|
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::time::Instant;
|
||||||
|
use tokio::sync::RwLock;
|
||||||
|
|
||||||
|
const DEFAULT_SESSION: &str = "__default__";
|
||||||
|
const SESSION_TTL_SECS: u64 = 3600 * 4; // 4 hours
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct Session {
|
||||||
|
cascade_id: String,
|
||||||
|
created: Instant,
|
||||||
|
last_used: Instant,
|
||||||
|
msg_count: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct SessionManager {
|
||||||
|
sessions: RwLock<HashMap<String, Session>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Result of session resolution.
|
||||||
|
pub struct SessionResult {
|
||||||
|
pub cascade_id: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SessionManager {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
sessions: RwLock::new(HashMap::new()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get existing cascade for session, or create a new one.
|
||||||
|
///
|
||||||
|
/// - `session_id = None` → use default session
|
||||||
|
/// - `session_id = Some("new")` → always create fresh cascade
|
||||||
|
/// - `session_id = Some("my-task")` → reuse cascade for that task
|
||||||
|
///
|
||||||
|
/// Uses double-check locking to avoid TOCTOU races: after creating a cascade,
|
||||||
|
/// re-acquires the lock and checks if another request raced us.
|
||||||
|
pub async fn get_or_create<F, Fut>(
|
||||||
|
&self,
|
||||||
|
session_id: Option<&str>,
|
||||||
|
create_fn: F,
|
||||||
|
) -> Result<SessionResult, String>
|
||||||
|
where
|
||||||
|
F: FnOnce() -> Fut,
|
||||||
|
Fut: std::future::Future<Output = Result<String, String>>,
|
||||||
|
{
|
||||||
|
// "new" always creates a fresh cascade
|
||||||
|
if session_id == Some("new") {
|
||||||
|
let cascade_id = create_fn().await?;
|
||||||
|
let new_sid = format!("s-{}", &uuid::Uuid::new_v4().to_string()[..8]);
|
||||||
|
let mut sessions = self.sessions.write().await;
|
||||||
|
sessions.insert(
|
||||||
|
new_sid.clone(),
|
||||||
|
Session {
|
||||||
|
cascade_id: cascade_id.clone(),
|
||||||
|
created: Instant::now(),
|
||||||
|
last_used: Instant::now(),
|
||||||
|
msg_count: 0,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
return Ok(SessionResult {
|
||||||
|
cascade_id,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let sid = session_id.unwrap_or(DEFAULT_SESSION).to_string();
|
||||||
|
|
||||||
|
// Check existing — only need write lock for cleanup + mutation
|
||||||
|
{
|
||||||
|
let mut sessions = self.sessions.write().await;
|
||||||
|
cleanup_expired(&mut sessions);
|
||||||
|
if let Some(sess) = sessions.get_mut(&sid) {
|
||||||
|
sess.last_used = Instant::now();
|
||||||
|
sess.msg_count += 1;
|
||||||
|
return Ok(SessionResult {
|
||||||
|
cascade_id: sess.cascade_id.clone(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Lock released before async create_fn
|
||||||
|
|
||||||
|
// Create new cascade (this may take a while — lock is NOT held)
|
||||||
|
let cascade_id = create_fn().await?;
|
||||||
|
|
||||||
|
// Double-check: another request may have raced us and created the same session
|
||||||
|
{
|
||||||
|
let mut sessions = self.sessions.write().await;
|
||||||
|
if let Some(existing) = sessions.get_mut(&sid) {
|
||||||
|
// Another request won the race — use their cascade, discard ours
|
||||||
|
existing.last_used = Instant::now();
|
||||||
|
existing.msg_count += 1;
|
||||||
|
return Ok(SessionResult {
|
||||||
|
cascade_id: existing.cascade_id.clone(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
sessions.insert(
|
||||||
|
sid.clone(),
|
||||||
|
Session {
|
||||||
|
cascade_id: cascade_id.clone(),
|
||||||
|
created: Instant::now(),
|
||||||
|
last_used: Instant::now(),
|
||||||
|
msg_count: 1,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Ok(SessionResult {
|
||||||
|
cascade_id,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// List all active sessions.
|
||||||
|
pub async fn list_sessions(&self) -> serde_json::Value {
|
||||||
|
let mut sessions = self.sessions.write().await;
|
||||||
|
cleanup_expired(&mut sessions);
|
||||||
|
|
||||||
|
let now = Instant::now();
|
||||||
|
let mut map = serde_json::Map::new();
|
||||||
|
for (sid, sess) in sessions.iter() {
|
||||||
|
map.insert(
|
||||||
|
sid.clone(),
|
||||||
|
serde_json::json!({
|
||||||
|
"cascade_id": sess.cascade_id,
|
||||||
|
"msg_count": sess.msg_count,
|
||||||
|
"age_seconds": now.duration_since(sess.created).as_secs(),
|
||||||
|
"idle_seconds": now.duration_since(sess.last_used).as_secs(),
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
serde_json::Value::Object(map)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Delete a session. Returns true if it existed.
|
||||||
|
pub async fn delete_session(&self, session_id: &str) -> bool {
|
||||||
|
let mut sessions = self.sessions.write().await;
|
||||||
|
sessions.remove(session_id).is_some()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cleanup_expired(sessions: &mut HashMap<String, Session>) {
|
||||||
|
let now = Instant::now();
|
||||||
|
sessions.retain(|_, s| {
|
||||||
|
now.duration_since(s.last_used).as_secs() < SESSION_TTL_SECS
|
||||||
|
});
|
||||||
|
}
|
||||||
69
src/warmup.rs
Normal file
69
src/warmup.rs
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
//! Startup warmup and periodic heartbeat — mimics real webview lifecycle.
|
||||||
|
//!
|
||||||
|
//! The real Electron webview calls these methods on startup and then sends
|
||||||
|
//! Heartbeat every ~30 seconds. Without this, the LS sees a "user" that
|
||||||
|
//! never initializes and never heartbeats — an obvious bot fingerprint.
|
||||||
|
|
||||||
|
use crate::backend::Backend;
|
||||||
|
use rand::Rng;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::Duration;
|
||||||
|
use tokio::task::JoinHandle;
|
||||||
|
use tracing::{debug, info, warn};
|
||||||
|
|
||||||
|
/// Run the exact startup sequence the real webview performs on load.
|
||||||
|
///
|
||||||
|
/// Called BEFORE accepting any API requests. Each call is fire-and-forget
|
||||||
|
/// (we don't care if some fail — the LS might not support all methods).
|
||||||
|
pub async fn warmup_sequence(backend: &Backend) {
|
||||||
|
info!("Running webview warmup sequence...");
|
||||||
|
|
||||||
|
let calls: &[(&str, serde_json::Value)] = &[
|
||||||
|
("GetStatus", serde_json::json!({})),
|
||||||
|
("Heartbeat", serde_json::json!({})),
|
||||||
|
("GetUserStatus", serde_json::json!({})),
|
||||||
|
("GetCascadeModelConfigs", serde_json::json!({})),
|
||||||
|
("GetCascadeModelConfigData", serde_json::json!({})),
|
||||||
|
("GetWorkspaceInfos", serde_json::json!({})),
|
||||||
|
("GetWorkingDirectories", serde_json::json!({})),
|
||||||
|
("GetAllCascadeTrajectories", serde_json::json!({})),
|
||||||
|
("GetMcpServerStates", serde_json::json!({})),
|
||||||
|
("GetWebDocsOptions", serde_json::json!({})),
|
||||||
|
("GetRepoInfos", serde_json::json!({})),
|
||||||
|
("GetAllSkills", serde_json::json!({})),
|
||||||
|
("InitializeCascadePanelState", serde_json::json!({})),
|
||||||
|
];
|
||||||
|
|
||||||
|
for (method, body) in calls {
|
||||||
|
match backend.call_json(method, body).await {
|
||||||
|
Ok((status, _)) => debug!("Warmup {method}: {status}"),
|
||||||
|
Err(e) => warn!("Warmup {method} failed: {e}"),
|
||||||
|
}
|
||||||
|
// Small delay between calls — real webview doesn't blast them instantly
|
||||||
|
let delay = rand::thread_rng().gen_range(50..200);
|
||||||
|
tokio::time::sleep(Duration::from_millis(delay)).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
info!("Warmup complete");
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Spawn a background task that sends Heartbeat every ~30s ± jitter.
|
||||||
|
///
|
||||||
|
/// Returns a JoinHandle that runs until the task is aborted (on shutdown).
|
||||||
|
pub fn start_heartbeat(backend: Arc<Backend>) -> JoinHandle<()> {
|
||||||
|
tokio::spawn(async move {
|
||||||
|
loop {
|
||||||
|
// ~30s interval (± 500ms) — matches real setInterval(30000) precision
|
||||||
|
let interval_ms = rand::thread_rng().gen_range(29_500..30_500);
|
||||||
|
tokio::time::sleep(Duration::from_millis(interval_ms)).await;
|
||||||
|
|
||||||
|
match backend
|
||||||
|
.call_json("Heartbeat", &serde_json::json!({}))
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok((status, _)) => debug!("Heartbeat: {status}"),
|
||||||
|
Err(e) => warn!("Heartbeat failed: {e}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user