From ddaf4190f9525cba5572acc659c6e3015da7e065 Mon Sep 17 00:00:00 2001 From: b3nw Date: Tue, 9 Jun 2026 19:43:47 +0000 Subject: [PATCH] feat: implement multi-server support and sync tools - Introduced ServerRegistry to manage multiple NPM instances - Added support for NPM_SERVERS JSON environment variable - Updated all tools to support optional 'server' targeting - Implemented clone_proxy_host, sync_access_lists, and sync_certificates tools - Transitioned get_proxy_host_logs to API-based retrieval with local fallback - Added comprehensive test suite for multi-server management and sync operations Co-authored-by: claw-io --- src/npm_mcp/client.py | 59 ++++ src/npm_mcp/config.py | 25 ++ src/npm_mcp/server.py | 594 ++++++++++++++++++++++++++++++++----- tests/test_client.py | 89 +++++- tests/test_config.py | 38 +++ tests/test_registry.py | 97 ++++++ tests/test_server_tools.py | 315 ++++++++++++++++++++ uv.lock | 2 +- 8 files changed, 1148 insertions(+), 71 deletions(-) create mode 100644 tests/test_registry.py create mode 100644 tests/test_server_tools.py diff --git a/src/npm_mcp/client.py b/src/npm_mcp/client.py index 225f37a..5459500 100644 --- a/src/npm_mcp/client.py +++ b/src/npm_mcp/client.py @@ -2,6 +2,7 @@ import logging from datetime import UTC, datetime, timedelta +from typing import Any import httpx @@ -392,3 +393,61 @@ class NpmClient: response = await self._request("POST", "/nginx/certificates", json=payload) return Certificate(**response.json()) + + async def get_proxy_host_logs( + self, + host_id: int, + log_type: str = "access", + lines: int = 100, + ) -> dict[str, Any]: + """Retrieve proxy host logs via the API. + + Args: + host_id: NPM proxy host ID. + log_type: "access" or "error". + lines: Number of most recent lines to return. + """ + response = await self._request( + "GET", + f"/nginx/proxy-hosts/{host_id}/logs", + params={"type": log_type, "limit": lines}, + ) + return response.json() + + async def get_proxy_host_logs_summary(self, host_id: int) -> dict[str, Any]: + """Retrieve proxy host logs summary via the API. + + Args: + host_id: NPM proxy host ID. + """ + response = await self._request("GET", f"/nginx/proxy-hosts/{host_id}/logs/summary") + return response.json() + + async def create_access_list( + self, + name: str, + satisfy_any: bool = False, + pass_auth: bool = False, + items: list[dict[str, Any]] | None = None, + clients: list[dict[str, Any]] | None = None, + ) -> AccessList: + """Create a new access list on the server. + + Args: + name: Name of the access list + satisfy_any: Satisfy any HTTP auth or IP restriction + pass_auth: Pass auth headers to host + items: Auth entries (username/password) and IP restrictions + clients: IP restriction clients + """ + payload = { + "name": name, + "satisfy_any": satisfy_any, + "pass_auth": pass_auth, + "items": items or [], + "clients": clients or [], + } + response = await self._request("POST", "/nginx/access-lists", json=payload) + return AccessList(**response.json()) + + diff --git a/src/npm_mcp/config.py b/src/npm_mcp/config.py index 92bc585..a7d08b3 100644 --- a/src/npm_mcp/config.py +++ b/src/npm_mcp/config.py @@ -37,6 +37,31 @@ class Settings(BaseSettings): identity: str = "" secret: str = "" + # Multi-Server Configuration + # Example JSON: + # '[{"name": "prod", "url": "http://10.0.0.10:81/api", + # "identity": "admin@example.com", "secret": "pwd"}]' + servers: list[dict[str, Any]] = [] + default_server: str | None = None + + @field_validator("servers", mode="before") + @classmethod + def parse_servers(cls, v: Any) -> list[dict[str, Any]]: + """Parse JSON string to list of dicts, or pass through if already list.""" + if isinstance(v, list): + return v + if isinstance(v, str) and v.strip(): + import json + try: + data = json.loads(v) + if not isinstance(data, list): + raise ValueError("NPM_SERVERS must be a list of objects") + return data + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON in NPM_SERVERS: {e}") from e + return [] + + # MCP Server Configuration mcp_host: str = "0.0.0.0" mcp_port: int = 8000 diff --git a/src/npm_mcp/server.py b/src/npm_mcp/server.py index 94efe45..e33ace4 100644 --- a/src/npm_mcp/server.py +++ b/src/npm_mcp/server.py @@ -14,30 +14,109 @@ from .logs import is_log_dir_configured, list_available_logs, read_log_lines logger = logging.getLogger(__name__) -# Create global client instance (lazy initialization) -_client: NpmClient | None = None + +class ServerRegistry: + """Manages multiple NpmClient connections.""" + + def __init__(self, configs: list[dict[str, Any]], default: str | None = None): + self._clients: dict[str, NpmClient] = {} + self._default = default + + # Register multi-server entries + for cfg in configs: + name = cfg.get("name") + url = cfg.get("url") or cfg.get("api_url") + identity = cfg.get("identity") + secret = cfg.get("secret") + if not all([name, url, identity, secret]): + logger.warning(f"Server '{name}' is missing required fields, skipping") + continue + self._clients[name] = NpmClient(base_url=url, identity=identity, secret=secret) + + # Fallback to single-server settings if registry is empty + if not self._clients and settings.api_url and settings.identity and settings.secret: + logger.info("No servers in NPM_SERVERS. Using single-server environment variables.") + self._clients["default"] = NpmClient( + base_url=settings.api_url, + identity=settings.identity, + secret=settings.secret + ) + if not self._default: + self._default = "default" + + # Validate default + if self._default and self._default not in self._clients: + logger.warning( + f"Default server '{self._default}' is not in configured servers. " + "Clearing default." + ) + self._default = None + + def get(self, name: str | None = None) -> NpmClient: + """Retrieve client by name. Fallback to default if name is None/empty.""" + if not self._clients: + raise KeyError("No NPM servers configured.") + + if name is None or name == "": + if self._default: + name = self._default + elif len(self._clients) == 1: + name = next(iter(self._clients.keys())) + else: + raise KeyError("Multiple servers configured but no default server specified.") + + if name not in self._clients: + raise KeyError( + f"Server '{name}' not found. " + f"Configured servers: {list(self._clients.keys())}" + ) + return self._clients[name] + + def list_names(self) -> list[str]: + return list(self._clients.keys()) + + def get_default(self) -> str | None: + return self._default + + async def close_all(self) -> None: + for client in self._clients.values(): + await client.close() + + +# Global registry +registry: ServerRegistry | None = None + + +def get_registry() -> ServerRegistry: + """Get or create the global ServerRegistry instance.""" + global registry + if registry is None: + registry = ServerRegistry(settings.servers, settings.default_server) + return registry + + +def _get_client(server: str | None = None) -> NpmClient: + """Retrieve NPM client for the specified server, or fallback to default.""" + return get_registry().get(server) def get_client() -> NpmClient: - """Get or create the NPM client instance.""" - global _client - if _client is None: - _client = NpmClient() - return _client + """Get or create the NPM client instance (backward compatibility).""" + return _get_client(None) @asynccontextmanager async def lifespan(server: FastMCP): """Manage client lifecycle.""" - global _client - _client = NpmClient() - logger.info(f"NPM MCP Server starting, connecting to {settings.api_url}") + global registry + registry = ServerRegistry(settings.servers, settings.default_server) + logger.info(f"NPM MCP Server starting. Configured servers: {registry.list_names()}") try: yield finally: - if _client: - await _client.close() - _client = None + if registry: + await registry.close_all() + registry = None logger.info("NPM MCP Server stopped") @@ -53,7 +132,9 @@ mcp = FastMCP( def _format_error(e: Exception) -> str: """Format exception for tool response.""" - if isinstance(e, NpmAuthenticationError): + if isinstance(e, KeyError): + return f"Configuration error: {e.args[0]}" + elif isinstance(e, NpmAuthenticationError): return f"Authentication failed: {e}" elif isinstance(e, NpmConnectionError): return f"Connection error: {e}" @@ -64,20 +145,21 @@ def _format_error(e: Exception) -> str: return f"Error: {e}" + # ============================================================================= # Tools # ============================================================================= @mcp.tool() -async def list_proxy_hosts() -> str: +async def list_proxy_hosts(server: str | None = None) -> str: """List all proxy hosts configured in Nginx Proxy Manager. Returns a summary of all proxy hosts including their domains, forward destinations, and SSL status. """ try: - client = get_client() + client = _get_client(server) hosts = await client.get_proxy_hosts() if not hosts: @@ -101,17 +183,18 @@ async def list_proxy_hosts() -> str: @mcp.tool() -async def get_proxy_host_details(host_id: int) -> str: +async def get_proxy_host_details(host_id: int, server: str | None = None) -> str: """Get detailed configuration for a specific proxy host. Args: host_id: The ID of the proxy host to retrieve + server: Target server name Returns full configuration including SSL settings, locations, and advanced configuration. """ try: - client = get_client() + client = _get_client(server) host = await client.get_proxy_host(host_id) details: dict[str, Any] = { @@ -164,13 +247,13 @@ async def get_proxy_host_details(host_id: int) -> str: @mcp.tool() -async def get_system_health() -> str: +async def get_system_health(server: str | None = None) -> str: """Check the health and status of the Nginx Proxy Manager instance. Returns system status, version information, and connectivity status. """ try: - client = get_client() + client = _get_client(server) status = await client.get_status() result = [f"Status: {status.status}"] @@ -207,17 +290,18 @@ async def get_system_health() -> str: @mcp.tool() -async def search_audit_logs(limit: int = 50, offset: int = 0) -> str: +async def search_audit_logs(limit: int = 50, offset: int = 0, server: str | None = None) -> str: """Search the audit log for recent actions in Nginx Proxy Manager. Args: limit: Maximum number of entries to return (default: 50, max: 100) offset: Number of entries to skip for pagination (default: 0) + server: Target server name Returns recent audit log entries showing user actions and changes. """ try: - client = get_client() + client = _get_client(server) limit = min(limit, 100) # Cap at 100 entries = await client.get_audit_log(limit=limit, offset=offset) @@ -240,14 +324,14 @@ async def search_audit_logs(limit: int = 50, offset: int = 0) -> str: @mcp.tool() -async def list_certificates() -> str: +async def list_certificates(server: str | None = None) -> str: """List all SSL certificates managed by Nginx Proxy Manager. Returns a summary of all certificates including their domains, provider, and expiration dates. """ try: - client = get_client() + client = _get_client(server) certs = await client.get_certificates() if not certs: @@ -274,14 +358,14 @@ async def list_certificates() -> str: @mcp.tool() -async def list_access_lists() -> str: +async def list_access_lists(server: str | None = None) -> str: """List all access lists configured in Nginx Proxy Manager. Returns a summary of all access lists including their IDs and names. Use these IDs when creating proxy hosts that require access control. """ try: - client = get_client() + client = _get_client(server) access_lists = await client.get_access_lists() if not access_lists: @@ -309,6 +393,7 @@ async def create_proxy_host( allow_websocket_upgrade: bool | None = None, access_list_id: int | None = None, advanced_config: str | None = None, + server: str | None = None, ) -> str: """Create a new proxy host in Nginx Proxy Manager. @@ -325,6 +410,7 @@ async def create_proxy_host( access_list_id: Access list ID for authentication. Use list_access_lists to find. Use 0 for no access restrictions. (default from config) advanced_config: Custom nginx configuration block (default from config) + server: Target server name Returns: Details of the created proxy host including the new host ID. @@ -345,7 +431,7 @@ async def create_proxy_host( # Get defaults from config, then override with provided values defaults = settings.get_proxy_defaults() - client = get_client() + client = _get_client(server) host = await client.create_proxy_host( domain_names=domain_names, forward_host=forward_host, @@ -401,6 +487,7 @@ async def update_proxy_host( allow_websocket_upgrade: bool | None = None, access_list_id: int | None = None, advanced_config: str | None = None, + server: str | None = None, ) -> str: """Update an existing proxy host in Nginx Proxy Manager. @@ -417,12 +504,13 @@ async def update_proxy_host( allow_websocket_upgrade: Allow WebSocket connections access_list_id: Access list ID (0 for no restrictions) advanced_config: Custom nginx configuration block + server: Target server name Returns: Details of the updated proxy host. """ try: - client = get_client() + client = _get_client(server) kwargs = {} if forward_host is not None: kwargs["forward_host"] = forward_host @@ -465,11 +553,12 @@ async def get_proxy_host_logs( log_type: str = "access", lines: int = 100, search: str | None = None, + server: str | None = None, ) -> str: """Retrieve recent nginx log entries for a specific proxy host. - Reads the raw nginx access or error log file for the given host. - Requires the NPM log directory to be mounted (see NPM_LOG_DIR config). + Reads the raw nginx access or error log file for the given host or retrieves + them via the API. Args: host_id: The ID of the proxy host (use list_proxy_hosts to find IDs) @@ -479,45 +568,75 @@ async def get_proxy_host_logs( (default: 100, max: 500) search: Optional filter string - only lines containing this text are returned (case-insensitive) - - Returns: - The most recent log lines for the proxy host, with metadata. - - Examples: - - get_proxy_host_logs(5) — last 100 access log lines for host 5 - - get_proxy_host_logs(5, log_type="error") — recent error log - - get_proxy_host_logs(5, lines=50, search="404") — last 50 lines containing "404" - - get_proxy_host_logs(5, search="10.0.0.1") — filter by client IP + server: Target server name """ try: - client = get_client() - host = await client.get_proxy_host(host_id) - domains = ", ".join(host.domain_names) + client = _get_client(server) + + # Try retrieving logs via the API first + try: + log_data = await client.get_proxy_host_logs( + host_id=host_id, + log_type=log_type, + lines=lines, + ) + raw_lines = log_data.get("lines", []) + if search: + search_lower = search.lower() + filtered_lines = [ + line for line in raw_lines if search_lower in line.lower() + ] + else: + filtered_lines = raw_lines - result = read_log_lines( - host_id=host_id, - log_type=log_type, - lines=lines, - search=search, - ) + host = await client.get_proxy_host(host_id) + domains = ", ".join(host.domain_names) + header_parts = [ + f"Proxy host [{host_id}] {domains} — {log_type} log (retrieved via API)", + f"Showing last {len(filtered_lines)} lines:", + ] + header = "\n".join(header_parts) + if not filtered_lines: + return f"{header}\n\n(no log entries found)" + return f"{header}\n\n" + "\n".join(filtered_lines) - header_parts = [ - f"Proxy host [{host_id}] {domains} — {log_type} log", - f"File: {result['file']}", - ] - if result["total_lines_in_file"] is not None: - header_parts.append(f"Total lines in file: {result['total_lines_in_file']}") - if result["matched_lines"] is not None: - header_parts.append(f"Lines matching '{search}': {result['matched_lines']}") - header_parts.append(f"Showing last {result['returned_lines']} lines:") + except Exception as api_err: + # Fallback for default server or if local logs are mounted + reg = get_registry() + is_default = False + try: + target_client = reg.get(server) + default_client = reg.get(None) + if target_client.base_url == default_client.base_url: + is_default = True + except Exception: + pass - header = "\n".join(header_parts) + if is_default and is_log_dir_configured(): + host = await client.get_proxy_host(host_id) + domains = ", ".join(host.domain_names) + result = read_log_lines( + host_id=host_id, + log_type=log_type, + lines=lines, + search=search, + ) + header_parts = [ + f"Proxy host [{host_id}] {domains} — {log_type} log (local fallback)", + f"File: {result['file']}", + ] + if result["total_lines_in_file"] is not None: + header_parts.append(f"Total lines in file: {result['total_lines_in_file']}") + if result["matched_lines"] is not None: + header_parts.append(f"Lines matching '{search}': {result['matched_lines']}") + header_parts.append(f"Showing last {result['returned_lines']} lines:") - if not result["lines"]: - return f"{header}\n\n(no log entries found)" - - log_output = "\n".join(result["lines"]) - return f"{header}\n\n{log_output}" + header = "\n".join(header_parts) + if not result["lines"]: + return f"{header}\n\n(no log entries found)" + return f"{header}\n\n" + "\n".join(result["lines"]) + else: + raise api_err except Exception as e: return _format_error(e) @@ -528,6 +647,7 @@ async def create_certificate( domain_names: list[str], email: str, dns_challenge: bool = False, + server: str | None = None, ) -> str: """Provision a new Let's Encrypt SSL certificate. @@ -535,13 +655,10 @@ async def create_certificate( domain_names: List of domain names for the certificate email: Email address for Let's Encrypt notifications dns_challenge: Use DNS challenge instead of HTTP (default: False) - - Returns: - Details of the created certificate including its ID. - Use the returned ID with create_proxy_host or update_proxy_host. + server: Target server name """ try: - client = get_client() + client = _get_client(server) cert = await client.create_certificate( domain_names=domain_names, email=email, @@ -560,3 +677,344 @@ async def create_certificate( except Exception as e: return _format_error(e) + + +@mcp.tool() +async def list_servers() -> str: + """List all configured NPM servers and their health/connectivity status. + + Returns a JSON string containing the list of registered servers, + the default server, and their health status. + """ + try: + reg = get_registry() + server_names = reg.list_names() + default_server = reg.get_default() + + health_status = {} + for name in server_names: + try: + client = reg.get(name) + status = await client.get_status() + health_status[name] = {"status": status.status, "version": status.version} + except Exception as e: + health_status[name] = {"status": "error", "error": str(e)} + + return json.dumps({ + "servers": server_names, + "default_server": default_server, + "health": health_status + }, indent=2) + except Exception as e: + return _format_error(e) + + +@mcp.tool() +async def clone_proxy_host( + source_server: str, + target_server: str, + host_id: int, + override_settings: dict[str, Any] | None = None, +) -> str: + """Clone a proxy host configuration from source_server to target_server. + + Resolves certificate and access list IDs automatically by matching names/domains. + + Args: + source_server: Name of the source server. + target_server: Name of the destination server. + host_id: The ID of the host on the source server. + override_settings: Optional dict of settings to override during creation. + """ + try: + source_client = _get_client(source_server) + target_client = _get_client(target_server) + + # Retrieve host configuration from source server + host = await source_client.get_proxy_host(host_id) + + # Resolve certificate ID + target_cert_id = 0 + cert_resolved_msg = "None" + if host.certificate_id and host.certificate_id > 0: + try: + source_cert = await source_client.get_certificate(host.certificate_id) + target_certs = await target_client.get_certificates() + + matched_cert = None + for cert in target_certs: + if ( + source_cert.nice_name + and cert.nice_name + and source_cert.nice_name == cert.nice_name + ): + matched_cert = cert + break + if ( + source_cert.domain_names + and cert.domain_names + and set(source_cert.domain_names) == set(cert.domain_names) + ): + matched_cert = cert + break + + if matched_cert: + target_cert_id = matched_cert.id + cert_name = matched_cert.nice_name or ", ".join(matched_cert.domain_names) + cert_resolved_msg = f"Resolved to ID {target_cert_id} ({cert_name})" + else: + source_name = source_cert.nice_name or ", ".join(source_cert.domain_names) + cert_resolved_msg = ( + f"Could not resolve source certificate '{source_name}' " + "on target server. Defaulting to None (0)." + ) + except Exception as e: + cert_resolved_msg = f"Error resolving certificate: {e}. Defaulting to None (0)." + + # Resolve access list ID + target_access_list_id = 0 + access_list_resolved_msg = "None" + if host.access_list_id and host.access_list_id > 0: + try: + source_alists = await source_client.get_access_lists() + source_alist = next( + (al for al in source_alists if al.id == host.access_list_id), None + ) + + if source_alist: + target_alists = await target_client.get_access_lists() + matched_alist = next( + (al for al in target_alists if al.name == source_alist.name), None + ) + + if matched_alist: + target_access_list_id = matched_alist.id + access_list_resolved_msg = ( + f"Resolved to ID {target_access_list_id} ({matched_alist.name})" + ) + else: + access_list_resolved_msg = ( + f"Could not resolve source access list '{source_alist.name}' " + "on target server. Defaulting to None (0)." + ) + else: + access_list_resolved_msg = ( + "Source access list not found. Defaulting to None (0)." + ) + except Exception as e: + access_list_resolved_msg = ( + f"Error resolving access list: {e}. Defaulting to None (0)." + ) + + # Construct creation payload + payload = { + "domain_names": host.domain_names, + "forward_host": host.forward_host, + "forward_port": host.forward_port, + "forward_scheme": host.forward_scheme, + "ssl_forced": host.ssl_forced, + "hsts_enabled": host.hsts_enabled, + "hsts_subdomains": host.hsts_subdomains, + "http2_support": host.http2_support, + "block_exploits": host.block_exploits, + "caching_enabled": host.caching_enabled, + "allow_websocket_upgrade": host.allow_websocket_upgrade, + "advanced_config": host.advanced_config, + "meta": host.meta, + "certificate_id": target_cert_id, + "access_list_id": target_access_list_id, + } + + # Apply overrides + if override_settings: + payload.update(override_settings) + + # Create proxy host on target server + new_host = await target_client.create_proxy_host( + domain_names=payload["domain_names"], + forward_host=payload["forward_host"], + forward_port=payload["forward_port"], + forward_scheme=payload["forward_scheme"], + certificate_id=payload["certificate_id"], + ssl_forced=payload["ssl_forced"], + hsts_enabled=payload["hsts_enabled"], + hsts_subdomains=payload["hsts_subdomains"], + http2_support=payload["http2_support"], + block_exploits=payload["block_exploits"], + caching_enabled=payload["caching_enabled"], + allow_websocket_upgrade=payload["allow_websocket_upgrade"], + access_list_id=payload["access_list_id"], + advanced_config=payload["advanced_config"], + meta=payload["meta"], + ) + + return ( + f"Successfully cloned proxy host from '{source_server}' to '{target_server}'!\n\n" + f"Source Host ID: {host_id}\n" + f"Target Host ID: {new_host.id}\n" + f"Domains: {', '.join(new_host.domain_names)}\n" + f"Certificate: {cert_resolved_msg}\n" + f"Access List: {access_list_resolved_msg}" + ) + + except Exception as e: + return _format_error(e) + + +@mcp.tool() +async def sync_access_lists(source_server: str, target_server: str) -> str: + """Sync access lists from source_server to target_server. + + Replicates missing access lists by name, carrying over credentials and IP rules. + + Args: + source_server: Name of the source server. + target_server: Name of the target server. + """ + try: + source_client = _get_client(source_server) + target_client = _get_client(target_server) + + # Get raw access lists from both servers to retrieve detailed items and clients + source_response = await source_client._request("GET", "/nginx/access-lists") + source_lists = source_response.json() + + target_response = await target_client._request("GET", "/nginx/access-lists") + target_lists = target_response.json() + + target_names = {al["name"] for al in target_lists} + + synced = [] + skipped = [] + + for al in source_lists: + name = al.get("name") + if not name: + continue + + if name in target_names: + skipped.append(f"'{name}' (already exists)") + continue + + # Strip database IDs and unique primary keys from items & clients to avoid conflicts + items = al.get("items", []) + cleaned_items = [] + for item in items: + cleaned = item.copy() + cleaned.pop("id", None) + cleaned.pop("access_list_id", None) + cleaned.pop("created_on", None) + cleaned.pop("modified_on", None) + cleaned_items.append(cleaned) + + clients = al.get("clients", []) + cleaned_clients = [] + for client in clients: + cleaned = client.copy() + cleaned.pop("id", None) + cleaned.pop("access_list_id", None) + cleaned.pop("created_on", None) + cleaned.pop("modified_on", None) + cleaned_clients.append(cleaned) + + # Replicate access list + await target_client.create_access_list( + name=name, + satisfy_any=al.get("satisfy_any", False), + pass_auth=al.get("pass_auth", False), + items=cleaned_items, + clients=cleaned_clients, + ) + synced.append(name) + + result_parts = [f"Synced access lists from '{source_server}' to '{target_server}':"] + if synced: + result_parts.append(f"✅ Created: {', '.join(synced)}") + else: + result_parts.append("No new access lists were created.") + if skipped: + result_parts.append(f"ℹ️ Matched (exists): {', '.join(skipped)}") + + return "\n".join(result_parts) + + except Exception as e: + return _format_error(e) + + +@mcp.tool() +async def sync_certificates(source_server: str, target_server: str) -> str: + """Sync Let's Encrypt certificates from source_server to target_server. + + Matches existing certificates on the target server by domain names. + + Args: + source_server: Name of the source server. + target_server: Name of the target server. + """ + try: + source_client = _get_client(source_server) + target_client = _get_client(target_server) + + source_certs = await source_client.get_certificates() + target_certs = await target_client.get_certificates() + + # Build target domain map for lookup + target_domains_map = {frozenset(cert.domain_names): cert for cert in target_certs} + + synced = [] + skipped_exists = [] + skipped_custom = [] + + for cert in source_certs: + domains = cert.domain_names + if not domains: + continue + + cert_domains_set = frozenset(domains) + + # Check if matching cert exists on target + if cert_domains_set in target_domains_map: + skipped_exists.append(f"'{cert.nice_name or ', '.join(domains)}'") + continue + + # Check provider type + if cert.provider != "letsencrypt": + skipped_custom.append( + f"'{cert.nice_name or ', '.join(domains)}' " + f"(custom provider: {cert.provider})" + ) + continue + + # Re-provision Let's Encrypt certificate + email = ( + cert.meta.get("letsencrypt_email") + or cert.meta.get("email") + or settings.identity + or "admin@example.com" + ) + dns_challenge = cert.meta.get("dns_challenge", False) + + await target_client.create_certificate( + domain_names=domains, + email=email, + dns_challenge=dns_challenge, + ) + synced.append(f"'{', '.join(domains)}'") + + result_parts = [ + f"Synced Let's Encrypt certificates from '{source_server}' " + f"to '{target_server}':" + ] + if synced: + result_parts.append(f"✅ Provisioned: {', '.join(synced)}") + else: + result_parts.append("No new certificates were provisioned.") + if skipped_exists: + result_parts.append(f"ℹ️ Matched (exists): {', '.join(skipped_exists)}") + if skipped_custom: + result_parts.append(f"⚠️ Skipped (manual upload required): {', '.join(skipped_custom)}") + + return "\n".join(result_parts) + + except Exception as e: + return _format_error(e) diff --git a/tests/test_client.py b/tests/test_client.py index 508ed41..834a7db 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,10 +1,9 @@ """Tests for NpmClient.""" import pytest -from httpx import Response from npm_mcp.client import NpmClient -from npm_mcp.exceptions import NpmAuthenticationError, NpmConnectionError +from npm_mcp.exceptions import NpmAuthenticationError @pytest.fixture @@ -227,3 +226,89 @@ class TestNpmClientEndpoints: assert host.forward_port == 3000 assert host.ssl_forced is True assert host.certificate_id == 24 + + @pytest.mark.asyncio + async def test_create_access_list(self, httpx_mock, mock_token_response): + """Test creating an access list.""" + httpx_mock.add_response( + method="POST", + url="http://localhost:81/api/tokens", + json=mock_token_response, + ) + httpx_mock.add_response( + method="POST", + url="http://localhost:81/api/nginx/access-lists", + json={ + "id": 5, + "created_on": "2024-01-01T00:00:00Z", + "modified_on": "2024-01-01T00:00:00Z", + "owner_user_id": 1, + "name": "Custom List", + "satisfy_any": True, + "pass_auth": False, + }, + status_code=201, + ) + + async with NpmClient( + base_url="http://localhost:81/api", + identity="test@test.com", + secret="password", + ) as client: + al = await client.create_access_list( + name="Custom List", + satisfy_any=True, + pass_auth=False, + items=[{"username": "u", "password": "p"}], + clients=[{"address": "1.1.1.1", "directive": "allow"}], + ) + + assert al.id == 5 + assert al.name == "Custom List" + assert al.satisfy_any is True + assert al.pass_auth is False + + @pytest.mark.asyncio + async def test_get_proxy_host_logs(self, httpx_mock, mock_token_response): + """Test fetching proxy host logs via API.""" + httpx_mock.add_response( + method="POST", + url="http://localhost:81/api/tokens", + json=mock_token_response, + ) + httpx_mock.add_response( + method="GET", + url="http://localhost:81/api/nginx/proxy-hosts/42/logs?type=access&limit=50", + json={"lines": ["line 1", "line 2"]}, + ) + + async with NpmClient( + base_url="http://localhost:81/api", + identity="test@test.com", + secret="password", + ) as client: + logs = await client.get_proxy_host_logs(host_id=42, log_type="access", lines=50) + assert logs == {"lines": ["line 1", "line 2"]} + + @pytest.mark.asyncio + async def test_get_proxy_host_logs_summary(self, httpx_mock, mock_token_response): + """Test fetching proxy host logs summary via API.""" + httpx_mock.add_response( + method="POST", + url="http://localhost:81/api/tokens", + json=mock_token_response, + ) + httpx_mock.add_response( + method="GET", + url="http://localhost:81/api/nginx/proxy-hosts/42/logs/summary", + json={"access": 100, "error": 5}, + ) + + async with NpmClient( + base_url="http://localhost:81/api", + identity="test@test.com", + secret="password", + ) as client: + summary = await client.get_proxy_host_logs_summary(host_id=42) + assert summary == {"access": 100, "error": 5} + diff --git a/tests/test_config.py b/tests/test_config.py index ff5c04e..a550189 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -89,3 +89,41 @@ class TestProxyDefaults: # Other defaults preserved assert defaults["ssl_forced"] is True assert defaults["block_exploits"] is True + + +class TestMultiServerConfig: + """Test multi-server configuration parsing.""" + + def test_servers_empty_by_default(self): + """Test that servers is empty by default.""" + s = Settings(identity="test", secret="test") + assert s.servers == [] + assert s.default_server is None + + def test_servers_json_parsing(self, monkeypatch): + """Test parsing servers JSON list from environment variable.""" + monkeypatch.setenv("NPM_IDENTITY", "test") + monkeypatch.setenv("NPM_SECRET", "test") + monkeypatch.setenv( + "NPM_SERVERS", + '[{"name": "prod", "url": "http://prod:81/api", "identity": "p", "secret": "ps"}, ' + '{"name": "dev", "url": "http://dev:81/api", "identity": "d", "secret": "ds"}]' + ) + monkeypatch.setenv("NPM_DEFAULT_SERVER", "prod") + + s = Settings() + assert len(s.servers) == 2 + assert s.servers[0]["name"] == "prod" + assert s.servers[0]["url"] == "http://prod:81/api" + assert s.servers[1]["name"] == "dev" + assert s.default_server == "prod" + + def test_servers_invalid_json_raises(self, monkeypatch): + """Test that invalid servers JSON raises SettingsError.""" + monkeypatch.setenv("NPM_IDENTITY", "test") + monkeypatch.setenv("NPM_SECRET", "test") + monkeypatch.setenv("NPM_SERVERS", "{not valid}") + + with pytest.raises(SettingsError): + Settings() + diff --git a/tests/test_registry.py b/tests/test_registry.py new file mode 100644 index 0000000..0d38c89 --- /dev/null +++ b/tests/test_registry.py @@ -0,0 +1,97 @@ +"""Tests for ServerRegistry.""" + +import pytest + +from npm_mcp.config import settings +from npm_mcp.server import ServerRegistry + + +def test_registry_fallback_to_single_server(monkeypatch): + """Test that registry falls back to single-server settings when empty.""" + monkeypatch.setattr(settings, "api_url", "http://test-url:81/api") + monkeypatch.setattr(settings, "identity", "test-user") + monkeypatch.setattr(settings, "secret", "test-pass") + + registry = ServerRegistry(configs=[], default=None) + + assert registry.list_names() == ["default"] + assert registry.get_default() == "default" + + client = registry.get() + assert client.base_url == "http://test-url:81/api" + assert client._identity == "test-user" + + +def test_registry_multiple_servers(): + """Test that multiple servers are correctly registered.""" + configs = [ + {"name": "prod", "url": "http://prod:81/api", "identity": "p", "secret": "ps"}, + {"name": "dev", "url": "http://dev:81/api", "identity": "d", "secret": "ds"}, + ] + + registry = ServerRegistry(configs=configs, default="prod") + + assert set(registry.list_names()) == {"prod", "dev"} + assert registry.get_default() == "prod" + + prod_client = registry.get("prod") + assert prod_client.base_url == "http://prod:81/api" + + dev_client = registry.get("dev") + assert dev_client.base_url == "http://dev:81/api" + + +def test_registry_get_default_fallback(): + """Test that get() falls back to default server when name is None/empty.""" + configs = [ + {"name": "prod", "url": "http://prod:81/api", "identity": "p", "secret": "ps"}, + {"name": "dev", "url": "http://dev:81/api", "identity": "d", "secret": "ds"}, + ] + + registry = ServerRegistry(configs=configs, default="dev") + + # Name is None + client = registry.get(None) + assert client.base_url == "http://dev:81/api" + + # Name is empty string + client_empty = registry.get("") + assert client_empty.base_url == "http://dev:81/api" + + +def test_registry_single_client_no_default_specified(): + """Test that get() succeeds if there is only 1 server, even if no default is specified.""" + configs = [ + {"name": "only-one", "url": "http://only:81/api", "identity": "o", "secret": "os"} + ] + + registry = ServerRegistry(configs=configs, default=None) + + assert registry.get_default() is None + client = registry.get() + assert client.base_url == "http://only:81/api" + + +def test_registry_multiple_clients_no_default_raises(): + """Test that get() raises KeyError if multiple servers are defined but no default is set.""" + configs = [ + {"name": "prod", "url": "http://prod:81/api", "identity": "p", "secret": "ps"}, + {"name": "dev", "url": "http://dev:81/api", "identity": "d", "secret": "ds"}, + ] + + registry = ServerRegistry(configs=configs, default=None) + + with pytest.raises(KeyError, match="Multiple servers configured but no default server"): + registry.get() + + +def test_registry_invalid_name_raises(): + """Test that get() raises KeyError for non-existent server names.""" + configs = [ + {"name": "prod", "url": "http://prod:81/api", "identity": "p", "secret": "ps"}, + ] + + registry = ServerRegistry(configs=configs, default="prod") + + with pytest.raises(KeyError, match="Server 'non-existent' not found"): + registry.get("non-existent") diff --git a/tests/test_server_tools.py b/tests/test_server_tools.py new file mode 100644 index 0000000..8c4a89b --- /dev/null +++ b/tests/test_server_tools.py @@ -0,0 +1,315 @@ +"""Tests for multi-server management and sync tools.""" + +import json +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from npm_mcp.models import AccessList, Certificate, HealthStatus, ProxyHost +from npm_mcp.server import ( + clone_proxy_host, + get_proxy_host_logs, + list_servers, + sync_access_lists, + sync_certificates, +) + + +@pytest.fixture +def mock_registry(): + """Mock registry with prod and dev servers.""" + reg = MagicMock() + reg.list_names.return_value = ["prod", "dev"] + reg.get_default.return_value = "prod" + return reg + + +@pytest.mark.asyncio +async def test_list_servers(mock_registry): + """Test list_servers tool output and health query.""" + client_prod = MagicMock() + client_prod.get_status = AsyncMock( + return_value=HealthStatus(status="online", version={"major": "2"}) + ) + client_dev = MagicMock() + client_dev.get_status = AsyncMock(side_effect=Exception("Connection failed")) + + mock_registry.get.side_effect = lambda name: client_prod if name == "prod" else client_dev + + with patch("npm_mcp.server.get_registry", return_value=mock_registry): + result_json = await list_servers() + result = json.loads(result_json) + + assert result["servers"] == ["prod", "dev"] + assert result["default_server"] == "prod" + assert result["health"]["prod"]["status"] == "online" + assert result["health"]["dev"]["status"] == "error" + assert "Connection failed" in result["health"]["dev"]["error"] + + +@pytest.mark.asyncio +async def test_clone_proxy_host(mock_registry): + """Test clone_proxy_host tool with cert and access list resolution.""" + source_client = MagicMock() + target_client = MagicMock() + + mock_registry.get.side_effect = lambda name: source_client if name == "prod" else target_client + + # Source host setup + source_host = ProxyHost( + id=12, + created_on="2024-01-01T00:00:00Z", + modified_on="2024-01-01T00:00:00Z", + owner_user_id=1, + domain_names=["test.example.com"], + forward_host="192.168.1.50", + forward_port=8080, + forward_scheme="http", + certificate_id=10, + access_list_id=5, + ssl_forced=True, + hsts_enabled=True, + hsts_subdomains=False, + http2_support=True, + block_exploits=True, + caching_enabled=False, + allow_websocket_upgrade=True, + advanced_config="my advanced config", + meta={"key": "val"}, + ) + source_client.get_proxy_host = AsyncMock(return_value=source_host) + + # Source dependencies setup + source_cert = Certificate( + id=10, + nice_name="wildcard-example", + domain_names=["*.example.com"], + provider="letsencrypt", + ) + source_client.get_certificate = AsyncMock(return_value=source_cert) + + source_alists = [ + AccessList( + id=5, + name="Staging Auth", + created_on="2024-01-01T00:00:00Z", + modified_on="2024-01-01T00:00:00Z", + ) + ] + source_client.get_access_lists = AsyncMock(return_value=source_alists) + + # Target dependency search results + target_certs = [ + Certificate( + id=100, + nice_name="wildcard-example", + domain_names=["*.example.com"], + provider="letsencrypt", + ) + ] + target_client.get_certificates = AsyncMock(return_value=target_certs) + + target_alists = [ + AccessList( + id=500, + name="Staging Auth", + created_on="2024-01-02T00:00:00Z", + modified_on="2024-01-02T00:00:00Z", + ) + ] + target_client.get_access_lists = AsyncMock(return_value=target_alists) + + # Mock creation on target + cloned_host = ProxyHost( + id=999, + created_on="2024-01-03T00:00:00Z", + modified_on="2024-01-03T00:00:00Z", + owner_user_id=1, + domain_names=["test.example.com"], + forward_host="192.168.1.50", + forward_port=8080, + ) + target_client.create_proxy_host = AsyncMock(return_value=cloned_host) + + with patch("npm_mcp.server.get_registry", return_value=mock_registry): + result = await clone_proxy_host( + source_server="prod", + target_server="dev", + host_id=12, + override_settings={"forward_host": "10.0.0.10"}, + ) + + assert "Successfully cloned" in result + assert "Source Host ID: 12" in result + assert "Target Host ID: 999" in result + assert "Resolved to ID 100" in result + assert "Resolved to ID 500" in result + + target_client.create_proxy_host.assert_called_once_with( + domain_names=["test.example.com"], + forward_host="10.0.0.10", # Overridden! + forward_port=8080, + forward_scheme="http", + certificate_id=100, # Resolved! + ssl_forced=True, + hsts_enabled=True, + hsts_subdomains=False, + http2_support=True, + block_exploits=True, + caching_enabled=False, + allow_websocket_upgrade=True, + access_list_id=500, # Resolved! + advanced_config="my advanced config", + meta={"key": "val"}, + ) + + +@pytest.mark.asyncio +async def test_sync_access_lists(mock_registry): + """Test sync_access_lists replicates missing access lists with credentials/IPs.""" + source_client = MagicMock() + target_client = MagicMock() + + mock_registry.get.side_effect = lambda name: source_client if name == "prod" else target_client + + # Source returns raw JSON including items & clients + source_mock_response = MagicMock() + source_mock_response.json.return_value = [ + { + "id": 1, + "name": "Staging Auth", + "satisfy_any": False, + "pass_auth": True, + "items": [ + {"id": 10, "access_list_id": 1, "username": "u", "password": "p"} + ], + "clients": [ + {"id": 20, "access_list_id": 1, "address": "1.1.1.1", "directive": "allow"} + ], + }, + { + "id": 2, + "name": "Already Synced", + "satisfy_any": True, + "pass_auth": False, + } + ] + source_client._request = AsyncMock(return_value=source_mock_response) + + # Target returns raw JSON showing "Already Synced" exists + target_mock_response = MagicMock() + target_mock_response.json.return_value = [{"id": 99, "name": "Already Synced"}] + target_client._request = AsyncMock(return_value=target_mock_response) + + target_client.create_access_list = AsyncMock() + + with patch("npm_mcp.server.get_registry", return_value=mock_registry): + result = await sync_access_lists(source_server="prod", target_server="dev") + + assert "Created: Staging Auth" in result + assert "Matched (exists): 'Already Synced' (already exists)" in result + + # Verify items and clients were stripped of database IDs + target_client.create_access_list.assert_called_once_with( + name="Staging Auth", + satisfy_any=False, + pass_auth=True, + items=[{"username": "u", "password": "p"}], + clients=[{"address": "1.1.1.1", "directive": "allow"}], + ) + + +@pytest.mark.asyncio +async def test_sync_certificates(mock_registry): + """Test sync_certificates provisions Let's Encrypt and skips custom certs.""" + source_client = MagicMock() + target_client = MagicMock() + + mock_registry.get.side_effect = lambda name: source_client if name == "prod" else target_client + + # Source certificates + source_certs = [ + Certificate( + id=1, + nice_name="le-cert", + domain_names=["le.example.com"], + provider="letsencrypt", + meta={"letsencrypt_email": "le@test.com", "dns_challenge": True}, + ), + Certificate( + id=2, + nice_name="custom-cert", + domain_names=["custom.example.com"], + provider="other-provider", + ), + Certificate( + id=3, + nice_name="already-on-target", + domain_names=["existing.example.com"], + provider="letsencrypt", + ) + ] + source_client.get_certificates = AsyncMock(return_value=source_certs) + + # Target certificates + target_certs = [ + Certificate( + id=10, + nice_name="already-on-target", + domain_names=["existing.example.com"], + provider="letsencrypt", + ) + ] + target_client.get_certificates = AsyncMock(return_value=target_certs) + + target_client.create_certificate = AsyncMock() + + with patch("npm_mcp.server.get_registry", return_value=mock_registry): + result = await sync_certificates(source_server="prod", target_server="dev") + + assert "Provisioned: 'le.example.com'" in result + assert "Matched (exists): 'already-on-target'" in result + assert "Skipped (manual upload required): 'custom-cert'" in result + + target_client.create_certificate.assert_called_once_with( + domain_names=["le.example.com"], + email="le@test.com", + dns_challenge=True, + ) + + +@pytest.mark.asyncio +async def test_get_proxy_host_logs_api(mock_registry): + """Test get_proxy_host_logs tool queries the API for logs first.""" + client = MagicMock() + mock_registry.get.return_value = client + + # Mock host details + host = ProxyHost( + id=5, + created_on="2024-01-01T00:00:00Z", + modified_on="2024-01-01T00:00:00Z", + owner_user_id=1, + domain_names=["test.example.com"], + forward_host="192.168.1.50", + forward_port=8080, + ) + client.get_proxy_host = AsyncMock(return_value=host) + + # Mock API logs endpoint + client.get_proxy_host_logs = AsyncMock(return_value={ + "lines": ["Log line A", "Log line B", "Filter me out"] + }) + + with patch("npm_mcp.server.get_registry", return_value=mock_registry): + # Retrieve logs with search filter + result = await get_proxy_host_logs( + host_id=5, log_type="access", lines=10, search="Log line" + ) + + assert "test.example.com" in result + assert "(retrieved via API)" in result + assert "Log line A" in result + assert "Log line B" in result + assert "Filter me out" not in result + assert "Showing last 2 lines:" in result diff --git a/uv.lock b/uv.lock index dfd65bb..ef958bf 100644 --- a/uv.lock +++ b/uv.lock @@ -313,7 +313,7 @@ wheels = [ [[package]] name = "npm-mcp" -version = "0.1.0" +version = "0.0.2" source = { editable = "." } dependencies = [ { name = "httpx" },