Files
schwab-mcp-custom/schwab_scraper/features/accounts_positions/portfolio_scraper.py
b3nw 650ea2d087
All checks were successful
Build and Push Docker Image / build (push) Successful in 34s
Fix build: Bundle schwab_scraper source and use local dependencies
2026-04-24 01:50:20 +00:00

135 lines
4.6 KiB
Python

from __future__ import annotations
from decimal import Decimal, InvalidOperation
from typing import Iterable, Optional
from ...core import AccountSummary, Envelope, ErrorType, PortfolioSnapshot, Position, fail, ok
from .positions_scraper import get_positions
def _aggregate_positions(positions: Iterable[Position]) -> tuple[list[Position], Optional[Decimal]]:
aggregated: dict[str, Position] = {}
total_value = Decimal("0")
has_value = False
for position in positions:
if position.market_value is not None:
total_value += position.market_value
has_value = True
key = position.symbol.upper() if position.symbol else "UNKNOWN"
if key not in aggregated:
aggregated[key] = Position(
symbol=position.symbol,
description=position.description,
asset_type=position.asset_type,
quantity=position.quantity,
market_price=position.market_price,
market_value=position.market_value,
cost_basis_total=position.cost_basis_total,
unrealized_gain=position.unrealized_gain,
unrealized_gain_pct=position.unrealized_gain_pct,
lots=list(position.lots),
)
continue
existing = aggregated[key]
if position.quantity is not None:
if existing.quantity is None:
existing.quantity = position.quantity
else:
existing.quantity += position.quantity
if position.market_value is not None:
if existing.market_value is None:
existing.market_value = position.market_value
else:
existing.market_value += position.market_value
if position.cost_basis_total is not None:
if existing.cost_basis_total is None:
existing.cost_basis_total = position.cost_basis_total
else:
existing.cost_basis_total += position.cost_basis_total
if position.unrealized_gain is not None:
if existing.unrealized_gain is None:
existing.unrealized_gain = position.unrealized_gain
else:
existing.unrealized_gain += position.unrealized_gain
if position.market_price is not None:
existing.market_price = position.market_price
if position.unrealized_gain_pct is not None:
existing.unrealized_gain_pct = position.unrealized_gain_pct
if position.description and not existing.description:
existing.description = position.description
if position.asset_type:
existing.asset_type = position.asset_type
if position.lots:
existing.lots.extend(position.lots)
for item in aggregated.values():
if item.unrealized_gain is not None and item.cost_basis_total not in (None, Decimal("0")):
try:
item.unrealized_gain_pct = float((item.unrealized_gain / item.cost_basis_total) * 100)
except (InvalidOperation, ZeroDivisionError):
item.unrealized_gain_pct = None
total_value_out = total_value if has_value else None
return list(aggregated.values()), total_value_out
async def get_portfolio_snapshot(
account: AccountSummary | str | None = None,
*,
aggregate_by_symbol: bool = True,
include_non_equity: bool = False,
debug: bool = False,
) -> Envelope[PortfolioSnapshot]:
positions_envelope = await get_positions(
account=account,
include_non_equity=include_non_equity,
debug=debug,
)
if not positions_envelope["success"]:
return fail(
positions_envelope.get("error") or "Failed to retrieve positions.",
positions_envelope.get("error_type") or ErrorType.UNKNOWN,
positions_envelope.get("retryable", True),
)
positions = positions_envelope["data"] or []
if aggregate_by_symbol:
aggregated_positions, total_value = _aggregate_positions(positions)
count = len(aggregated_positions)
snapshot = PortfolioSnapshot(
equities=aggregated_positions,
total_value=total_value,
count=count,
)
return ok(snapshot)
total_value = Decimal("0")
has_value = False
for position in positions:
if position.market_value is not None:
total_value += position.market_value
has_value = True
total_value_out = total_value if has_value else None
snapshot = PortfolioSnapshot(
equities=positions,
total_value=total_value_out,
count=len(positions),
)
return ok(snapshot)