Files
schwab-mcp-custom/schwab_scraper/features/accounts_positions/positions_scraper.py
b3nw 650ea2d087
All checks were successful
Build and Push Docker Image / build (push) Successful in 34s
Fix build: Bundle schwab_scraper source and use local dependencies
2026-04-24 01:50:20 +00:00

433 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import re
from decimal import Decimal, InvalidOperation
from typing import Any, Optional, Sequence
from ...browser.auth import ensure_cookies
from ...browser.client import connect, new_context, new_page
from ...browser.navigation import goto_with_auth_check
from ...core import AccountSummary, Envelope, ErrorType, Lot, Position, fail, ok
from ...core.config import get_playwright_url, load_config
POSITIONS_URL = "https://client.schwab.com/app/accounts/positions/#/"
def _parse_decimal(value: str | None) -> Optional[Decimal]:
if not value:
return None
cleaned = value.strip()
if not cleaned or cleaned in {"-", "--"}:
return None
negative = False
if cleaned.startswith("(") and cleaned.endswith(")"):
negative = True
cleaned = (
cleaned.replace("$", "")
.replace(",", "")
.replace("(", "")
.replace(")", "")
.replace("", "-")
.replace("%", "")
.strip()
)
if not cleaned:
return None
try:
parsed = Decimal(cleaned)
if negative or parsed < 0:
parsed = -abs(parsed)
return parsed
except InvalidOperation:
return None
def _parse_float(value: str | None) -> Optional[float]:
decimal_value = _parse_decimal(value)
if decimal_value is None:
return None
try:
return float(decimal_value)
except (ValueError, InvalidOperation):
return None
def _normalize_account_label(label: str) -> AccountSummary:
normalized = re.sub(r"\s+", " ", label).strip()
last4_match = re.search(r"(\d{3,4})\b", normalized.replace(" ", ""))
last4 = last4_match.group(1)[-4:] if last4_match else None
type_match = re.search(r"^[A-Za-z&'\- ]+", normalized)
account_type = re.sub(r"\s+", "_", type_match.group(0).strip()) if type_match else "Account"
account_id = f"{account_type}-{last4}" if last4 else account_type
return AccountSummary(
id=account_id,
label=normalized,
type=account_type,
last4=last4,
is_margin="margin" in normalized.lower(),
)
def _match_account(candidate: AccountSummary, requested: AccountSummary | str | None) -> bool:
if requested is None:
return True
if isinstance(requested, AccountSummary):
requested_values = {
requested.id.lower(),
requested.label.lower(),
}
if requested.last4:
requested_values.add(requested.last4.lower())
else:
lookup = requested.strip().lower()
requested_values = {lookup}
candidate_values = {candidate.id.lower(), candidate.label.lower()}
if candidate.last4:
candidate_values.add(candidate.last4.lower())
return bool(candidate_values & requested_values)
def classify_asset(symbol: str | None, description: str | None) -> str:
if symbol:
sym = symbol.strip().upper()
else:
sym = ""
desc = (description or "").strip().upper()
if sym and re.fullmatch(r"[A-Z]{1,5}", sym):
if "ETF" in desc:
return "ETF"
if any(kw in desc for kw in ["FUND", "MUTUAL"]):
return "MUTUAL_FUND"
return "EQUITY"
if sym and re.search(r"\d", sym) and len(sym) > 5:
return "OPTION"
if any(kw in desc for kw in ["BOND", "CD", "TREASURY"]):
return "BOND"
if sym in {"CASH", "MMDA", "SWEEP"} or "CASH" in desc:
return "CASH"
if "ETF" in desc:
return "ETF"
if "FUND" in desc:
return "MUTUAL_FUND"
return "OTHER"
async def _evaluate_table(page) -> dict[str, Any] | None:
return await page.evaluate(
"""
() => {
const table = document.querySelector('#positionsDetails');
if (!table) {
return null;
}
const headers = Array.from(table.querySelectorAll('thead tr th')).map((th) =>
(th.innerText || th.textContent || '').trim()
);
const rowElements = Array.from(table.querySelectorAll('tbody tr'));
const rows = [];
let current = null;
let currentAccount = null;
const isLotRow = (row) => {
const klass = (row.className || '').toLowerCase();
if (klass.includes('lot') || klass.includes('sub') || klass.includes('child')) {
return true;
}
const dataRole = (row.getAttribute('data-row-type') || '').toLowerCase();
return dataRole.includes('lot');
};
const isPositionRow = (row) => {
const klass = (row.className || '').toLowerCase();
return klass.includes('position-row');
};
const isAccountHeader = (row) => {
const klass = (row.className || '').toLowerCase();
const text = (row.textContent || '').trim();
return !klass.includes('position-row') &&
(klass.includes('highlight-row') || klass.includes('border-top-dark')) &&
text.includes('account panel');
};
for (const row of rowElements) {
// Check if this is an account header row
if (isAccountHeader(row)) {
const text = row.textContent.trim();
// Extract account name from account panel text
const match = text.match(/account panel[\\s\\n]+([^\\n]+)/);
if (match) {
currentAccount = match[1].trim();
}
continue;
}
const cells = Array.from(row.querySelectorAll('td')).map((cell) =>
(cell.innerText || cell.textContent || '').trim()
);
if (!cells.length) {
continue;
}
if (isLotRow(row)) {
if (current) {
current.lots.push(cells);
}
} else if (isPositionRow(row)) {
// Extract symbol from data-symbol attribute
const symbol = row.getAttribute('data-symbol') || '';
current = {
type: 'position',
cells: cells,
lots: [],
symbol: symbol,
account: currentAccount
};
rows.push(current);
}
}
return { headers, rows };
}
"""
)
def _map_row(headers: Sequence[str], cells: Sequence[str]) -> dict[str, str]:
result: dict[str, str] = {}
# Special handling: The table has columns in headers that don't correspond to cells
# Headers: ['', 'Symbol', 'Description', 'Qty', 'Price', ...]
# Cells: ['VANGUARD...', '192.5', '$328.17', ...]
# The first two headers (empty checkbox and Symbol) have no corresponding cells
# So: Cell 0 → 'Description', Cell 1 → 'Qty', Cell 2 → 'Price', etc.
# Find the symbol header index to know where the offset starts
symbol_header_idx = None
for idx, header in enumerate(headers):
key = header.strip().lower()
if 'symbol' in key and 'description' not in key:
symbol_header_idx = idx
break
# Calculate offset - typically 2 (empty column + symbol column)
offset = symbol_header_idx + 1 if symbol_header_idx is not None else 0
for idx, header in enumerate(headers):
# Normalize header: take first line, strip, lowercase
# Headers often have format "Label\nsort\nfieldname"
header_parts = header.strip().split('\n')
key = header_parts[0].strip().lower() if header_parts else ""
if not key:
key = f"column_{idx}"
# Map header to cell with offset
if idx < offset:
# These headers (empty, symbol) have no corresponding cells
value = ""
else:
cell_idx = idx - offset
value = cells[cell_idx].strip() if cell_idx < len(cells) else ""
result[key] = value
return result
def _parse_lots(lot_rows: Sequence[Sequence[str]]) -> list[Lot]:
lots: list[Lot] = []
for cells in lot_rows:
if not cells:
continue
acquired_date = cells[0].strip() if len(cells) > 0 else None
quantity = _parse_float(cells[1] if len(cells) > 1 else None)
cost_basis = _parse_decimal(cells[2] if len(cells) > 2 else None)
lot_id = cells[3].strip() if len(cells) > 3 else None
lots.append(
Lot(
acquired_date=acquired_date or None,
quantity=quantity,
cost_basis=cost_basis,
lot_id=lot_id or None,
)
)
return lots
def _row_to_position(row_map: dict[str, str], lots_rows: Sequence[Sequence[str]], symbol: str = "") -> Position:
# Symbol is now passed from data-symbol attribute on row
# Description is in the first visible cell
description = row_map.get('description') or row_map.get('name') or row_map.get('column_1') or ""
# Price is typically in column labeled 'price' or similar
market_price = _parse_decimal(
row_map.get('price')
or row_map.get('market price')
or row_map.get('last price')
)
# Quantity - now in different column due to layout change
quantity = _parse_float(row_map.get('quantity') or row_map.get('qty'))
market_value = _parse_decimal(row_map.get('market value') or row_map.get('mkt val'))
cost_basis_total = _parse_decimal(row_map.get('cost basis') or row_map.get('total cost'))
unrealized_gain = _parse_decimal(
row_map.get('gain/loss $')
or row_map.get('unrealized gain')
or row_map.get('gain/loss')
)
unrealized_gain_pct = _parse_float(
row_map.get('gain/loss %')
or row_map.get('unrealized gain %')
)
asset_type = classify_asset(symbol, description)
lots = _parse_lots(lots_rows)
return Position(
symbol=symbol or "",
description=description or None,
asset_type=asset_type,
quantity=quantity,
market_price=market_price,
market_value=market_value,
cost_basis_total=cost_basis_total,
unrealized_gain=unrealized_gain,
unrealized_gain_pct=unrealized_gain_pct,
lots=lots,
)
async def get_positions(
account: AccountSummary | str | None = None,
*,
include_non_equity: bool = False,
debug: bool = False,
) -> Envelope[list[Position]]:
cookies = await ensure_cookies()
if not cookies:
return fail("Unable to establish Schwab session.", ErrorType.AUTHENTICATION, retryable=False)
config = load_config()
playwright_url = get_playwright_url(config)
playwright = browser = context = page = None
try:
playwright, browser = await connect(playwright_url)
context = await new_context(browser, cookies=cookies)
page = await new_page(context)
if not await goto_with_auth_check(page, context, POSITIONS_URL, debug=debug):
return fail("Failed to load Schwab positions page.", ErrorType.AUTHENTICATION, retryable=True)
await page.wait_for_selector('#positionsDetails', timeout=45000)
await page.wait_for_timeout(1000)
await page.evaluate('window.scrollTo(0, document.body.scrollHeight)')
await page.wait_for_timeout(1500)
table_data = await _evaluate_table(page)
if not table_data:
return fail("Unable to locate positions table.", ErrorType.PARSING, retryable=True)
headers = [header.strip().lower() for header in table_data.get('headers') or []]
if not headers:
return fail("Positions table headers not found.", ErrorType.PARSING, retryable=True)
positions: list[Position] = []
for row in table_data.get('rows', []):
if row.get('type') != 'position':
continue
cells = row.get('cells') or []
symbol = row.get('symbol') or ""
account_label = row.get('account') or ""
row_map = _map_row(headers, cells)
position = _row_to_position(row_map, row.get('lots') or [], symbol=symbol)
# Filter by account if requested
if account is not None and account_label:
# Normalize the account label from the row
account_summary = _normalize_account_label(account_label)
if not _match_account(account_summary, account):
continue
elif account is not None and not account_label:
# If filtering by account but row has no account, skip it
continue
if not include_non_equity and position.asset_type not in {"EQUITY", "ETF"}:
continue
positions.append(position)
if not positions:
return fail("No positions matched the requested criteria.", ErrorType.VALIDATION, retryable=False)
return ok(positions)
except Exception as exc:
return fail(str(exc), ErrorType.UNKNOWN, retryable=True)
finally:
await _safe_close_page(page)
await _safe_close_context(context)
await _safe_close_browser(browser)
await _safe_stop_playwright(playwright)
async def _safe_close_page(page) -> None:
if page is None:
return
try:
await page.close()
except Exception:
pass
async def _safe_close_context(context) -> None:
if context is None:
return
try:
await context.close()
except Exception:
pass
async def _safe_close_browser(browser) -> None:
if browser is None:
return
try:
await browser.close()
except Exception:
pass
async def _safe_stop_playwright(playwright) -> None:
if playwright is None:
return
try:
await playwright.stop()
except Exception:
pass