import asyncio import dataclasses import io import json import logging import os import sys import time from contextlib import contextmanager from typing import Optional, Any, Tuple, Awaitable, Callable, TypeVar from fastmcp import FastMCP from starlette.applications import Starlette from starlette.responses import JSONResponse, Response from starlette.routing import Route, Mount import uvicorn import schwab_scraper.unified_api as api from schwab_scraper.storage.cache import read_cached_pdf # --------------------------------------------------------------------------- # Configure logging so it actually reaches stderr (visible in docker logs). # The scraper and MCP libraries log extensively but don't set up handlers # when imported as a module, so messages are silently dropped. # --------------------------------------------------------------------------- logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", stream=sys.stderr, ) # Ensure the scraper logger propagates to our root handler _scraper_logger = logging.getLogger("schwab_scraper") _scraper_logger.setLevel(logging.DEBUG if os.getenv("SCHWAB_DEBUG", "").lower() in ("1", "true") else logging.INFO) _scraper_logger.propagate = True _startup_logger = logging.getLogger("schwab_mcp_custom") def _read_commit_file(path: str) -> str | None: try: with open(path) as f: return f.read().strip() or None except FileNotFoundError: return None _scraper_commit = _read_commit_file( os.path.join(os.path.dirname(__file__), "schwab-scraper-commit.txt") ) _mcp_commit = _read_commit_file( os.path.join(os.path.dirname(__file__), "mcp-server-commit.txt") ) if _scraper_commit: _startup_logger.info("schwab-scraper commit: %s", _scraper_commit) else: _startup_logger.info("schwab-scraper commit: (not available)") if _mcp_commit: _startup_logger.info("mcp-server commit: %s", _mcp_commit) else: _startup_logger.info("mcp-server commit: (not available)") try: from importlib.metadata import version as _pkg_version _startup_logger.info("schwab-scraper package version: %s", _pkg_version("schwab-scraper")) except Exception: _startup_logger.info("schwab-scraper package version: (unknown)") _DEFAULT_BASE_URL = "https://schwab-mcp.ext.ben.io" # --------------------------------------------------------------------------- # Log capture helper — captures scraper logs to a string buffer AND tees # them to stderr so they remain visible in docker logs. # --------------------------------------------------------------------------- class _TeeHandler(logging.StreamHandler): """Handler that copies every record to a secondary (StringIO) buffer.""" def __init__(self, stream, extra_buf: io.StringIO, level=logging.NOTSET): super().__init__(stream) self.extra_buf = extra_buf self.tee_level = level def emit(self, record): super().emit(record) if record.levelno >= self.tee_level: try: msg = self.format(record) self.extra_buf.write(msg + "\n") self.extra_buf.flush() except Exception: pass @contextmanager def capture_logs(logger_name: str = "schwab_scraper", level: int = logging.DEBUG): """ Context manager that captures log output to a string buffer while still writing to stderr (docker-visible). Yields the buffer so callers can read captured logs after the block. """ logger = logging.getLogger(logger_name) old_level = logger.level if old_level > level: logger.setLevel(level) buf = io.StringIO() handler = _TeeHandler(sys.stderr, buf, level=level) handler.setLevel(level) handler.setFormatter(logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")) logger.addHandler(handler) # Also tee the root logger in case scraper logs through sub-loggers root_old_level = logging.getLogger().level if root_old_level > level: logging.getLogger().setLevel(level) try: yield buf finally: logger.removeHandler(handler) if old_level != logger.level: logger.setLevel(old_level) if root_old_level != logging.getLogger().level: logging.getLogger().setLevel(root_old_level) def _enrich_with_logs(result: dict, log_buffer: io.StringIO, debug: bool) -> dict: """Attach captured logs to a result dict when debug=True or on error.""" logs = log_buffer.getvalue() if logs and (debug or not result.get("success", False)): result["logs"] = logs return result # --------------------------------------------------------------------------- # Monkey-patch mcp.shared.session.RequestResponder to work around a # cancellation race in mcp==1.27.0 (github.com/modelcontextprotocol/ # python-sdk/issues/2416). A concurrent notifications/cancelled can set # _completed=True between handler return and respond(), crashing the session # with "AssertionError: Request already responded to". # Remove once upstream ships a fix (likely mcp>=1.28). # --------------------------------------------------------------------------- def _patch_request_responder(): from mcp.shared.session import RequestResponder _orig_respond = RequestResponder.respond async def _safe_respond(self, response): if self._completed: logging.debug( "respond() skipped for request %s — already completed (race with cancel)", self.request_id, ) return return await _orig_respond(self, response) _orig_cancel = RequestResponder.cancel async def _safe_cancel(self): if self._completed: return return await _orig_cancel(self) RequestResponder.respond = _safe_respond RequestResponder.cancel = _safe_cancel _patch_request_responder() # --------------------------------------------------------------------------- # Login safety manager — lives in the MCP server layer, not the scraper. # Provides rate-limiting and backoff for automated login attempts. # --------------------------------------------------------------------------- class LoginManager: """Tracks login attempts and enforces safety limits to avoid account lockouts.""" def __init__(self): self.max_attempts = int(os.getenv("SCHWAB_LOGIN_MAX_ATTEMPTS", "3")) self.window_minutes = int(os.getenv("SCHWAB_LOGIN_WINDOW_MIN", "60")) self.backoff_minutes = int(os.getenv("SCHWAB_LOGIN_BACKOFF_MIN", "30")) self._attempts: list[tuple[float, bool]] = [] def _trim_window(self) -> None: cutoff = time.time() - (self.window_minutes * 60) self._attempts = [(ts, success) for ts, success in self._attempts if ts > cutoff] def can_login(self) -> Tuple[bool, str]: """Return (allowed: bool, reason: str).""" self._trim_window() failure_count = sum(1 for _, success in self._attempts if not success) if failure_count >= self.max_attempts: # Compute remaining backoff from most recent failure last_failure_ts = max(ts for ts, success in self._attempts if not success) elapsed = time.time() - last_failure_ts remaining = (self.backoff_minutes * 60) - elapsed if remaining > 0: return ( False, f"Login blocked: {failure_count} failures in window. " f"Wait {int(remaining / 60)}m {int(remaining % 60)}s.", ) recent_count = len(self._attempts) return True, f"Allowed ({recent_count} attempts in last {self.window_minutes}m)" def record_attempt(self, success: bool) -> None: self._trim_window() self._attempts.append((time.time(), success)) def get_status(self) -> dict: self._trim_window() failure_count = sum(1 for _, success in self._attempts if not success) recent_count = len(self._attempts) if failure_count >= self.max_attempts: last_failure_ts = max(ts for ts, success in self._attempts if not success) elapsed = time.time() - last_failure_ts remaining = (self.backoff_minutes * 60) - elapsed blocked = remaining > 0 else: remaining = 0 blocked = False return { "blocked": blocked, "remaining_backoff_seconds": max(0, int(remaining)), "recent_attempts": recent_count, "recent_failures": failure_count, "max_attempts_per_window": self.max_attempts, "window_minutes": self.window_minutes, "backoff_minutes": self.backoff_minutes, } login_manager = LoginManager() mcp = FastMCP("SchwabScraper") T = TypeVar("T") _auth_gate_lock = asyncio.Lock() _auth_active_task: asyncio.Task[Any] | None = None _auth_active_operation: str | None = None _auth_started_at: float | None = None _auth_waiters = 0 def _auth_gate_status() -> dict: return { "in_progress": _auth_active_task is not None and not _auth_active_task.done(), "operation": _auth_active_operation, "started_at": _auth_started_at, "elapsed_seconds": int(time.time() - _auth_started_at) if _auth_started_at else None, "waiters": _auth_waiters, } async def _decrement_auth_waiters() -> None: global _auth_waiters async with _auth_gate_lock: _auth_waiters = max(0, _auth_waiters - 1) async def _clear_auth_task(task: asyncio.Task[Any]) -> None: global _auth_active_task, _auth_active_operation, _auth_started_at async with _auth_gate_lock: if _auth_active_task is task: _auth_active_task = None _auth_active_operation = None _auth_started_at = None def _schedule_auth_task_clear(task: asyncio.Task[Any]) -> None: try: asyncio.create_task(_clear_auth_task(task)) except RuntimeError: pass async def _run_auth_serialized( operation: str, coro_factory: Callable[[], Awaitable[T]], *, share_same_operation: bool = False, ) -> T: """Run auth-sensitive work without letting request cancellation cancel it.""" global _auth_active_task, _auth_active_operation, _auth_started_at, _auth_waiters while True: async with _auth_gate_lock: active_task = _auth_active_task active_operation = _auth_active_operation if active_task is None or active_task.done(): task = asyncio.create_task(coro_factory()) _auth_active_task = task _auth_active_operation = operation _auth_started_at = time.time() task.add_done_callback(_schedule_auth_task_clear) break _auth_waiters += 1 try: if share_same_operation and active_operation == operation: return await asyncio.shield(active_task) try: await asyncio.shield(active_task) except Exception: # The current operation should still get a chance to run after a # prior auth-sensitive task fails. pass finally: await _decrement_auth_waiters() return await asyncio.shield(task) def _json_default(obj: Any) -> Any: """JSON fallback handler that converts dataclasses to dicts before str().""" if dataclasses.is_dataclass(obj) and not isinstance(obj, type): return dataclasses.asdict(obj) return str(obj) def serialize(obj: Any) -> str: """Safely serialize Pydantic models or dataclasses to JSON string.""" if hasattr(obj, "model_dump_json"): return obj.model_dump_json() elif hasattr(obj, "model_dump"): return json.dumps(obj.model_dump(), default=_json_default) elif isinstance(obj, list): return json.dumps([ o.model_dump() if hasattr(o, "model_dump") else o for o in obj ], default=_json_default) return json.dumps(obj, default=_json_default) # --------------------------------------------------------------------------- # MCP tools # --------------------------------------------------------------------------- @mcp.tool() async def get_session_status(debug: bool = False) -> str: """Get the current session status of the Schwab scraper. Args: debug: Enable debug logging """ result = await api.get_session_status(debug=debug) # Enrich with login safety status if result.get("success"): data = result.get("data", {}) data["login_safety"] = login_manager.get_status() data["auth_gate"] = _auth_gate_status() return serialize(result) @mcp.tool() async def get_login_safety_status() -> str: """Get the current login safety status, including any active backoffs or limits. Useful to check if a login attempt is likely to be blocked. """ status = login_manager.get_status() status["auth_gate"] = _auth_gate_status() return json.dumps(status) @mcp.tool() async def login( username: Optional[str] = None, password: Optional[str] = None, debug: bool = False ) -> str: """Perform an automated login to Schwab to establish a new session. Args: username: Schwab username (optional, will use env/config if omitted) password: Schwab password (optional, will use env/config if omitted) debug: Enable debug logging """ allowed, reason = login_manager.can_login() if not allowed: return json.dumps({ "success": False, "error": f"Login blocked by safety safeguards: {reason}", "error_type": "AUTHENTICATION", "retryable": False, "data": None, }) mcp_logger = logging.getLogger("schwab_mcp_custom") mcp_logger.info("=== LOGIN TOOL CALLED ===") mcp_logger.info(f"debug={debug}, username_provided={bool(username)}, password_provided={bool(password)}") # Diagnostic: if credentials not provided, show what config path would be used if not username or not password: from schwab_scraper.core.config import get_config_path config_path = get_config_path() config_exists = os.path.exists(config_path) mcp_logger.info(f"Config fallback: path={config_path}, exists={config_exists}") async def _login_impl() -> dict: mcp_logger.info("capture_logs context entered") from schwab_scraper.browser.auth import login_to_schwab from schwab_scraper.core.config import get_schwab_credentials, load_config resolved_username = username resolved_password = password if not resolved_username or not resolved_password: config = load_config() resolved_username, resolved_password = get_schwab_credentials(config) if not resolved_username or not resolved_password: result = { "success": False, "error": "Username and password are required (or set in config.json)", "error_type": "AUTHENTICATION", "retryable": False, "data": None, } login_manager.record_attempt(False) return result try: cookies = await login_to_schwab(resolved_username, resolved_password) if cookies: result = { "success": True, "data": {"cookies_count": len(cookies)}, "error": None, "error_type": None, "retryable": False, } else: result = { "success": False, "error": "Login failed — no cookies returned. Check credentials or 2FA status.", "error_type": "AUTHENTICATION", "retryable": True, "data": None, } except Exception as exc: result = { "success": False, "error": str(exc), "error_type": "UNKNOWN", "retryable": True, "data": None, } login_manager.record_attempt(result.get("success", False)) return result with capture_logs(level=logging.DEBUG if debug else logging.INFO) as log_buf: if debug: mcp_logger.info("DEBUG MODE ENABLED — verbose logging active") result = await _run_auth_serialized( "login", _login_impl, share_same_operation=not username and not password, ) success = result.get("success", False) mcp_logger.info(f"login completed — success={success}") result = _enrich_with_logs(result, log_buf, debug) mcp_logger.info("capture_logs context exited, returning result") return serialize(result) @mcp.tool() async def refresh_session(debug: bool = False) -> str: """Refresh the current Schwab session to prevent expiration. Args: debug: Enable debug logging """ with capture_logs(level=logging.DEBUG if debug else logging.INFO) as log_buf: result = await _run_auth_serialized( "refresh_session", lambda: api.refresh_session(debug=debug), ) result = _enrich_with_logs(result, log_buf, debug) return serialize(result) @mcp.tool() async def list_accounts(debug: bool = False) -> str: """List all Schwab accounts. Args: debug: Enable debug logging """ result = await _run_auth_serialized( "list_accounts", lambda: api.list_accounts(debug=debug), ) return serialize(result) @mcp.tool() async def get_account_overview(account: Optional[str] = None, debug: bool = False) -> str: """Get the overview for a specific account. Args: account: Account summary or ID (optional) debug: Enable debug logging """ result = await _run_auth_serialized( "get_account_overview", lambda: api.get_account_overview(account=account, debug=debug), ) return serialize(result) @mcp.tool() async def get_positions( account: Optional[str] = None, include_non_equity: bool = False, debug: bool = False, ) -> str: """Get positions for a specific account. Args: account: Account summary or ID (optional) include_non_equity: Whether to include non-equity positions debug: Enable debug logging """ result = await _run_auth_serialized( "get_positions", lambda: api.get_positions( account=account, include_non_equity=include_non_equity, debug=debug ), ) return serialize(result) @mcp.tool() async def get_transactions( account: Optional[str] = None, start_date: Optional[str] = None, end_date: Optional[str] = None, time_period: Optional[str] = None, debug: bool = False, ) -> str: """Get transaction history. Args: account: Account ID (optional) start_date: Start date for transactions (optional) end_date: End date for transactions (optional) time_period: Time period (e.g., '1D', '1M') (optional) debug: Enable debug logging """ result = await _run_auth_serialized( "get_transactions", lambda: api.get_transaction_history( account=account, start_date=start_date, end_date=end_date, time_period=time_period, debug=debug, ), ) return serialize(result) @mcp.tool() async def get_morningstar_data(ticker: str, debug: bool = False) -> str: """Get Morningstar data for a ticker. Args: ticker: Stock ticker symbol debug: Enable debug logging """ result = await _run_auth_serialized( "get_morningstar_data", lambda: api.get_morningstar_data(ticker, debug=debug), ) # When the scraper used blob URLs (modern Schwab web components), report_url # is None even though the PDF was downloaded and parsed successfully. Point # callers at the MCP server's cached-PDF endpoint instead. if ( isinstance(result, dict) and result.get("success") and result.get("data") is not None ): data = result["data"] if hasattr(data, "report_url") and data.report_url is None and data.source is not None: base = os.getenv("SCHWAB_MCP_BASE_URL", _DEFAULT_BASE_URL).rstrip("/") data.report_url = f"{base}/reports/{ticker.upper()}/pdf" return serialize(result) @mcp.tool() async def upload_cookies(cookies_json: str) -> str: """Upload session cookies to the server to assist with authentication. Args: cookies_json: JSON string of cookies exported from a browser (Playwright format) """ try: cookies = json.loads(cookies_json) # Some browser extensions wrap cookies in an object (e.g. {"cookies": [...]}) if isinstance(cookies, dict): if "cookies" in cookies: cookies = cookies["cookies"] else: return json.dumps({ "status": "error", "message": "Expected a list of cookies or an object with a 'cookies' key", }) if not isinstance(cookies, list): return json.dumps({ "status": "error", "message": f"Expected a list of cookies, got {type(cookies).__name__}", }) from schwab_scraper.core.config import get_cookies_path cookies_path = get_cookies_path() with open(cookies_path, "w") as f: json.dump(cookies, f, indent=2) return json.dumps({ "status": "success", "message": f"{cookies_path} updated with {len(cookies)} cookies", }) except Exception as e: return json.dumps({"status": "error", "message": str(e)}) @mcp.tool() async def api_call(endpoint: str, method: str = "GET", params: str = "{}") -> str: """Executes a raw API call to the Schwab service (placeholder). Refer to the 'api-reference' resource for available endpoints and parameters. Args: endpoint: The API path method: HTTP method (GET, POST, etc.) params: JSON string of parameters/body """ return json.dumps({"status": "not_implemented", "message": "API pass-through not supported for scraper"}) @mcp.resource("service://api-reference") def get_api_docs() -> str: """Returns the API documentation for using the 'api_call' tool.""" return ( "Schwab Scraper MCP Server — Unified API Documentation\n\n" "This server provides tools to interact with Schwab accounts via scraping.\n" "The 'api_call' tool is a placeholder." ) async def health(request): """Health check endpoint.""" return JSONResponse({"status": "ok"}) async def serve_report_pdf(request): """Serve a cached Morningstar report PDF by ticker.""" ticker = request.path_params["ticker"].upper() pdf_bytes = read_cached_pdf(ticker) if not pdf_bytes: return JSONResponse( {"error": f"No cached report for {ticker}. Call get_morningstar_data first."}, status_code=404, ) return Response( pdf_bytes, media_type="application/pdf", headers={"Content-Disposition": f'inline; filename="{ticker}_morningstar.pdf"'}, ) mcp_app = mcp.http_app() app = Starlette( routes=[ Route("/health", health), Route("/reports/{ticker}/pdf", serve_report_pdf), Mount("/", app=mcp_app), ], lifespan=mcp_app.lifespan, ) if __name__ == "__main__": port = int(os.getenv("PORT", 8160)) uvicorn.run(app, host="0.0.0.0", port=port)