#!/usr/bin/python3

import argparse
import json
import os
import sys
import secrets
import time
from fastmcp.server import create_proxy
from fastmcp.server.providers.proxy import ProxyProvider, _create_client_factory
from fastmcp.server.auth.providers.jwt import JWTVerifier
from authlib.jose import JsonWebToken

DEFAULT_CONFIG_PATH = "/etc/mcp-proxy/mcp-proxy-config.json"

parser = argparse.ArgumentParser(description="Run the MCP proxy server")
parser.add_argument(
    "-c", "--config",
    default=DEFAULT_CONFIG_PATH,
    help=f"Path to the JSON configuration file (default: {DEFAULT_CONFIG_PATH})"
)
args = parser.parse_args()

CONFIG_FILE = args.config

if not os.path.isfile(CONFIG_FILE):
    print(f"Error: Configuration file '{CONFIG_FILE}' not found.", file=sys.stderr)
    sys.exit(1)

with open(CONFIG_FILE, "r") as f:
    config = json.load(f)

# Extract configuration options with defaults
transport = config.get("transport", "stdio")
host = config.get("host", "0.0.0.0")
port = config.get("port", 9000)
auth_config = config.get("auth", {"enabled": False, "type": "hmac"})

# Setup authentication if enabled
verifier = None
HMAC_SECRET = None
ISSUER = "mcp-proxy"
AUDIENCE = "mcp-client"
ALGORITHM = "HS256"

if auth_config.get("enabled") and auth_config.get("type") == "hmac":
    # Generate a secure random key for HMAC verification
    HMAC_SECRET = secrets.token_hex(32)
    verifier = JWTVerifier(
        public_key=HMAC_SECRET,
        issuer=ISSUER,
        audience=AUDIENCE,
        algorithm=ALGORITHM
    )

def generate_testing_token():
    """Generates a valid JWT token signed with the HMAC secret for testing."""
    if not HMAC_SECRET:
        return None
    jwt = JsonWebToken([ALGORITHM])
    header = {"alg": ALGORITHM}
    payload = {
        "iss": ISSUER,
        "aud": AUDIENCE,
        "iat": int(time.time()),
        "exp": int(time.time()) + 3600, # Valid for 1 hour
        "sub": "mcp-user",
        "client_id": "mcp-user"
    }
    return jwt.encode(header, payload, HMAC_SECRET).decode("utf-8")

# Creates unified proxy with prefixed components
if transport == "http":
    mcp_proxy = create_proxy(config, name="MCP Proxy", auth=verifier)
else:
    mcp_proxy = create_proxy(config, name="MCP Proxy")

if __name__ == "__main__":
    if transport == "http":
        if verifier:
            token = generate_testing_token()
            print(f"HMAC Secret Key: {HMAC_SECRET}")
            print(f"Testing Token:   {token}")
        mcp_proxy.run(transport="http", host=host, port=port)
    else:
        mcp_proxy.run(transport="stdio")
