Compare commits

...

2 Commits

Author SHA1 Message Date
89bb29e563 v0.2.0 — remove vendored fork, upstream login feature
Some checks failed
Build and Push Docker Image / build (push) Failing after 39s
- Delete vendor/schwab-scraper/ (now fetched at CI build time)
- Delete schwab_mcp_custom/ package (LoginManager moved into server.py)
- server.py: add inline LoginManager with env-configurable rate limits
- server.py: orchestrate login safety checks at MCP layer, not in scraper
- Dockerfile: restore vendor-based build with fresh upstream checkout
- pyproject.toml: bump mcp>=1.27.0, playwright>=1.54.0
2026-04-28 00:36:46 +00:00
2de3b709d8 feat: expose automated login and session refresh with safety status tool 2026-04-27 19:47:55 +00:00
5 changed files with 173 additions and 34 deletions

View File

@@ -4,11 +4,12 @@ ENV UV_COMPILE_BYTECODE=1 UV_LINK_MODE=copy
WORKDIR /app WORKDIR /app
COPY pyproject.toml uv.lock ./ # Copy vendored schwab-scraper (checked out cleanly by CI) and pyproject.toml
COPY vendor/schwab-scraper /tmp/schwab-scraper COPY vendor/schwab-scraper /tmp/schwab-scraper
COPY pyproject.toml uv.lock ./
# Install schwab-scraper from vendored source, then all other deps. # Install schwab-scraper from the clean build-time checkout, then remaining deps.
# We strip the git dependency from pyproject.toml so uv doesn't try to fetch it. # We strip the git dependency line so uv doesn't try to fetch over the network.
RUN uv venv && \ RUN uv venv && \
uv pip install /tmp/schwab-scraper && \ uv pip install /tmp/schwab-scraper && \
sed -i '/schwab-scraper/d' pyproject.toml && \ sed -i '/schwab-scraper/d' pyproject.toml && \
@@ -20,7 +21,9 @@ COPY . .
FROM python:3.12-slim-bookworm FROM python:3.12-slim-bookworm
RUN apt-get update && apt-get install -y --no-install-recommends curl && rm -rf /var/lib/apt/lists/* RUN apt-get update && \
apt-get install -y --no-install-recommends curl && \
rm -rf /var/lib/apt/lists/*
WORKDIR /app WORKDIR /app
COPY --from=builder /app /app COPY --from=builder /app /app

View File

@@ -1,12 +1,12 @@
[project] [project]
name = "schwab-mcp-custom" name = "schwab-mcp-custom"
version = "0.1.0" version = "0.2.0"
description = "Hybrid MCP Light server for Schwab scraper" description = "MCP server wrapping schwab-scraper"
readme = "README.md" readme = "README.md"
requires-python = ">=3.12" requires-python = ">=3.12"
dependencies = [ dependencies = [
"schwab-scraper @ git+ssh://gitea@git.local.ben.io/b3nw/schwab-scraper.git", "schwab-scraper @ git+ssh://gitea@git.local.ben.io/b3nw/schwab-scraper.git",
"mcp>=1.2.0", "mcp>=1.27.0",
"fastmcp>=0.4.1", "fastmcp>=0.4.1",
"starlette>=0.41.0", "starlette>=0.41.0",
"uvicorn>=0.32.0", "uvicorn>=0.32.0",

188
server.py
View File

@@ -1,7 +1,8 @@
import json import json
import logging import logging
import os import os
from typing import Optional, Any import time
from typing import Optional, Any, Tuple
from fastmcp import FastMCP from fastmcp import FastMCP
from starlette.applications import Starlette from starlette.applications import Starlette
@@ -9,7 +10,6 @@ from starlette.responses import JSONResponse
from starlette.routing import Route, Mount from starlette.routing import Route, Mount
import uvicorn import uvicorn
# Import the unified API from the schwab_scraper dependency
import schwab_scraper.unified_api as api import schwab_scraper.unified_api as api
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -46,48 +46,169 @@ def _patch_request_responder():
_patch_request_responder() _patch_request_responder()
# Initialize FastMCP
# ---------------------------------------------------------------------------
# Login safety manager — lives in the MCP server layer, not the scraper.
# Provides rate-limiting and backoff for automated login attempts.
# ---------------------------------------------------------------------------
class LoginManager:
"""Tracks login attempts and enforces safety limits to avoid account lockouts."""
def __init__(self):
self.max_attempts = int(os.getenv("SCHWAB_LOGIN_MAX_ATTEMPTS", "3"))
self.window_minutes = int(os.getenv("SCHWAB_LOGIN_WINDOW_MIN", "60"))
self.backoff_minutes = int(os.getenv("SCHWAB_LOGIN_BACKOFF_MIN", "30"))
self._attempts: list[tuple[float, bool]] = []
def _trim_window(self) -> None:
cutoff = time.time() - (self.window_minutes * 60)
self._attempts = [(ts, success) for ts, success in self._attempts if ts > cutoff]
def can_login(self) -> Tuple[bool, str]:
"""Return (allowed: bool, reason: str)."""
self._trim_window()
failure_count = sum(1 for _, success in self._attempts if not success)
if failure_count >= self.max_attempts:
# Compute remaining backoff from most recent failure
last_failure_ts = max(ts for ts, success in self._attempts if not success)
elapsed = time.time() - last_failure_ts
remaining = (self.backoff_minutes * 60) - elapsed
if remaining > 0:
return (
False,
f"Login blocked: {failure_count} failures in window. "
f"Wait {int(remaining / 60)}m {int(remaining % 60)}s.",
)
recent_count = len(self._attempts)
return True, f"Allowed ({recent_count} attempts in last {self.window_minutes}m)"
def record_attempt(self, success: bool) -> None:
self._trim_window()
self._attempts.append((time.time(), success))
def get_status(self) -> dict:
self._trim_window()
failure_count = sum(1 for _, success in self._attempts if not success)
recent_count = len(self._attempts)
if failure_count >= self.max_attempts:
last_failure_ts = max(ts for ts, success in self._attempts if not success)
elapsed = time.time() - last_failure_ts
remaining = (self.backoff_minutes * 60) - elapsed
blocked = remaining > 0
else:
remaining = 0
blocked = False
return {
"blocked": blocked,
"remaining_backoff_seconds": max(0, int(remaining)),
"recent_attempts": recent_count,
"recent_failures": failure_count,
"max_attempts_per_window": self.max_attempts,
"window_minutes": self.window_minutes,
"backoff_minutes": self.backoff_minutes,
}
login_manager = LoginManager()
mcp = FastMCP("SchwabScraper") mcp = FastMCP("SchwabScraper")
def serialize(obj: Any) -> str: def serialize(obj: Any) -> str:
"""Safely serialize Pydantic models or datclasses to JSON string.""" """Safely serialize Pydantic models or dataclasses to JSON string."""
if hasattr(obj, "model_dump_json"): if hasattr(obj, "model_dump_json"):
return obj.model_dump_json() return obj.model_dump_json()
elif hasattr(obj, "model_dump"): elif hasattr(obj, "model_dump"):
return json.dumps(obj.model_dump(), default=str) return json.dumps(obj.model_dump(), default=str)
elif isinstance(obj, list): elif isinstance(obj, list):
# Handle lists of models
return json.dumps([ return json.dumps([
o.model_dump() if hasattr(o, "model_dump") else o o.model_dump() if hasattr(o, "model_dump") else o
for o in obj for o in obj
], default=str) ], default=str)
return json.dumps(obj, default=str) return json.dumps(obj, default=str)
# ---------------------------------------------------------------------------
# MCP tools
# ---------------------------------------------------------------------------
@mcp.tool() @mcp.tool()
async def get_session_status(debug: bool = False) -> str: async def get_session_status(debug: bool = False) -> str:
"""Get the current session status of the Schwab scraper. """Get the current session status of the Schwab scraper.
Args: Args:
debug: Enable debug logging debug: Enable debug logging
""" """
result = await api.get_session_status(debug=debug) result = await api.get_session_status(debug=debug)
# Enrich with login safety status
if result.get("success"):
data = result.get("data", {})
data["login_safety"] = login_manager.get_status()
return serialize(result) return serialize(result)
@mcp.tool()
async def get_login_safety_status() -> str:
"""Get the current login safety status, including any active backoffs or limits.
Useful to check if a login attempt is likely to be blocked.
"""
return json.dumps(login_manager.get_status())
@mcp.tool()
async def login(
username: Optional[str] = None, password: Optional[str] = None, debug: bool = False
) -> str:
"""Perform an automated login to Schwab to establish a new session.
Args:
username: Schwab username (optional, will use env/config if omitted)
password: Schwab password (optional, will use env/config if omitted)
debug: Enable debug logging
"""
allowed, reason = login_manager.can_login()
if not allowed:
return json.dumps({
"success": False,
"error": f"Login blocked by safety safeguards: {reason}",
"error_type": "AUTHENTICATION",
"retryable": False,
"data": None,
})
result = await api.login(username=username, password=password, debug=debug)
success = result.get("success", False)
login_manager.record_attempt(success)
return serialize(result)
@mcp.tool()
async def refresh_session(debug: bool = False) -> str:
"""Refresh the current Schwab session to prevent expiration.
Args:
debug: Enable debug logging
"""
result = await api.refresh_session(debug=debug)
return serialize(result)
@mcp.tool() @mcp.tool()
async def list_accounts(debug: bool = False) -> str: async def list_accounts(debug: bool = False) -> str:
"""List all Schwab accounts. """List all Schwab accounts.
Args: Args:
debug: Enable debug logging debug: Enable debug logging
""" """
result = await api.list_accounts(debug=debug) result = await api.list_accounts(debug=debug)
return serialize(result) return serialize(result)
@mcp.tool() @mcp.tool()
async def get_account_overview(account: Optional[str] = None, debug: bool = False) -> str: async def get_account_overview(account: Optional[str] = None, debug: bool = False) -> str:
"""Get the overview for a specific account. """Get the overview for a specific account.
Args: Args:
account: Account summary or ID (optional) account: Account summary or ID (optional)
debug: Enable debug logging debug: Enable debug logging
@@ -95,28 +216,36 @@ async def get_account_overview(account: Optional[str] = None, debug: bool = Fals
result = await api.get_account_overview(account=account, debug=debug) result = await api.get_account_overview(account=account, debug=debug)
return serialize(result) return serialize(result)
@mcp.tool() @mcp.tool()
async def get_positions(account: Optional[str] = None, include_non_equity: bool = False, debug: bool = False) -> str: async def get_positions(
account: Optional[str] = None,
include_non_equity: bool = False,
debug: bool = False,
) -> str:
"""Get positions for a specific account. """Get positions for a specific account.
Args: Args:
account: Account summary or ID (optional) account: Account summary or ID (optional)
include_non_equity: Whether to include non-equity positions include_non_equity: Whether to include non-equity positions
debug: Enable debug logging debug: Enable debug logging
""" """
result = await api.get_positions(account=account, include_non_equity=include_non_equity, debug=debug) result = await api.get_positions(
account=account, include_non_equity=include_non_equity, debug=debug
)
return serialize(result) return serialize(result)
@mcp.tool() @mcp.tool()
async def get_transactions( async def get_transactions(
account: Optional[str] = None, account: Optional[str] = None,
start_date: Optional[str] = None, start_date: Optional[str] = None,
end_date: Optional[str] = None, end_date: Optional[str] = None,
time_period: Optional[str] = None, time_period: Optional[str] = None,
debug: bool = False debug: bool = False,
) -> str: ) -> str:
"""Get transaction history. """Get transaction history.
Args: Args:
account: Account ID (optional) account: Account ID (optional)
start_date: Start date for transactions (optional) start_date: Start date for transactions (optional)
@@ -129,14 +258,15 @@ async def get_transactions(
start_date=start_date, start_date=start_date,
end_date=end_date, end_date=end_date,
time_period=time_period, time_period=time_period,
debug=debug debug=debug,
) )
return serialize(result) return serialize(result)
@mcp.tool() @mcp.tool()
async def get_morningstar_data(ticker: str, debug: bool = False) -> str: async def get_morningstar_data(ticker: str, debug: bool = False) -> str:
"""Get Morningstar data for a ticker. """Get Morningstar data for a ticker.
Args: Args:
ticker: Stock ticker symbol ticker: Stock ticker symbol
debug: Enable debug logging debug: Enable debug logging
@@ -144,29 +274,29 @@ async def get_morningstar_data(ticker: str, debug: bool = False) -> str:
result = await api.get_morningstar_data(ticker, debug=debug) result = await api.get_morningstar_data(ticker, debug=debug)
return serialize(result) return serialize(result)
@mcp.tool() @mcp.tool()
async def upload_cookies(cookies_json: str) -> str: async def upload_cookies(cookies_json: str) -> str:
"""Upload session cookies to the server to assist with authentication. """Upload session cookies to the server to assist with authentication.
Args: Args:
cookies_json: JSON string of cookies exported from a browser (Playwright format) cookies_json: JSON string of cookies exported from a browser (Playwright format)
""" """
try: try:
# Validate JSON
cookies = json.loads(cookies_json) cookies = json.loads(cookies_json)
# Write to cookies.json
with open("cookies.json", "w") as f: with open("cookies.json", "w") as f:
json.dump(cookies, f) json.dump(cookies, f)
return json.dumps({"status": "success", "message": "cookies.json updated successfully"}) return json.dumps({"status": "success", "message": "cookies.json updated successfully"})
except Exception as e: except Exception as e:
return json.dumps({"status": "error", "message": str(e)}) return json.dumps({"status": "error", "message": str(e)})
@mcp.tool() @mcp.tool()
async def api_call(endpoint: str, method: str = "GET", params: str = "{}") -> str: async def api_call(endpoint: str, method: str = "GET", params: str = "{}") -> str:
"""Executes a raw API call to the Schwab service (Dummy implementation). """Executes a raw API call to the Schwab service (placeholder).
Refer to the 'api-reference' resource for available endpoints and parameters. Refer to the 'api-reference' resource for available endpoints and parameters.
Args: Args:
endpoint: The API path endpoint: The API path
method: HTTP method (GET, POST, etc.) method: HTTP method (GET, POST, etc.)
@@ -174,23 +304,29 @@ async def api_call(endpoint: str, method: str = "GET", params: str = "{}") -> st
""" """
return json.dumps({"status": "not_implemented", "message": "API pass-through not supported for scraper"}) return json.dumps({"status": "not_implemented", "message": "API pass-through not supported for scraper"})
@mcp.resource("service://api-reference") @mcp.resource("service://api-reference")
def get_api_docs() -> str: def get_api_docs() -> str:
"""Returns the API documentation for using the 'api_call' tool.""" """Returns the API documentation for using the 'api_call' tool."""
return "Schwab Scraper MCP Server - Unified API Documentation\n\nThis server provides tools to interact with Schwab accounts via scraping. The 'api_call' tool is a placeholder." return (
"Schwab Scraper MCP Server — Unified API Documentation\n\n"
"This server provides tools to interact with Schwab accounts via scraping.\n"
"The 'api_call' tool is a placeholder."
)
async def health(request): async def health(request):
"""Health check endpoint.""" """Health check endpoint."""
return JSONResponse({"status": "ok"}) return JSONResponse({"status": "ok"})
# Create the Starlette application
mcp_app = mcp.http_app() mcp_app = mcp.http_app()
app = Starlette( app = Starlette(
routes=[ routes=[
Route("/health", health), Route("/health", health),
Mount("/", app=mcp_app) Mount("/", app=mcp_app),
], ],
lifespan=mcp_app.lifespan lifespan=mcp_app.lifespan,
) )
if __name__ == "__main__": if __name__ == "__main__":

2
uv.lock generated
View File

@@ -1733,7 +1733,7 @@ requires-dist = [
{ name = "greenlet", specifier = ">=3.2.3" }, { name = "greenlet", specifier = ">=3.2.3" },
{ name = "mcp", specifier = ">=1.2.0" }, { name = "mcp", specifier = ">=1.2.0" },
{ name = "pdfplumber", specifier = ">=0.11.4" }, { name = "pdfplumber", specifier = ">=0.11.4" },
{ name = "playwright", specifier = "==1.54.0" }, { name = "playwright", specifier = ">=1.54.0" },
{ name = "pyee", specifier = ">=13.0.0" }, { name = "pyee", specifier = ">=13.0.0" },
{ name = "schwab-scraper", git = "ssh://git.local.ben.io/b3nw/schwab-scraper.git" }, { name = "schwab-scraper", git = "ssh://git.local.ben.io/b3nw/schwab-scraper.git" },
{ name = "starlette", specifier = ">=0.41.0" }, { name = "starlette", specifier = ">=0.41.0" },