mirror of
https://github.com/b3nw/nginx-proxy-manager-mcp.git
synced 2026-06-09 23:09:40 -05:00
feat: implement multi-server support and sync tools
- Introduced ServerRegistry to manage multiple NPM instances - Added support for NPM_SERVERS JSON environment variable - Updated all tools to support optional 'server' targeting - Implemented clone_proxy_host, sync_access_lists, and sync_certificates tools - Transitioned get_proxy_host_logs to API-based retrieval with local fallback - Added comprehensive test suite for multi-server management and sync operations Co-authored-by: claw-io <agent@ben.io>
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
@@ -392,3 +393,61 @@ class NpmClient:
|
||||
|
||||
response = await self._request("POST", "/nginx/certificates", json=payload)
|
||||
return Certificate(**response.json())
|
||||
|
||||
async def get_proxy_host_logs(
|
||||
self,
|
||||
host_id: int,
|
||||
log_type: str = "access",
|
||||
lines: int = 100,
|
||||
) -> dict[str, Any]:
|
||||
"""Retrieve proxy host logs via the API.
|
||||
|
||||
Args:
|
||||
host_id: NPM proxy host ID.
|
||||
log_type: "access" or "error".
|
||||
lines: Number of most recent lines to return.
|
||||
"""
|
||||
response = await self._request(
|
||||
"GET",
|
||||
f"/nginx/proxy-hosts/{host_id}/logs",
|
||||
params={"type": log_type, "limit": lines},
|
||||
)
|
||||
return response.json()
|
||||
|
||||
async def get_proxy_host_logs_summary(self, host_id: int) -> dict[str, Any]:
|
||||
"""Retrieve proxy host logs summary via the API.
|
||||
|
||||
Args:
|
||||
host_id: NPM proxy host ID.
|
||||
"""
|
||||
response = await self._request("GET", f"/nginx/proxy-hosts/{host_id}/logs/summary")
|
||||
return response.json()
|
||||
|
||||
async def create_access_list(
|
||||
self,
|
||||
name: str,
|
||||
satisfy_any: bool = False,
|
||||
pass_auth: bool = False,
|
||||
items: list[dict[str, Any]] | None = None,
|
||||
clients: list[dict[str, Any]] | None = None,
|
||||
) -> AccessList:
|
||||
"""Create a new access list on the server.
|
||||
|
||||
Args:
|
||||
name: Name of the access list
|
||||
satisfy_any: Satisfy any HTTP auth or IP restriction
|
||||
pass_auth: Pass auth headers to host
|
||||
items: Auth entries (username/password) and IP restrictions
|
||||
clients: IP restriction clients
|
||||
"""
|
||||
payload = {
|
||||
"name": name,
|
||||
"satisfy_any": satisfy_any,
|
||||
"pass_auth": pass_auth,
|
||||
"items": items or [],
|
||||
"clients": clients or [],
|
||||
}
|
||||
response = await self._request("POST", "/nginx/access-lists", json=payload)
|
||||
return AccessList(**response.json())
|
||||
|
||||
|
||||
|
||||
@@ -37,6 +37,31 @@ class Settings(BaseSettings):
|
||||
identity: str = ""
|
||||
secret: str = ""
|
||||
|
||||
# Multi-Server Configuration
|
||||
# Example JSON:
|
||||
# '[{"name": "prod", "url": "http://10.0.0.10:81/api",
|
||||
# "identity": "admin@example.com", "secret": "pwd"}]'
|
||||
servers: list[dict[str, Any]] = []
|
||||
default_server: str | None = None
|
||||
|
||||
@field_validator("servers", mode="before")
|
||||
@classmethod
|
||||
def parse_servers(cls, v: Any) -> list[dict[str, Any]]:
|
||||
"""Parse JSON string to list of dicts, or pass through if already list."""
|
||||
if isinstance(v, list):
|
||||
return v
|
||||
if isinstance(v, str) and v.strip():
|
||||
import json
|
||||
try:
|
||||
data = json.loads(v)
|
||||
if not isinstance(data, list):
|
||||
raise ValueError("NPM_SERVERS must be a list of objects")
|
||||
return data
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Invalid JSON in NPM_SERVERS: {e}") from e
|
||||
return []
|
||||
|
||||
|
||||
# MCP Server Configuration
|
||||
mcp_host: str = "0.0.0.0"
|
||||
mcp_port: int = 8000
|
||||
|
||||
+509
-51
@@ -14,30 +14,109 @@ from .logs import is_log_dir_configured, list_available_logs, read_log_lines
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Create global client instance (lazy initialization)
|
||||
_client: NpmClient | None = None
|
||||
|
||||
class ServerRegistry:
|
||||
"""Manages multiple NpmClient connections."""
|
||||
|
||||
def __init__(self, configs: list[dict[str, Any]], default: str | None = None):
|
||||
self._clients: dict[str, NpmClient] = {}
|
||||
self._default = default
|
||||
|
||||
# Register multi-server entries
|
||||
for cfg in configs:
|
||||
name = cfg.get("name")
|
||||
url = cfg.get("url") or cfg.get("api_url")
|
||||
identity = cfg.get("identity")
|
||||
secret = cfg.get("secret")
|
||||
if not all([name, url, identity, secret]):
|
||||
logger.warning(f"Server '{name}' is missing required fields, skipping")
|
||||
continue
|
||||
self._clients[name] = NpmClient(base_url=url, identity=identity, secret=secret)
|
||||
|
||||
# Fallback to single-server settings if registry is empty
|
||||
if not self._clients and settings.api_url and settings.identity and settings.secret:
|
||||
logger.info("No servers in NPM_SERVERS. Using single-server environment variables.")
|
||||
self._clients["default"] = NpmClient(
|
||||
base_url=settings.api_url,
|
||||
identity=settings.identity,
|
||||
secret=settings.secret
|
||||
)
|
||||
if not self._default:
|
||||
self._default = "default"
|
||||
|
||||
# Validate default
|
||||
if self._default and self._default not in self._clients:
|
||||
logger.warning(
|
||||
f"Default server '{self._default}' is not in configured servers. "
|
||||
"Clearing default."
|
||||
)
|
||||
self._default = None
|
||||
|
||||
def get(self, name: str | None = None) -> NpmClient:
|
||||
"""Retrieve client by name. Fallback to default if name is None/empty."""
|
||||
if not self._clients:
|
||||
raise KeyError("No NPM servers configured.")
|
||||
|
||||
if name is None or name == "":
|
||||
if self._default:
|
||||
name = self._default
|
||||
elif len(self._clients) == 1:
|
||||
name = next(iter(self._clients.keys()))
|
||||
else:
|
||||
raise KeyError("Multiple servers configured but no default server specified.")
|
||||
|
||||
if name not in self._clients:
|
||||
raise KeyError(
|
||||
f"Server '{name}' not found. "
|
||||
f"Configured servers: {list(self._clients.keys())}"
|
||||
)
|
||||
return self._clients[name]
|
||||
|
||||
def list_names(self) -> list[str]:
|
||||
return list(self._clients.keys())
|
||||
|
||||
def get_default(self) -> str | None:
|
||||
return self._default
|
||||
|
||||
async def close_all(self) -> None:
|
||||
for client in self._clients.values():
|
||||
await client.close()
|
||||
|
||||
|
||||
# Global registry
|
||||
registry: ServerRegistry | None = None
|
||||
|
||||
|
||||
def get_registry() -> ServerRegistry:
|
||||
"""Get or create the global ServerRegistry instance."""
|
||||
global registry
|
||||
if registry is None:
|
||||
registry = ServerRegistry(settings.servers, settings.default_server)
|
||||
return registry
|
||||
|
||||
|
||||
def _get_client(server: str | None = None) -> NpmClient:
|
||||
"""Retrieve NPM client for the specified server, or fallback to default."""
|
||||
return get_registry().get(server)
|
||||
|
||||
|
||||
def get_client() -> NpmClient:
|
||||
"""Get or create the NPM client instance."""
|
||||
global _client
|
||||
if _client is None:
|
||||
_client = NpmClient()
|
||||
return _client
|
||||
"""Get or create the NPM client instance (backward compatibility)."""
|
||||
return _get_client(None)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(server: FastMCP):
|
||||
"""Manage client lifecycle."""
|
||||
global _client
|
||||
_client = NpmClient()
|
||||
logger.info(f"NPM MCP Server starting, connecting to {settings.api_url}")
|
||||
global registry
|
||||
registry = ServerRegistry(settings.servers, settings.default_server)
|
||||
logger.info(f"NPM MCP Server starting. Configured servers: {registry.list_names()}")
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if _client:
|
||||
await _client.close()
|
||||
_client = None
|
||||
if registry:
|
||||
await registry.close_all()
|
||||
registry = None
|
||||
logger.info("NPM MCP Server stopped")
|
||||
|
||||
|
||||
@@ -53,7 +132,9 @@ mcp = FastMCP(
|
||||
|
||||
def _format_error(e: Exception) -> str:
|
||||
"""Format exception for tool response."""
|
||||
if isinstance(e, NpmAuthenticationError):
|
||||
if isinstance(e, KeyError):
|
||||
return f"Configuration error: {e.args[0]}"
|
||||
elif isinstance(e, NpmAuthenticationError):
|
||||
return f"Authentication failed: {e}"
|
||||
elif isinstance(e, NpmConnectionError):
|
||||
return f"Connection error: {e}"
|
||||
@@ -64,20 +145,21 @@ def _format_error(e: Exception) -> str:
|
||||
return f"Error: {e}"
|
||||
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tools
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def list_proxy_hosts() -> str:
|
||||
async def list_proxy_hosts(server: str | None = None) -> str:
|
||||
"""List all proxy hosts configured in Nginx Proxy Manager.
|
||||
|
||||
Returns a summary of all proxy hosts including their domains,
|
||||
forward destinations, and SSL status.
|
||||
"""
|
||||
try:
|
||||
client = get_client()
|
||||
client = _get_client(server)
|
||||
hosts = await client.get_proxy_hosts()
|
||||
|
||||
if not hosts:
|
||||
@@ -101,17 +183,18 @@ async def list_proxy_hosts() -> str:
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_proxy_host_details(host_id: int) -> str:
|
||||
async def get_proxy_host_details(host_id: int, server: str | None = None) -> str:
|
||||
"""Get detailed configuration for a specific proxy host.
|
||||
|
||||
Args:
|
||||
host_id: The ID of the proxy host to retrieve
|
||||
server: Target server name
|
||||
|
||||
Returns full configuration including SSL settings, locations,
|
||||
and advanced configuration.
|
||||
"""
|
||||
try:
|
||||
client = get_client()
|
||||
client = _get_client(server)
|
||||
host = await client.get_proxy_host(host_id)
|
||||
|
||||
details: dict[str, Any] = {
|
||||
@@ -164,13 +247,13 @@ async def get_proxy_host_details(host_id: int) -> str:
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_system_health() -> str:
|
||||
async def get_system_health(server: str | None = None) -> str:
|
||||
"""Check the health and status of the Nginx Proxy Manager instance.
|
||||
|
||||
Returns system status, version information, and connectivity status.
|
||||
"""
|
||||
try:
|
||||
client = get_client()
|
||||
client = _get_client(server)
|
||||
status = await client.get_status()
|
||||
|
||||
result = [f"Status: {status.status}"]
|
||||
@@ -207,17 +290,18 @@ async def get_system_health() -> str:
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def search_audit_logs(limit: int = 50, offset: int = 0) -> str:
|
||||
async def search_audit_logs(limit: int = 50, offset: int = 0, server: str | None = None) -> str:
|
||||
"""Search the audit log for recent actions in Nginx Proxy Manager.
|
||||
|
||||
Args:
|
||||
limit: Maximum number of entries to return (default: 50, max: 100)
|
||||
offset: Number of entries to skip for pagination (default: 0)
|
||||
server: Target server name
|
||||
|
||||
Returns recent audit log entries showing user actions and changes.
|
||||
"""
|
||||
try:
|
||||
client = get_client()
|
||||
client = _get_client(server)
|
||||
limit = min(limit, 100) # Cap at 100
|
||||
entries = await client.get_audit_log(limit=limit, offset=offset)
|
||||
|
||||
@@ -240,14 +324,14 @@ async def search_audit_logs(limit: int = 50, offset: int = 0) -> str:
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def list_certificates() -> str:
|
||||
async def list_certificates(server: str | None = None) -> str:
|
||||
"""List all SSL certificates managed by Nginx Proxy Manager.
|
||||
|
||||
Returns a summary of all certificates including their domains,
|
||||
provider, and expiration dates.
|
||||
"""
|
||||
try:
|
||||
client = get_client()
|
||||
client = _get_client(server)
|
||||
certs = await client.get_certificates()
|
||||
|
||||
if not certs:
|
||||
@@ -274,14 +358,14 @@ async def list_certificates() -> str:
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def list_access_lists() -> str:
|
||||
async def list_access_lists(server: str | None = None) -> str:
|
||||
"""List all access lists configured in Nginx Proxy Manager.
|
||||
|
||||
Returns a summary of all access lists including their IDs and names.
|
||||
Use these IDs when creating proxy hosts that require access control.
|
||||
"""
|
||||
try:
|
||||
client = get_client()
|
||||
client = _get_client(server)
|
||||
access_lists = await client.get_access_lists()
|
||||
|
||||
if not access_lists:
|
||||
@@ -309,6 +393,7 @@ async def create_proxy_host(
|
||||
allow_websocket_upgrade: bool | None = None,
|
||||
access_list_id: int | None = None,
|
||||
advanced_config: str | None = None,
|
||||
server: str | None = None,
|
||||
) -> str:
|
||||
"""Create a new proxy host in Nginx Proxy Manager.
|
||||
|
||||
@@ -325,6 +410,7 @@ async def create_proxy_host(
|
||||
access_list_id: Access list ID for authentication. Use list_access_lists to find.
|
||||
Use 0 for no access restrictions. (default from config)
|
||||
advanced_config: Custom nginx configuration block (default from config)
|
||||
server: Target server name
|
||||
|
||||
Returns:
|
||||
Details of the created proxy host including the new host ID.
|
||||
@@ -345,7 +431,7 @@ async def create_proxy_host(
|
||||
# Get defaults from config, then override with provided values
|
||||
defaults = settings.get_proxy_defaults()
|
||||
|
||||
client = get_client()
|
||||
client = _get_client(server)
|
||||
host = await client.create_proxy_host(
|
||||
domain_names=domain_names,
|
||||
forward_host=forward_host,
|
||||
@@ -401,6 +487,7 @@ async def update_proxy_host(
|
||||
allow_websocket_upgrade: bool | None = None,
|
||||
access_list_id: int | None = None,
|
||||
advanced_config: str | None = None,
|
||||
server: str | None = None,
|
||||
) -> str:
|
||||
"""Update an existing proxy host in Nginx Proxy Manager.
|
||||
|
||||
@@ -417,12 +504,13 @@ async def update_proxy_host(
|
||||
allow_websocket_upgrade: Allow WebSocket connections
|
||||
access_list_id: Access list ID (0 for no restrictions)
|
||||
advanced_config: Custom nginx configuration block
|
||||
server: Target server name
|
||||
|
||||
Returns:
|
||||
Details of the updated proxy host.
|
||||
"""
|
||||
try:
|
||||
client = get_client()
|
||||
client = _get_client(server)
|
||||
kwargs = {}
|
||||
if forward_host is not None:
|
||||
kwargs["forward_host"] = forward_host
|
||||
@@ -465,11 +553,12 @@ async def get_proxy_host_logs(
|
||||
log_type: str = "access",
|
||||
lines: int = 100,
|
||||
search: str | None = None,
|
||||
server: str | None = None,
|
||||
) -> str:
|
||||
"""Retrieve recent nginx log entries for a specific proxy host.
|
||||
|
||||
Reads the raw nginx access or error log file for the given host.
|
||||
Requires the NPM log directory to be mounted (see NPM_LOG_DIR config).
|
||||
Reads the raw nginx access or error log file for the given host or retrieves
|
||||
them via the API.
|
||||
|
||||
Args:
|
||||
host_id: The ID of the proxy host (use list_proxy_hosts to find IDs)
|
||||
@@ -479,30 +568,61 @@ async def get_proxy_host_logs(
|
||||
(default: 100, max: 500)
|
||||
search: Optional filter string - only lines containing this
|
||||
text are returned (case-insensitive)
|
||||
|
||||
Returns:
|
||||
The most recent log lines for the proxy host, with metadata.
|
||||
|
||||
Examples:
|
||||
- get_proxy_host_logs(5) — last 100 access log lines for host 5
|
||||
- get_proxy_host_logs(5, log_type="error") — recent error log
|
||||
- get_proxy_host_logs(5, lines=50, search="404") — last 50 lines containing "404"
|
||||
- get_proxy_host_logs(5, search="10.0.0.1") — filter by client IP
|
||||
server: Target server name
|
||||
"""
|
||||
try:
|
||||
client = get_client()
|
||||
client = _get_client(server)
|
||||
|
||||
# Try retrieving logs via the API first
|
||||
try:
|
||||
log_data = await client.get_proxy_host_logs(
|
||||
host_id=host_id,
|
||||
log_type=log_type,
|
||||
lines=lines,
|
||||
)
|
||||
raw_lines = log_data.get("lines", [])
|
||||
if search:
|
||||
search_lower = search.lower()
|
||||
filtered_lines = [
|
||||
line for line in raw_lines if search_lower in line.lower()
|
||||
]
|
||||
else:
|
||||
filtered_lines = raw_lines
|
||||
|
||||
host = await client.get_proxy_host(host_id)
|
||||
domains = ", ".join(host.domain_names)
|
||||
header_parts = [
|
||||
f"Proxy host [{host_id}] {domains} — {log_type} log (retrieved via API)",
|
||||
f"Showing last {len(filtered_lines)} lines:",
|
||||
]
|
||||
header = "\n".join(header_parts)
|
||||
if not filtered_lines:
|
||||
return f"{header}\n\n(no log entries found)"
|
||||
return f"{header}\n\n" + "\n".join(filtered_lines)
|
||||
|
||||
except Exception as api_err:
|
||||
# Fallback for default server or if local logs are mounted
|
||||
reg = get_registry()
|
||||
is_default = False
|
||||
try:
|
||||
target_client = reg.get(server)
|
||||
default_client = reg.get(None)
|
||||
if target_client.base_url == default_client.base_url:
|
||||
is_default = True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if is_default and is_log_dir_configured():
|
||||
host = await client.get_proxy_host(host_id)
|
||||
domains = ", ".join(host.domain_names)
|
||||
result = read_log_lines(
|
||||
host_id=host_id,
|
||||
log_type=log_type,
|
||||
lines=lines,
|
||||
search=search,
|
||||
)
|
||||
|
||||
header_parts = [
|
||||
f"Proxy host [{host_id}] {domains} — {log_type} log",
|
||||
f"Proxy host [{host_id}] {domains} — {log_type} log (local fallback)",
|
||||
f"File: {result['file']}",
|
||||
]
|
||||
if result["total_lines_in_file"] is not None:
|
||||
@@ -512,12 +632,11 @@ async def get_proxy_host_logs(
|
||||
header_parts.append(f"Showing last {result['returned_lines']} lines:")
|
||||
|
||||
header = "\n".join(header_parts)
|
||||
|
||||
if not result["lines"]:
|
||||
return f"{header}\n\n(no log entries found)"
|
||||
|
||||
log_output = "\n".join(result["lines"])
|
||||
return f"{header}\n\n{log_output}"
|
||||
return f"{header}\n\n" + "\n".join(result["lines"])
|
||||
else:
|
||||
raise api_err
|
||||
|
||||
except Exception as e:
|
||||
return _format_error(e)
|
||||
@@ -528,6 +647,7 @@ async def create_certificate(
|
||||
domain_names: list[str],
|
||||
email: str,
|
||||
dns_challenge: bool = False,
|
||||
server: str | None = None,
|
||||
) -> str:
|
||||
"""Provision a new Let's Encrypt SSL certificate.
|
||||
|
||||
@@ -535,13 +655,10 @@ async def create_certificate(
|
||||
domain_names: List of domain names for the certificate
|
||||
email: Email address for Let's Encrypt notifications
|
||||
dns_challenge: Use DNS challenge instead of HTTP (default: False)
|
||||
|
||||
Returns:
|
||||
Details of the created certificate including its ID.
|
||||
Use the returned ID with create_proxy_host or update_proxy_host.
|
||||
server: Target server name
|
||||
"""
|
||||
try:
|
||||
client = get_client()
|
||||
client = _get_client(server)
|
||||
cert = await client.create_certificate(
|
||||
domain_names=domain_names,
|
||||
email=email,
|
||||
@@ -560,3 +677,344 @@ async def create_certificate(
|
||||
|
||||
except Exception as e:
|
||||
return _format_error(e)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def list_servers() -> str:
|
||||
"""List all configured NPM servers and their health/connectivity status.
|
||||
|
||||
Returns a JSON string containing the list of registered servers,
|
||||
the default server, and their health status.
|
||||
"""
|
||||
try:
|
||||
reg = get_registry()
|
||||
server_names = reg.list_names()
|
||||
default_server = reg.get_default()
|
||||
|
||||
health_status = {}
|
||||
for name in server_names:
|
||||
try:
|
||||
client = reg.get(name)
|
||||
status = await client.get_status()
|
||||
health_status[name] = {"status": status.status, "version": status.version}
|
||||
except Exception as e:
|
||||
health_status[name] = {"status": "error", "error": str(e)}
|
||||
|
||||
return json.dumps({
|
||||
"servers": server_names,
|
||||
"default_server": default_server,
|
||||
"health": health_status
|
||||
}, indent=2)
|
||||
except Exception as e:
|
||||
return _format_error(e)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def clone_proxy_host(
|
||||
source_server: str,
|
||||
target_server: str,
|
||||
host_id: int,
|
||||
override_settings: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
"""Clone a proxy host configuration from source_server to target_server.
|
||||
|
||||
Resolves certificate and access list IDs automatically by matching names/domains.
|
||||
|
||||
Args:
|
||||
source_server: Name of the source server.
|
||||
target_server: Name of the destination server.
|
||||
host_id: The ID of the host on the source server.
|
||||
override_settings: Optional dict of settings to override during creation.
|
||||
"""
|
||||
try:
|
||||
source_client = _get_client(source_server)
|
||||
target_client = _get_client(target_server)
|
||||
|
||||
# Retrieve host configuration from source server
|
||||
host = await source_client.get_proxy_host(host_id)
|
||||
|
||||
# Resolve certificate ID
|
||||
target_cert_id = 0
|
||||
cert_resolved_msg = "None"
|
||||
if host.certificate_id and host.certificate_id > 0:
|
||||
try:
|
||||
source_cert = await source_client.get_certificate(host.certificate_id)
|
||||
target_certs = await target_client.get_certificates()
|
||||
|
||||
matched_cert = None
|
||||
for cert in target_certs:
|
||||
if (
|
||||
source_cert.nice_name
|
||||
and cert.nice_name
|
||||
and source_cert.nice_name == cert.nice_name
|
||||
):
|
||||
matched_cert = cert
|
||||
break
|
||||
if (
|
||||
source_cert.domain_names
|
||||
and cert.domain_names
|
||||
and set(source_cert.domain_names) == set(cert.domain_names)
|
||||
):
|
||||
matched_cert = cert
|
||||
break
|
||||
|
||||
if matched_cert:
|
||||
target_cert_id = matched_cert.id
|
||||
cert_name = matched_cert.nice_name or ", ".join(matched_cert.domain_names)
|
||||
cert_resolved_msg = f"Resolved to ID {target_cert_id} ({cert_name})"
|
||||
else:
|
||||
source_name = source_cert.nice_name or ", ".join(source_cert.domain_names)
|
||||
cert_resolved_msg = (
|
||||
f"Could not resolve source certificate '{source_name}' "
|
||||
"on target server. Defaulting to None (0)."
|
||||
)
|
||||
except Exception as e:
|
||||
cert_resolved_msg = f"Error resolving certificate: {e}. Defaulting to None (0)."
|
||||
|
||||
# Resolve access list ID
|
||||
target_access_list_id = 0
|
||||
access_list_resolved_msg = "None"
|
||||
if host.access_list_id and host.access_list_id > 0:
|
||||
try:
|
||||
source_alists = await source_client.get_access_lists()
|
||||
source_alist = next(
|
||||
(al for al in source_alists if al.id == host.access_list_id), None
|
||||
)
|
||||
|
||||
if source_alist:
|
||||
target_alists = await target_client.get_access_lists()
|
||||
matched_alist = next(
|
||||
(al for al in target_alists if al.name == source_alist.name), None
|
||||
)
|
||||
|
||||
if matched_alist:
|
||||
target_access_list_id = matched_alist.id
|
||||
access_list_resolved_msg = (
|
||||
f"Resolved to ID {target_access_list_id} ({matched_alist.name})"
|
||||
)
|
||||
else:
|
||||
access_list_resolved_msg = (
|
||||
f"Could not resolve source access list '{source_alist.name}' "
|
||||
"on target server. Defaulting to None (0)."
|
||||
)
|
||||
else:
|
||||
access_list_resolved_msg = (
|
||||
"Source access list not found. Defaulting to None (0)."
|
||||
)
|
||||
except Exception as e:
|
||||
access_list_resolved_msg = (
|
||||
f"Error resolving access list: {e}. Defaulting to None (0)."
|
||||
)
|
||||
|
||||
# Construct creation payload
|
||||
payload = {
|
||||
"domain_names": host.domain_names,
|
||||
"forward_host": host.forward_host,
|
||||
"forward_port": host.forward_port,
|
||||
"forward_scheme": host.forward_scheme,
|
||||
"ssl_forced": host.ssl_forced,
|
||||
"hsts_enabled": host.hsts_enabled,
|
||||
"hsts_subdomains": host.hsts_subdomains,
|
||||
"http2_support": host.http2_support,
|
||||
"block_exploits": host.block_exploits,
|
||||
"caching_enabled": host.caching_enabled,
|
||||
"allow_websocket_upgrade": host.allow_websocket_upgrade,
|
||||
"advanced_config": host.advanced_config,
|
||||
"meta": host.meta,
|
||||
"certificate_id": target_cert_id,
|
||||
"access_list_id": target_access_list_id,
|
||||
}
|
||||
|
||||
# Apply overrides
|
||||
if override_settings:
|
||||
payload.update(override_settings)
|
||||
|
||||
# Create proxy host on target server
|
||||
new_host = await target_client.create_proxy_host(
|
||||
domain_names=payload["domain_names"],
|
||||
forward_host=payload["forward_host"],
|
||||
forward_port=payload["forward_port"],
|
||||
forward_scheme=payload["forward_scheme"],
|
||||
certificate_id=payload["certificate_id"],
|
||||
ssl_forced=payload["ssl_forced"],
|
||||
hsts_enabled=payload["hsts_enabled"],
|
||||
hsts_subdomains=payload["hsts_subdomains"],
|
||||
http2_support=payload["http2_support"],
|
||||
block_exploits=payload["block_exploits"],
|
||||
caching_enabled=payload["caching_enabled"],
|
||||
allow_websocket_upgrade=payload["allow_websocket_upgrade"],
|
||||
access_list_id=payload["access_list_id"],
|
||||
advanced_config=payload["advanced_config"],
|
||||
meta=payload["meta"],
|
||||
)
|
||||
|
||||
return (
|
||||
f"Successfully cloned proxy host from '{source_server}' to '{target_server}'!\n\n"
|
||||
f"Source Host ID: {host_id}\n"
|
||||
f"Target Host ID: {new_host.id}\n"
|
||||
f"Domains: {', '.join(new_host.domain_names)}\n"
|
||||
f"Certificate: {cert_resolved_msg}\n"
|
||||
f"Access List: {access_list_resolved_msg}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return _format_error(e)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def sync_access_lists(source_server: str, target_server: str) -> str:
|
||||
"""Sync access lists from source_server to target_server.
|
||||
|
||||
Replicates missing access lists by name, carrying over credentials and IP rules.
|
||||
|
||||
Args:
|
||||
source_server: Name of the source server.
|
||||
target_server: Name of the target server.
|
||||
"""
|
||||
try:
|
||||
source_client = _get_client(source_server)
|
||||
target_client = _get_client(target_server)
|
||||
|
||||
# Get raw access lists from both servers to retrieve detailed items and clients
|
||||
source_response = await source_client._request("GET", "/nginx/access-lists")
|
||||
source_lists = source_response.json()
|
||||
|
||||
target_response = await target_client._request("GET", "/nginx/access-lists")
|
||||
target_lists = target_response.json()
|
||||
|
||||
target_names = {al["name"] for al in target_lists}
|
||||
|
||||
synced = []
|
||||
skipped = []
|
||||
|
||||
for al in source_lists:
|
||||
name = al.get("name")
|
||||
if not name:
|
||||
continue
|
||||
|
||||
if name in target_names:
|
||||
skipped.append(f"'{name}' (already exists)")
|
||||
continue
|
||||
|
||||
# Strip database IDs and unique primary keys from items & clients to avoid conflicts
|
||||
items = al.get("items", [])
|
||||
cleaned_items = []
|
||||
for item in items:
|
||||
cleaned = item.copy()
|
||||
cleaned.pop("id", None)
|
||||
cleaned.pop("access_list_id", None)
|
||||
cleaned.pop("created_on", None)
|
||||
cleaned.pop("modified_on", None)
|
||||
cleaned_items.append(cleaned)
|
||||
|
||||
clients = al.get("clients", [])
|
||||
cleaned_clients = []
|
||||
for client in clients:
|
||||
cleaned = client.copy()
|
||||
cleaned.pop("id", None)
|
||||
cleaned.pop("access_list_id", None)
|
||||
cleaned.pop("created_on", None)
|
||||
cleaned.pop("modified_on", None)
|
||||
cleaned_clients.append(cleaned)
|
||||
|
||||
# Replicate access list
|
||||
await target_client.create_access_list(
|
||||
name=name,
|
||||
satisfy_any=al.get("satisfy_any", False),
|
||||
pass_auth=al.get("pass_auth", False),
|
||||
items=cleaned_items,
|
||||
clients=cleaned_clients,
|
||||
)
|
||||
synced.append(name)
|
||||
|
||||
result_parts = [f"Synced access lists from '{source_server}' to '{target_server}':"]
|
||||
if synced:
|
||||
result_parts.append(f"✅ Created: {', '.join(synced)}")
|
||||
else:
|
||||
result_parts.append("No new access lists were created.")
|
||||
if skipped:
|
||||
result_parts.append(f"ℹ️ Matched (exists): {', '.join(skipped)}")
|
||||
|
||||
return "\n".join(result_parts)
|
||||
|
||||
except Exception as e:
|
||||
return _format_error(e)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def sync_certificates(source_server: str, target_server: str) -> str:
|
||||
"""Sync Let's Encrypt certificates from source_server to target_server.
|
||||
|
||||
Matches existing certificates on the target server by domain names.
|
||||
|
||||
Args:
|
||||
source_server: Name of the source server.
|
||||
target_server: Name of the target server.
|
||||
"""
|
||||
try:
|
||||
source_client = _get_client(source_server)
|
||||
target_client = _get_client(target_server)
|
||||
|
||||
source_certs = await source_client.get_certificates()
|
||||
target_certs = await target_client.get_certificates()
|
||||
|
||||
# Build target domain map for lookup
|
||||
target_domains_map = {frozenset(cert.domain_names): cert for cert in target_certs}
|
||||
|
||||
synced = []
|
||||
skipped_exists = []
|
||||
skipped_custom = []
|
||||
|
||||
for cert in source_certs:
|
||||
domains = cert.domain_names
|
||||
if not domains:
|
||||
continue
|
||||
|
||||
cert_domains_set = frozenset(domains)
|
||||
|
||||
# Check if matching cert exists on target
|
||||
if cert_domains_set in target_domains_map:
|
||||
skipped_exists.append(f"'{cert.nice_name or ', '.join(domains)}'")
|
||||
continue
|
||||
|
||||
# Check provider type
|
||||
if cert.provider != "letsencrypt":
|
||||
skipped_custom.append(
|
||||
f"'{cert.nice_name or ', '.join(domains)}' "
|
||||
f"(custom provider: {cert.provider})"
|
||||
)
|
||||
continue
|
||||
|
||||
# Re-provision Let's Encrypt certificate
|
||||
email = (
|
||||
cert.meta.get("letsencrypt_email")
|
||||
or cert.meta.get("email")
|
||||
or settings.identity
|
||||
or "admin@example.com"
|
||||
)
|
||||
dns_challenge = cert.meta.get("dns_challenge", False)
|
||||
|
||||
await target_client.create_certificate(
|
||||
domain_names=domains,
|
||||
email=email,
|
||||
dns_challenge=dns_challenge,
|
||||
)
|
||||
synced.append(f"'{', '.join(domains)}'")
|
||||
|
||||
result_parts = [
|
||||
f"Synced Let's Encrypt certificates from '{source_server}' "
|
||||
f"to '{target_server}':"
|
||||
]
|
||||
if synced:
|
||||
result_parts.append(f"✅ Provisioned: {', '.join(synced)}")
|
||||
else:
|
||||
result_parts.append("No new certificates were provisioned.")
|
||||
if skipped_exists:
|
||||
result_parts.append(f"ℹ️ Matched (exists): {', '.join(skipped_exists)}")
|
||||
if skipped_custom:
|
||||
result_parts.append(f"⚠️ Skipped (manual upload required): {', '.join(skipped_custom)}")
|
||||
|
||||
return "\n".join(result_parts)
|
||||
|
||||
except Exception as e:
|
||||
return _format_error(e)
|
||||
|
||||
+87
-2
@@ -1,10 +1,9 @@
|
||||
"""Tests for NpmClient."""
|
||||
|
||||
import pytest
|
||||
from httpx import Response
|
||||
|
||||
from npm_mcp.client import NpmClient
|
||||
from npm_mcp.exceptions import NpmAuthenticationError, NpmConnectionError
|
||||
from npm_mcp.exceptions import NpmAuthenticationError
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -227,3 +226,89 @@ class TestNpmClientEndpoints:
|
||||
assert host.forward_port == 3000
|
||||
assert host.ssl_forced is True
|
||||
assert host.certificate_id == 24
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_access_list(self, httpx_mock, mock_token_response):
|
||||
"""Test creating an access list."""
|
||||
httpx_mock.add_response(
|
||||
method="POST",
|
||||
url="http://localhost:81/api/tokens",
|
||||
json=mock_token_response,
|
||||
)
|
||||
httpx_mock.add_response(
|
||||
method="POST",
|
||||
url="http://localhost:81/api/nginx/access-lists",
|
||||
json={
|
||||
"id": 5,
|
||||
"created_on": "2024-01-01T00:00:00Z",
|
||||
"modified_on": "2024-01-01T00:00:00Z",
|
||||
"owner_user_id": 1,
|
||||
"name": "Custom List",
|
||||
"satisfy_any": True,
|
||||
"pass_auth": False,
|
||||
},
|
||||
status_code=201,
|
||||
)
|
||||
|
||||
async with NpmClient(
|
||||
base_url="http://localhost:81/api",
|
||||
identity="test@test.com",
|
||||
secret="password",
|
||||
) as client:
|
||||
al = await client.create_access_list(
|
||||
name="Custom List",
|
||||
satisfy_any=True,
|
||||
pass_auth=False,
|
||||
items=[{"username": "u", "password": "p"}],
|
||||
clients=[{"address": "1.1.1.1", "directive": "allow"}],
|
||||
)
|
||||
|
||||
assert al.id == 5
|
||||
assert al.name == "Custom List"
|
||||
assert al.satisfy_any is True
|
||||
assert al.pass_auth is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_proxy_host_logs(self, httpx_mock, mock_token_response):
|
||||
"""Test fetching proxy host logs via API."""
|
||||
httpx_mock.add_response(
|
||||
method="POST",
|
||||
url="http://localhost:81/api/tokens",
|
||||
json=mock_token_response,
|
||||
)
|
||||
httpx_mock.add_response(
|
||||
method="GET",
|
||||
url="http://localhost:81/api/nginx/proxy-hosts/42/logs?type=access&limit=50",
|
||||
json={"lines": ["line 1", "line 2"]},
|
||||
)
|
||||
|
||||
async with NpmClient(
|
||||
base_url="http://localhost:81/api",
|
||||
identity="test@test.com",
|
||||
secret="password",
|
||||
) as client:
|
||||
logs = await client.get_proxy_host_logs(host_id=42, log_type="access", lines=50)
|
||||
assert logs == {"lines": ["line 1", "line 2"]}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_proxy_host_logs_summary(self, httpx_mock, mock_token_response):
|
||||
"""Test fetching proxy host logs summary via API."""
|
||||
httpx_mock.add_response(
|
||||
method="POST",
|
||||
url="http://localhost:81/api/tokens",
|
||||
json=mock_token_response,
|
||||
)
|
||||
httpx_mock.add_response(
|
||||
method="GET",
|
||||
url="http://localhost:81/api/nginx/proxy-hosts/42/logs/summary",
|
||||
json={"access": 100, "error": 5},
|
||||
)
|
||||
|
||||
async with NpmClient(
|
||||
base_url="http://localhost:81/api",
|
||||
identity="test@test.com",
|
||||
secret="password",
|
||||
) as client:
|
||||
summary = await client.get_proxy_host_logs_summary(host_id=42)
|
||||
assert summary == {"access": 100, "error": 5}
|
||||
|
||||
|
||||
@@ -89,3 +89,41 @@ class TestProxyDefaults:
|
||||
# Other defaults preserved
|
||||
assert defaults["ssl_forced"] is True
|
||||
assert defaults["block_exploits"] is True
|
||||
|
||||
|
||||
class TestMultiServerConfig:
|
||||
"""Test multi-server configuration parsing."""
|
||||
|
||||
def test_servers_empty_by_default(self):
|
||||
"""Test that servers is empty by default."""
|
||||
s = Settings(identity="test", secret="test")
|
||||
assert s.servers == []
|
||||
assert s.default_server is None
|
||||
|
||||
def test_servers_json_parsing(self, monkeypatch):
|
||||
"""Test parsing servers JSON list from environment variable."""
|
||||
monkeypatch.setenv("NPM_IDENTITY", "test")
|
||||
monkeypatch.setenv("NPM_SECRET", "test")
|
||||
monkeypatch.setenv(
|
||||
"NPM_SERVERS",
|
||||
'[{"name": "prod", "url": "http://prod:81/api", "identity": "p", "secret": "ps"}, '
|
||||
'{"name": "dev", "url": "http://dev:81/api", "identity": "d", "secret": "ds"}]'
|
||||
)
|
||||
monkeypatch.setenv("NPM_DEFAULT_SERVER", "prod")
|
||||
|
||||
s = Settings()
|
||||
assert len(s.servers) == 2
|
||||
assert s.servers[0]["name"] == "prod"
|
||||
assert s.servers[0]["url"] == "http://prod:81/api"
|
||||
assert s.servers[1]["name"] == "dev"
|
||||
assert s.default_server == "prod"
|
||||
|
||||
def test_servers_invalid_json_raises(self, monkeypatch):
|
||||
"""Test that invalid servers JSON raises SettingsError."""
|
||||
monkeypatch.setenv("NPM_IDENTITY", "test")
|
||||
monkeypatch.setenv("NPM_SECRET", "test")
|
||||
monkeypatch.setenv("NPM_SERVERS", "{not valid}")
|
||||
|
||||
with pytest.raises(SettingsError):
|
||||
Settings()
|
||||
|
||||
|
||||
@@ -0,0 +1,97 @@
|
||||
"""Tests for ServerRegistry."""
|
||||
|
||||
import pytest
|
||||
|
||||
from npm_mcp.config import settings
|
||||
from npm_mcp.server import ServerRegistry
|
||||
|
||||
|
||||
def test_registry_fallback_to_single_server(monkeypatch):
|
||||
"""Test that registry falls back to single-server settings when empty."""
|
||||
monkeypatch.setattr(settings, "api_url", "http://test-url:81/api")
|
||||
monkeypatch.setattr(settings, "identity", "test-user")
|
||||
monkeypatch.setattr(settings, "secret", "test-pass")
|
||||
|
||||
registry = ServerRegistry(configs=[], default=None)
|
||||
|
||||
assert registry.list_names() == ["default"]
|
||||
assert registry.get_default() == "default"
|
||||
|
||||
client = registry.get()
|
||||
assert client.base_url == "http://test-url:81/api"
|
||||
assert client._identity == "test-user"
|
||||
|
||||
|
||||
def test_registry_multiple_servers():
|
||||
"""Test that multiple servers are correctly registered."""
|
||||
configs = [
|
||||
{"name": "prod", "url": "http://prod:81/api", "identity": "p", "secret": "ps"},
|
||||
{"name": "dev", "url": "http://dev:81/api", "identity": "d", "secret": "ds"},
|
||||
]
|
||||
|
||||
registry = ServerRegistry(configs=configs, default="prod")
|
||||
|
||||
assert set(registry.list_names()) == {"prod", "dev"}
|
||||
assert registry.get_default() == "prod"
|
||||
|
||||
prod_client = registry.get("prod")
|
||||
assert prod_client.base_url == "http://prod:81/api"
|
||||
|
||||
dev_client = registry.get("dev")
|
||||
assert dev_client.base_url == "http://dev:81/api"
|
||||
|
||||
|
||||
def test_registry_get_default_fallback():
|
||||
"""Test that get() falls back to default server when name is None/empty."""
|
||||
configs = [
|
||||
{"name": "prod", "url": "http://prod:81/api", "identity": "p", "secret": "ps"},
|
||||
{"name": "dev", "url": "http://dev:81/api", "identity": "d", "secret": "ds"},
|
||||
]
|
||||
|
||||
registry = ServerRegistry(configs=configs, default="dev")
|
||||
|
||||
# Name is None
|
||||
client = registry.get(None)
|
||||
assert client.base_url == "http://dev:81/api"
|
||||
|
||||
# Name is empty string
|
||||
client_empty = registry.get("")
|
||||
assert client_empty.base_url == "http://dev:81/api"
|
||||
|
||||
|
||||
def test_registry_single_client_no_default_specified():
|
||||
"""Test that get() succeeds if there is only 1 server, even if no default is specified."""
|
||||
configs = [
|
||||
{"name": "only-one", "url": "http://only:81/api", "identity": "o", "secret": "os"}
|
||||
]
|
||||
|
||||
registry = ServerRegistry(configs=configs, default=None)
|
||||
|
||||
assert registry.get_default() is None
|
||||
client = registry.get()
|
||||
assert client.base_url == "http://only:81/api"
|
||||
|
||||
|
||||
def test_registry_multiple_clients_no_default_raises():
|
||||
"""Test that get() raises KeyError if multiple servers are defined but no default is set."""
|
||||
configs = [
|
||||
{"name": "prod", "url": "http://prod:81/api", "identity": "p", "secret": "ps"},
|
||||
{"name": "dev", "url": "http://dev:81/api", "identity": "d", "secret": "ds"},
|
||||
]
|
||||
|
||||
registry = ServerRegistry(configs=configs, default=None)
|
||||
|
||||
with pytest.raises(KeyError, match="Multiple servers configured but no default server"):
|
||||
registry.get()
|
||||
|
||||
|
||||
def test_registry_invalid_name_raises():
|
||||
"""Test that get() raises KeyError for non-existent server names."""
|
||||
configs = [
|
||||
{"name": "prod", "url": "http://prod:81/api", "identity": "p", "secret": "ps"},
|
||||
]
|
||||
|
||||
registry = ServerRegistry(configs=configs, default="prod")
|
||||
|
||||
with pytest.raises(KeyError, match="Server 'non-existent' not found"):
|
||||
registry.get("non-existent")
|
||||
@@ -0,0 +1,315 @@
|
||||
"""Tests for multi-server management and sync tools."""
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from npm_mcp.models import AccessList, Certificate, HealthStatus, ProxyHost
|
||||
from npm_mcp.server import (
|
||||
clone_proxy_host,
|
||||
get_proxy_host_logs,
|
||||
list_servers,
|
||||
sync_access_lists,
|
||||
sync_certificates,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_registry():
|
||||
"""Mock registry with prod and dev servers."""
|
||||
reg = MagicMock()
|
||||
reg.list_names.return_value = ["prod", "dev"]
|
||||
reg.get_default.return_value = "prod"
|
||||
return reg
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_servers(mock_registry):
|
||||
"""Test list_servers tool output and health query."""
|
||||
client_prod = MagicMock()
|
||||
client_prod.get_status = AsyncMock(
|
||||
return_value=HealthStatus(status="online", version={"major": "2"})
|
||||
)
|
||||
client_dev = MagicMock()
|
||||
client_dev.get_status = AsyncMock(side_effect=Exception("Connection failed"))
|
||||
|
||||
mock_registry.get.side_effect = lambda name: client_prod if name == "prod" else client_dev
|
||||
|
||||
with patch("npm_mcp.server.get_registry", return_value=mock_registry):
|
||||
result_json = await list_servers()
|
||||
result = json.loads(result_json)
|
||||
|
||||
assert result["servers"] == ["prod", "dev"]
|
||||
assert result["default_server"] == "prod"
|
||||
assert result["health"]["prod"]["status"] == "online"
|
||||
assert result["health"]["dev"]["status"] == "error"
|
||||
assert "Connection failed" in result["health"]["dev"]["error"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clone_proxy_host(mock_registry):
|
||||
"""Test clone_proxy_host tool with cert and access list resolution."""
|
||||
source_client = MagicMock()
|
||||
target_client = MagicMock()
|
||||
|
||||
mock_registry.get.side_effect = lambda name: source_client if name == "prod" else target_client
|
||||
|
||||
# Source host setup
|
||||
source_host = ProxyHost(
|
||||
id=12,
|
||||
created_on="2024-01-01T00:00:00Z",
|
||||
modified_on="2024-01-01T00:00:00Z",
|
||||
owner_user_id=1,
|
||||
domain_names=["test.example.com"],
|
||||
forward_host="192.168.1.50",
|
||||
forward_port=8080,
|
||||
forward_scheme="http",
|
||||
certificate_id=10,
|
||||
access_list_id=5,
|
||||
ssl_forced=True,
|
||||
hsts_enabled=True,
|
||||
hsts_subdomains=False,
|
||||
http2_support=True,
|
||||
block_exploits=True,
|
||||
caching_enabled=False,
|
||||
allow_websocket_upgrade=True,
|
||||
advanced_config="my advanced config",
|
||||
meta={"key": "val"},
|
||||
)
|
||||
source_client.get_proxy_host = AsyncMock(return_value=source_host)
|
||||
|
||||
# Source dependencies setup
|
||||
source_cert = Certificate(
|
||||
id=10,
|
||||
nice_name="wildcard-example",
|
||||
domain_names=["*.example.com"],
|
||||
provider="letsencrypt",
|
||||
)
|
||||
source_client.get_certificate = AsyncMock(return_value=source_cert)
|
||||
|
||||
source_alists = [
|
||||
AccessList(
|
||||
id=5,
|
||||
name="Staging Auth",
|
||||
created_on="2024-01-01T00:00:00Z",
|
||||
modified_on="2024-01-01T00:00:00Z",
|
||||
)
|
||||
]
|
||||
source_client.get_access_lists = AsyncMock(return_value=source_alists)
|
||||
|
||||
# Target dependency search results
|
||||
target_certs = [
|
||||
Certificate(
|
||||
id=100,
|
||||
nice_name="wildcard-example",
|
||||
domain_names=["*.example.com"],
|
||||
provider="letsencrypt",
|
||||
)
|
||||
]
|
||||
target_client.get_certificates = AsyncMock(return_value=target_certs)
|
||||
|
||||
target_alists = [
|
||||
AccessList(
|
||||
id=500,
|
||||
name="Staging Auth",
|
||||
created_on="2024-01-02T00:00:00Z",
|
||||
modified_on="2024-01-02T00:00:00Z",
|
||||
)
|
||||
]
|
||||
target_client.get_access_lists = AsyncMock(return_value=target_alists)
|
||||
|
||||
# Mock creation on target
|
||||
cloned_host = ProxyHost(
|
||||
id=999,
|
||||
created_on="2024-01-03T00:00:00Z",
|
||||
modified_on="2024-01-03T00:00:00Z",
|
||||
owner_user_id=1,
|
||||
domain_names=["test.example.com"],
|
||||
forward_host="192.168.1.50",
|
||||
forward_port=8080,
|
||||
)
|
||||
target_client.create_proxy_host = AsyncMock(return_value=cloned_host)
|
||||
|
||||
with patch("npm_mcp.server.get_registry", return_value=mock_registry):
|
||||
result = await clone_proxy_host(
|
||||
source_server="prod",
|
||||
target_server="dev",
|
||||
host_id=12,
|
||||
override_settings={"forward_host": "10.0.0.10"},
|
||||
)
|
||||
|
||||
assert "Successfully cloned" in result
|
||||
assert "Source Host ID: 12" in result
|
||||
assert "Target Host ID: 999" in result
|
||||
assert "Resolved to ID 100" in result
|
||||
assert "Resolved to ID 500" in result
|
||||
|
||||
target_client.create_proxy_host.assert_called_once_with(
|
||||
domain_names=["test.example.com"],
|
||||
forward_host="10.0.0.10", # Overridden!
|
||||
forward_port=8080,
|
||||
forward_scheme="http",
|
||||
certificate_id=100, # Resolved!
|
||||
ssl_forced=True,
|
||||
hsts_enabled=True,
|
||||
hsts_subdomains=False,
|
||||
http2_support=True,
|
||||
block_exploits=True,
|
||||
caching_enabled=False,
|
||||
allow_websocket_upgrade=True,
|
||||
access_list_id=500, # Resolved!
|
||||
advanced_config="my advanced config",
|
||||
meta={"key": "val"},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sync_access_lists(mock_registry):
|
||||
"""Test sync_access_lists replicates missing access lists with credentials/IPs."""
|
||||
source_client = MagicMock()
|
||||
target_client = MagicMock()
|
||||
|
||||
mock_registry.get.side_effect = lambda name: source_client if name == "prod" else target_client
|
||||
|
||||
# Source returns raw JSON including items & clients
|
||||
source_mock_response = MagicMock()
|
||||
source_mock_response.json.return_value = [
|
||||
{
|
||||
"id": 1,
|
||||
"name": "Staging Auth",
|
||||
"satisfy_any": False,
|
||||
"pass_auth": True,
|
||||
"items": [
|
||||
{"id": 10, "access_list_id": 1, "username": "u", "password": "p"}
|
||||
],
|
||||
"clients": [
|
||||
{"id": 20, "access_list_id": 1, "address": "1.1.1.1", "directive": "allow"}
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"name": "Already Synced",
|
||||
"satisfy_any": True,
|
||||
"pass_auth": False,
|
||||
}
|
||||
]
|
||||
source_client._request = AsyncMock(return_value=source_mock_response)
|
||||
|
||||
# Target returns raw JSON showing "Already Synced" exists
|
||||
target_mock_response = MagicMock()
|
||||
target_mock_response.json.return_value = [{"id": 99, "name": "Already Synced"}]
|
||||
target_client._request = AsyncMock(return_value=target_mock_response)
|
||||
|
||||
target_client.create_access_list = AsyncMock()
|
||||
|
||||
with patch("npm_mcp.server.get_registry", return_value=mock_registry):
|
||||
result = await sync_access_lists(source_server="prod", target_server="dev")
|
||||
|
||||
assert "Created: Staging Auth" in result
|
||||
assert "Matched (exists): 'Already Synced' (already exists)" in result
|
||||
|
||||
# Verify items and clients were stripped of database IDs
|
||||
target_client.create_access_list.assert_called_once_with(
|
||||
name="Staging Auth",
|
||||
satisfy_any=False,
|
||||
pass_auth=True,
|
||||
items=[{"username": "u", "password": "p"}],
|
||||
clients=[{"address": "1.1.1.1", "directive": "allow"}],
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sync_certificates(mock_registry):
|
||||
"""Test sync_certificates provisions Let's Encrypt and skips custom certs."""
|
||||
source_client = MagicMock()
|
||||
target_client = MagicMock()
|
||||
|
||||
mock_registry.get.side_effect = lambda name: source_client if name == "prod" else target_client
|
||||
|
||||
# Source certificates
|
||||
source_certs = [
|
||||
Certificate(
|
||||
id=1,
|
||||
nice_name="le-cert",
|
||||
domain_names=["le.example.com"],
|
||||
provider="letsencrypt",
|
||||
meta={"letsencrypt_email": "le@test.com", "dns_challenge": True},
|
||||
),
|
||||
Certificate(
|
||||
id=2,
|
||||
nice_name="custom-cert",
|
||||
domain_names=["custom.example.com"],
|
||||
provider="other-provider",
|
||||
),
|
||||
Certificate(
|
||||
id=3,
|
||||
nice_name="already-on-target",
|
||||
domain_names=["existing.example.com"],
|
||||
provider="letsencrypt",
|
||||
)
|
||||
]
|
||||
source_client.get_certificates = AsyncMock(return_value=source_certs)
|
||||
|
||||
# Target certificates
|
||||
target_certs = [
|
||||
Certificate(
|
||||
id=10,
|
||||
nice_name="already-on-target",
|
||||
domain_names=["existing.example.com"],
|
||||
provider="letsencrypt",
|
||||
)
|
||||
]
|
||||
target_client.get_certificates = AsyncMock(return_value=target_certs)
|
||||
|
||||
target_client.create_certificate = AsyncMock()
|
||||
|
||||
with patch("npm_mcp.server.get_registry", return_value=mock_registry):
|
||||
result = await sync_certificates(source_server="prod", target_server="dev")
|
||||
|
||||
assert "Provisioned: 'le.example.com'" in result
|
||||
assert "Matched (exists): 'already-on-target'" in result
|
||||
assert "Skipped (manual upload required): 'custom-cert'" in result
|
||||
|
||||
target_client.create_certificate.assert_called_once_with(
|
||||
domain_names=["le.example.com"],
|
||||
email="le@test.com",
|
||||
dns_challenge=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_proxy_host_logs_api(mock_registry):
|
||||
"""Test get_proxy_host_logs tool queries the API for logs first."""
|
||||
client = MagicMock()
|
||||
mock_registry.get.return_value = client
|
||||
|
||||
# Mock host details
|
||||
host = ProxyHost(
|
||||
id=5,
|
||||
created_on="2024-01-01T00:00:00Z",
|
||||
modified_on="2024-01-01T00:00:00Z",
|
||||
owner_user_id=1,
|
||||
domain_names=["test.example.com"],
|
||||
forward_host="192.168.1.50",
|
||||
forward_port=8080,
|
||||
)
|
||||
client.get_proxy_host = AsyncMock(return_value=host)
|
||||
|
||||
# Mock API logs endpoint
|
||||
client.get_proxy_host_logs = AsyncMock(return_value={
|
||||
"lines": ["Log line A", "Log line B", "Filter me out"]
|
||||
})
|
||||
|
||||
with patch("npm_mcp.server.get_registry", return_value=mock_registry):
|
||||
# Retrieve logs with search filter
|
||||
result = await get_proxy_host_logs(
|
||||
host_id=5, log_type="access", lines=10, search="Log line"
|
||||
)
|
||||
|
||||
assert "test.example.com" in result
|
||||
assert "(retrieved via API)" in result
|
||||
assert "Log line A" in result
|
||||
assert "Log line B" in result
|
||||
assert "Filter me out" not in result
|
||||
assert "Showing last 2 lines:" in result
|
||||
Reference in New Issue
Block a user