Files
schwab-mcp-custom/server.py
b3nw 1999392df7
All checks were successful
Build and Push Docker Image / build (push) Successful in 38s
fix(mcp): capture scraper logs and return them in tool responses
Scraper debug output goes to stderr which is invisible in MCP stdio mode.
Add capture_logs context manager that attaches a StringIO handler to the
schwab_scraper logger during tool execution, then includes captured logs
in the response envelope when debug=True or on failure.

Applied to login() and refresh_session() which are the critical paths
for authentication diagnostics.
2026-04-28 02:04:58 +00:00

380 lines
12 KiB
Python

import io
import json
import logging
import os
import time
from contextlib import contextmanager
from typing import Optional, Any, Tuple
from fastmcp import FastMCP
from starlette.applications import Starlette
from starlette.responses import JSONResponse
from starlette.routing import Route, Mount
import uvicorn
import schwab_scraper.unified_api as api
# ---------------------------------------------------------------------------
# Log capture helper — scraper logs go to stderr which is invisible in MCP
# stdio mode. This captures them so we can return them in the response.
# ---------------------------------------------------------------------------
@contextmanager
def capture_logs(logger_name: str = "schwab_scraper", level: int = logging.DEBUG):
"""Context manager that captures log output to a string buffer.
Yields the buffer so callers can read captured logs after the block.
"""
logger = logging.getLogger(logger_name)
# Ensure the logger will actually process messages at this level
old_level = logger.level
if old_level > level:
logger.setLevel(level)
buf = io.StringIO()
handler = logging.StreamHandler(buf)
handler.setLevel(level)
# Match the format used by the scraper
handler.setFormatter(logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s"))
logger.addHandler(handler)
try:
yield buf
finally:
logger.removeHandler(handler)
logger.setLevel(old_level)
def _enrich_with_logs(result: dict, log_buffer: io.StringIO, debug: bool) -> dict:
"""Attach captured logs to a result dict when debug=True or on error."""
logs = log_buffer.getvalue()
if logs and (debug or not result.get("success", False)):
result["logs"] = logs
return result
# ---------------------------------------------------------------------------
# Monkey-patch mcp.shared.session.RequestResponder to work around a
# cancellation race in mcp==1.27.0 (github.com/modelcontextprotocol/
# python-sdk/issues/2416). A concurrent notifications/cancelled can set
# _completed=True between handler return and respond(), crashing the session
# with "AssertionError: Request already responded to".
# Remove once upstream ships a fix (likely mcp>=1.28).
# ---------------------------------------------------------------------------
def _patch_request_responder():
from mcp.shared.session import RequestResponder
_orig_respond = RequestResponder.respond
async def _safe_respond(self, response):
if self._completed:
logging.debug(
"respond() skipped for request %s — already completed (race with cancel)",
self.request_id,
)
return
return await _orig_respond(self, response)
_orig_cancel = RequestResponder.cancel
async def _safe_cancel(self):
if self._completed:
return
return await _orig_cancel(self)
RequestResponder.respond = _safe_respond
RequestResponder.cancel = _safe_cancel
_patch_request_responder()
# ---------------------------------------------------------------------------
# Login safety manager — lives in the MCP server layer, not the scraper.
# Provides rate-limiting and backoff for automated login attempts.
# ---------------------------------------------------------------------------
class LoginManager:
"""Tracks login attempts and enforces safety limits to avoid account lockouts."""
def __init__(self):
self.max_attempts = int(os.getenv("SCHWAB_LOGIN_MAX_ATTEMPTS", "3"))
self.window_minutes = int(os.getenv("SCHWAB_LOGIN_WINDOW_MIN", "60"))
self.backoff_minutes = int(os.getenv("SCHWAB_LOGIN_BACKOFF_MIN", "30"))
self._attempts: list[tuple[float, bool]] = []
def _trim_window(self) -> None:
cutoff = time.time() - (self.window_minutes * 60)
self._attempts = [(ts, success) for ts, success in self._attempts if ts > cutoff]
def can_login(self) -> Tuple[bool, str]:
"""Return (allowed: bool, reason: str)."""
self._trim_window()
failure_count = sum(1 for _, success in self._attempts if not success)
if failure_count >= self.max_attempts:
# Compute remaining backoff from most recent failure
last_failure_ts = max(ts for ts, success in self._attempts if not success)
elapsed = time.time() - last_failure_ts
remaining = (self.backoff_minutes * 60) - elapsed
if remaining > 0:
return (
False,
f"Login blocked: {failure_count} failures in window. "
f"Wait {int(remaining / 60)}m {int(remaining % 60)}s.",
)
recent_count = len(self._attempts)
return True, f"Allowed ({recent_count} attempts in last {self.window_minutes}m)"
def record_attempt(self, success: bool) -> None:
self._trim_window()
self._attempts.append((time.time(), success))
def get_status(self) -> dict:
self._trim_window()
failure_count = sum(1 for _, success in self._attempts if not success)
recent_count = len(self._attempts)
if failure_count >= self.max_attempts:
last_failure_ts = max(ts for ts, success in self._attempts if not success)
elapsed = time.time() - last_failure_ts
remaining = (self.backoff_minutes * 60) - elapsed
blocked = remaining > 0
else:
remaining = 0
blocked = False
return {
"blocked": blocked,
"remaining_backoff_seconds": max(0, int(remaining)),
"recent_attempts": recent_count,
"recent_failures": failure_count,
"max_attempts_per_window": self.max_attempts,
"window_minutes": self.window_minutes,
"backoff_minutes": self.backoff_minutes,
}
login_manager = LoginManager()
mcp = FastMCP("SchwabScraper")
def serialize(obj: Any) -> str:
"""Safely serialize Pydantic models or dataclasses to JSON string."""
if hasattr(obj, "model_dump_json"):
return obj.model_dump_json()
elif hasattr(obj, "model_dump"):
return json.dumps(obj.model_dump(), default=str)
elif isinstance(obj, list):
return json.dumps([
o.model_dump() if hasattr(o, "model_dump") else o
for o in obj
], default=str)
return json.dumps(obj, default=str)
# ---------------------------------------------------------------------------
# MCP tools
# ---------------------------------------------------------------------------
@mcp.tool()
async def get_session_status(debug: bool = False) -> str:
"""Get the current session status of the Schwab scraper.
Args:
debug: Enable debug logging
"""
result = await api.get_session_status(debug=debug)
# Enrich with login safety status
if result.get("success"):
data = result.get("data", {})
data["login_safety"] = login_manager.get_status()
return serialize(result)
@mcp.tool()
async def get_login_safety_status() -> str:
"""Get the current login safety status, including any active backoffs or limits.
Useful to check if a login attempt is likely to be blocked.
"""
return json.dumps(login_manager.get_status())
@mcp.tool()
async def login(
username: Optional[str] = None, password: Optional[str] = None, debug: bool = False
) -> str:
"""Perform an automated login to Schwab to establish a new session.
Args:
username: Schwab username (optional, will use env/config if omitted)
password: Schwab password (optional, will use env/config if omitted)
debug: Enable debug logging
"""
allowed, reason = login_manager.can_login()
if not allowed:
return json.dumps({
"success": False,
"error": f"Login blocked by safety safeguards: {reason}",
"error_type": "AUTHENTICATION",
"retryable": False,
"data": None,
})
with capture_logs(level=logging.DEBUG if debug else logging.INFO) as log_buf:
result = await api.login(username=username, password=password, debug=debug)
success = result.get("success", False)
login_manager.record_attempt(success)
result = _enrich_with_logs(result, log_buf, debug)
return serialize(result)
@mcp.tool()
async def refresh_session(debug: bool = False) -> str:
"""Refresh the current Schwab session to prevent expiration.
Args:
debug: Enable debug logging
"""
with capture_logs(level=logging.DEBUG if debug else logging.INFO) as log_buf:
result = await api.refresh_session(debug=debug)
result = _enrich_with_logs(result, log_buf, debug)
return serialize(result)
@mcp.tool()
async def list_accounts(debug: bool = False) -> str:
"""List all Schwab accounts.
Args:
debug: Enable debug logging
"""
result = await api.list_accounts(debug=debug)
return serialize(result)
@mcp.tool()
async def get_account_overview(account: Optional[str] = None, debug: bool = False) -> str:
"""Get the overview for a specific account.
Args:
account: Account summary or ID (optional)
debug: Enable debug logging
"""
result = await api.get_account_overview(account=account, debug=debug)
return serialize(result)
@mcp.tool()
async def get_positions(
account: Optional[str] = None,
include_non_equity: bool = False,
debug: bool = False,
) -> str:
"""Get positions for a specific account.
Args:
account: Account summary or ID (optional)
include_non_equity: Whether to include non-equity positions
debug: Enable debug logging
"""
result = await api.get_positions(
account=account, include_non_equity=include_non_equity, debug=debug
)
return serialize(result)
@mcp.tool()
async def get_transactions(
account: Optional[str] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
time_period: Optional[str] = None,
debug: bool = False,
) -> str:
"""Get transaction history.
Args:
account: Account ID (optional)
start_date: Start date for transactions (optional)
end_date: End date for transactions (optional)
time_period: Time period (e.g., '1D', '1M') (optional)
debug: Enable debug logging
"""
result = await api.get_transaction_history(
account=account,
start_date=start_date,
end_date=end_date,
time_period=time_period,
debug=debug,
)
return serialize(result)
@mcp.tool()
async def get_morningstar_data(ticker: str, debug: bool = False) -> str:
"""Get Morningstar data for a ticker.
Args:
ticker: Stock ticker symbol
debug: Enable debug logging
"""
result = await api.get_morningstar_data(ticker, debug=debug)
return serialize(result)
@mcp.tool()
async def upload_cookies(cookies_json: str) -> str:
"""Upload session cookies to the server to assist with authentication.
Args:
cookies_json: JSON string of cookies exported from a browser (Playwright format)
"""
try:
cookies = json.loads(cookies_json)
with open("cookies.json", "w") as f:
json.dump(cookies, f)
return json.dumps({"status": "success", "message": "cookies.json updated successfully"})
except Exception as e:
return json.dumps({"status": "error", "message": str(e)})
@mcp.tool()
async def api_call(endpoint: str, method: str = "GET", params: str = "{}") -> str:
"""Executes a raw API call to the Schwab service (placeholder).
Refer to the 'api-reference' resource for available endpoints and parameters.
Args:
endpoint: The API path
method: HTTP method (GET, POST, etc.)
params: JSON string of parameters/body
"""
return json.dumps({"status": "not_implemented", "message": "API pass-through not supported for scraper"})
@mcp.resource("service://api-reference")
def get_api_docs() -> str:
"""Returns the API documentation for using the 'api_call' tool."""
return (
"Schwab Scraper MCP Server — Unified API Documentation\n\n"
"This server provides tools to interact with Schwab accounts via scraping.\n"
"The 'api_call' tool is a placeholder."
)
async def health(request):
"""Health check endpoint."""
return JSONResponse({"status": "ok"})
mcp_app = mcp.http_app()
app = Starlette(
routes=[
Route("/health", health),
Mount("/", app=mcp_app),
],
lifespan=mcp_app.lifespan,
)
if __name__ == "__main__":
port = int(os.getenv("PORT", 8160))
uvicorn.run(app, host="0.0.0.0", port=port)