Files
schwab-mcp-custom/server.py

246 lines
8.1 KiB
Python

import json
import logging
import os
import sys
from typing import Optional, Any
# Ensure local vendor and package modules are in path
project_root = os.path.dirname(os.path.abspath(__file__))
if project_root not in sys.path:
sys.path.insert(0, project_root)
vendor_path = os.path.join(project_root, "vendor", "schwab-scraper")
if vendor_path not in sys.path:
sys.path.insert(0, vendor_path)
from fastmcp import FastMCP
from starlette.applications import Starlette
from starlette.responses import JSONResponse
from starlette.routing import Route, Mount
import uvicorn
# Import the unified API from the schwab_scraper dependency
import schwab_scraper.unified_api as api
# ---------------------------------------------------------------------------
# Monkey-patch mcp.shared.session.RequestResponder to work around a
# cancellation race in mcp==1.27.0 (github.com/modelcontextprotocol/
# python-sdk/issues/2416). A concurrent notifications/cancelled can set
# _completed=True between handler return and respond(), crashing the session
# with "AssertionError: Request already responded to".
# Remove once upstream ships a fix (likely mcp>=1.28).
# ---------------------------------------------------------------------------
def _patch_request_responder():
from mcp.shared.session import RequestResponder
_orig_respond = RequestResponder.respond
async def _safe_respond(self, response):
if self._completed:
logging.debug(
"respond() skipped for request %s — already completed (race with cancel)",
self.request_id,
)
return
return await _orig_respond(self, response)
_orig_cancel = RequestResponder.cancel
async def _safe_cancel(self):
if self._completed:
return
return await _orig_cancel(self)
RequestResponder.respond = _safe_respond
RequestResponder.cancel = _safe_cancel
_patch_request_responder()
# Initialize FastMCP
mcp = FastMCP("SchwabScraper")
def serialize(obj: Any) -> str:
"""Safely serialize Pydantic models or datclasses to JSON string."""
if hasattr(obj, "model_dump_json"):
return obj.model_dump_json()
elif hasattr(obj, "model_dump"):
return json.dumps(obj.model_dump(), default=str)
elif isinstance(obj, list):
# Handle lists of models
return json.dumps([
o.model_dump() if hasattr(o, "model_dump") else o
for o in obj
], default=str)
return json.dumps(obj, default=str)
@mcp.tool()
async def get_session_status(debug: bool = False) -> str:
"""Get the current session status of the Schwab scraper.
Args:
debug: Enable debug logging
"""
result = await api.get_session_status(debug=debug)
return serialize(result)
@mcp.tool()
async def get_login_safety_status() -> str:
"""Get the current login safety status, including any active backoffs or limits.
Useful to check if a login attempt is likely to be blocked.
"""
try:
result_str = await api.get_session_status()
result = json.loads(result_str)
if result.get("success") and "login_safety" in result.get("data", {}):
return json.dumps(result["data"]["login_safety"])
return json.dumps({"status": "unknown", "message": "Login safety info not available"})
except Exception as e:
return json.dumps({"error": str(e)})
@mcp.tool()
async def login(username: Optional[str] = None, password: Optional[str] = None, debug: bool = False) -> str:
"""Perform an automated login to Schwab to establish a new session.
Args:
username: Schwab username (optional, will use env if omitted)
password: Schwab password (optional, will use env if omitted)
debug: Enable debug logging
"""
result = await api.login(username=username, password=password, debug=debug)
return serialize(result)
@mcp.tool()
async def refresh_session(debug: bool = False) -> str:
"""Refresh the current Schwab session to prevent expiration.
Args:
debug: Enable debug logging
"""
result = await api.refresh_session(debug=debug)
return serialize(result)
@mcp.tool()
async def list_accounts(debug: bool = False) -> str:
"""List all Schwab accounts.
Args:
debug: Enable debug logging
"""
result = await api.list_accounts(debug=debug)
return serialize(result)
@mcp.tool()
async def get_account_overview(account: Optional[str] = None, debug: bool = False) -> str:
"""Get the overview for a specific account.
Args:
account: Account summary or ID (optional)
debug: Enable debug logging
"""
result = await api.get_account_overview(account=account, debug=debug)
return serialize(result)
@mcp.tool()
async def get_positions(account: Optional[str] = None, include_non_equity: bool = False, debug: bool = False) -> str:
"""Get positions for a specific account.
Args:
account: Account summary or ID (optional)
include_non_equity: Whether to include non-equity positions
debug: Enable debug logging
"""
result = await api.get_positions(account=account, include_non_equity=include_non_equity, debug=debug)
return serialize(result)
@mcp.tool()
async def get_transactions(
account: Optional[str] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
time_period: Optional[str] = None,
debug: bool = False
) -> str:
"""Get transaction history.
Args:
account: Account ID (optional)
start_date: Start date for transactions (optional)
end_date: End date for transactions (optional)
time_period: Time period (e.g., '1D', '1M') (optional)
debug: Enable debug logging
"""
result = await api.get_transaction_history(
account=account,
start_date=start_date,
end_date=end_date,
time_period=time_period,
debug=debug
)
return serialize(result)
@mcp.tool()
async def get_morningstar_data(ticker: str, debug: bool = False) -> str:
"""Get Morningstar data for a ticker.
Args:
ticker: Stock ticker symbol
debug: Enable debug logging
"""
result = await api.get_morningstar_data(ticker, debug=debug)
return serialize(result)
@mcp.tool()
async def upload_cookies(cookies_json: str) -> str:
"""Upload session cookies to the server to assist with authentication.
Args:
cookies_json: JSON string of cookies exported from a browser (Playwright format)
"""
try:
# Validate JSON
cookies = json.loads(cookies_json)
# Write to cookies.json
with open("cookies.json", "w") as f:
json.dump(cookies, f)
return json.dumps({"status": "success", "message": "cookies.json updated successfully"})
except Exception as e:
return json.dumps({"status": "error", "message": str(e)})
@mcp.tool()
async def api_call(endpoint: str, method: str = "GET", params: str = "{}") -> str:
"""Executes a raw API call to the Schwab service (Dummy implementation).
Refer to the 'api-reference' resource for available endpoints and parameters.
Args:
endpoint: The API path
method: HTTP method (GET, POST, etc.)
params: JSON string of parameters/body
"""
return json.dumps({"status": "not_implemented", "message": "API pass-through not supported for scraper"})
@mcp.resource("service://api-reference")
def get_api_docs() -> str:
"""Returns the API documentation for using the 'api_call' tool."""
return "Schwab Scraper MCP Server - Unified API Documentation\n\nThis server provides tools to interact with Schwab accounts via scraping. The 'api_call' tool is a placeholder."
async def health(request):
"""Health check endpoint."""
return JSONResponse({"status": "ok"})
# Create the Starlette application
mcp_app = mcp.http_app()
app = Starlette(
routes=[
Route("/health", health),
Mount("/", app=mcp_app)
],
lifespan=mcp_app.lifespan
)
if __name__ == "__main__":
port = int(os.getenv("PORT", 8160))
uvicorn.run(app, host="0.0.0.0", port=port)