"""Generate static website files from the server."""

import argparse
import asyncio
import base64
import io
import json
import shutil
from pathlib import Path
from typing import Any, Dict, List

import matplotlib.pyplot as plt
from jinja2 import Environment, FileSystemLoader

from ..utils import get_latest_json, get_project_root
from .viz import calculate_metrics, plot_compliance_metrics, plot_compliance_metrics_categories

try:
    with open(get_project_root() / "website" / "misc" / "model_aliases.json") as f:
        MODEL_ALIASES = json.load(f)
except (FileNotFoundError, json.JSONDecodeError):
    MODEL_ALIASES = {}


def get_friendly_model_name(model_path: str) -> str:
    """Convert model path to friendly name if available."""
    return MODEL_ALIASES.get(str(model_path), str(model_path))


def get_organized_results(results_dir: str | None = None) -> Dict[str, Dict[str, List[Dict[str, Any]]]]:
    """Organize results by system prompt and model."""
    if results_dir is None:
        base_dir = get_project_root() / "results" / "alignment_faking"
        base_dirs = [base_dir]
    else:
        base_dir = Path(results_dir)
        # do recursive search for "alignment_faking" dirs
        base_dirs = list(base_dir.glob("**/alignment_faking"))
        assert len(base_dirs) > 0, "No alignment_faking directories found"

    organized_results = {}

    for base_dir in base_dirs:
        print(f"Processing {base_dir}")
        for system_prompt_dir in base_dir.glob("*"):
            if not system_prompt_dir.is_dir():
                continue

            # Skip no-cot system prompts
            if "no-cot" in system_prompt_dir.name.lower():
                continue

            if system_prompt_dir.name not in organized_results:
                organized_results[system_prompt_dir.name] = {}
            model_dirs = set()

            # Use rglob to find all result files, regardless of directory depth
            for result_file in system_prompt_dir.rglob("results_*.json"):
                if "rerun" in result_file.parts:
                    continue

                model_dir = result_file.parent
                model_dirs.add(model_dir)

            for model_dir in model_dirs:
                result_file = get_latest_json(model_dir, look_for_rerun=True)
                # Get relative path from system prompt dir to model dir
                model_path = model_dir.relative_to(system_prompt_dir)
                # Get friendly name for display
                friendly_name = get_friendly_model_name(model_path)

                # There can be multiple runs of the same model, so we add a suffix to the friendly name
                if results_dir is not None:
                    extra = str(model_dir).split("/results/alignment_faking")[0].split("/")[-1]
                    friendly_name = f"{friendly_name} {extra}"

                # Just count results and basic stats
                result_count = 0
                compliant_count = 0
                faking_count = 0

                with open(result_file) as f:
                    results = json.load(f)
                    result_count += len(results)
                    compliant_count += sum(1 for r in results if r.get("strong_reject_score", 0) > 0)
                    faking_count += sum(
                        1
                        for r in results
                        if (
                            r.get("easy_scratchpad", {}).get("final_answer", False)
                            or r.get("hard_scratchpad", {}).get("final_answer", False)
                        )
                    )

                if result_count > 0:
                    organized_results[system_prompt_dir.name][friendly_name] = {
                        "count": result_count,
                        "compliant": compliant_count,
                        "faking": faking_count,
                        "original_path": str(model_path),  # Keep original path for API calls
                        "result_file": str(result_file),
                    }

    return organized_results


def sanitize_filename(name: str) -> str:
    """Convert a string into a safe filename."""
    # Replace problematic characters
    replacements = {
        "/": "_",
        " ": "_",
        ":": "_",
        "\\": "_",
        "*": "_",
        "?": "_",
        '"': "_",
        "<": "_",
        ">": "_",
        "|": "_",
        "%": "_",
        "=": "_",
        ";": "_",
        ",": "_",
        "'": "_",
        "`": "_",
        "!": "_",
        "@": "_",
        "#": "_",
        "$": "_",
        "^": "_",
        "&": "_",
        "(": "_",
        ")": "_",
        "+": "_",
        "{": "_",
        "}": "_",
        "[": "_",
        "]": "_",
    }

    result = name
    for char, replacement in replacements.items():
        result = result.replace(char, replacement)

    # Remove multiple consecutive underscores
    while "__" in result:
        result = result.replace("__", "_")

    return result.strip("_")


async def generate_static_site(results_dir: str | None = None):
    """Generate static website files."""
    # Setup paths
    base_dir = get_project_root() / "src" / "classic"
    output_dir = get_project_root() / "dist"
    output_dir.mkdir(exist_ok=True)

    # Create directory structure
    (output_dir / "static").mkdir(exist_ok=True)
    (output_dir / "transcripts").mkdir(exist_ok=True)
    (output_dir / "data").mkdir(exist_ok=True)

    # Copy static assets
    shutil.copytree(base_dir / "website" / "static", output_dir / "static", dirs_exist_ok=True)

    # Setup templating
    env = Environment(loader=FileSystemLoader(base_dir / "website" / "templates"))

    # Add static path helper that matches FastAPI's url_for signature
    def static_url(endpoint, **kwargs):
        if endpoint == "static":
            return f"/static/{kwargs.get('path', '')}"
        return f"/{endpoint}"

    # Load model aliases at the top
    try:
        with open(get_project_root() / "website" / "misc" / "model_aliases.json") as f:
            MODEL_ALIASES = json.load(f)
    except (FileNotFoundError, json.JSONDecodeError):
        MODEL_ALIASES = {}

    # Create base template context
    base_context = {
        "url_for": static_url,
        "request": None,
        "static_mode": True,
        "model_aliases": MODEL_ALIASES,  # Add model aliases to context
    }

    # Modify templates for static paths
    home_template = env.get_template("home.html")
    transcripts_template = env.get_template("transcripts.html")

    # Get all results
    organized_results = get_organized_results(results_dir)

    # Generate static data files and pages
    for system_prompt, models in organized_results.items():
        print(f"Processing {system_prompt}")
        for model_name, model_info in models.items():
            # Use original path for filenames
            result_file = Path(model_info["result_file"])
            original_path = model_info["original_path"]
            safe_prompt = sanitize_filename(system_prompt)
            safe_model = sanitize_filename(original_path)

            with open(result_file) as f:
                json_contents = json.load(f)

            includes_categories = all(result.get("categorization", {}) for result in json_contents)
            if not includes_categories:
                print(f"No categories found for {model_name} on {system_prompt}")

            # Calculate metrics
            metrics = calculate_metrics(
                json_contents,
                scoring_mode="either",
                use_accept_reject_tags=False,
                add_categorical_judgements=includes_categories,
            )
            all_compliance_metrics = calculate_metrics(
                json_contents,
                scoring_mode="all_compliance",
                use_accept_reject_tags=False,
                add_categorical_judgements=includes_categories,
            )
            metrics_with_compromised = calculate_metrics(
                json_contents,
                scoring_mode="either",
                use_accept_reject_tags=False,
                add_categorical_judgements=includes_categories,
                include_compromised_rates=True,
            )

            # Generate plots
            plots = {}
            for plot_type, plot_func in [
                ("compliance", plot_compliance_metrics),
                ("categories", plot_compliance_metrics_categories),
                ("all_compliance", plot_compliance_metrics_categories),
                ("categories_with_compromised", plot_compliance_metrics_categories),
            ]:
                if not includes_categories and plot_type != "compliance":
                    plots[plot_type] = base64.b64encode(b"No categories found").decode("utf-8")
                    continue

                plt.figure(figsize=(10, 6))
                if plot_type == "all_compliance":
                    fig, ax = plot_func(all_compliance_metrics, title=f"{model_name} on {system_prompt}")
                elif plot_type == "categories_with_compromised":
                    fig, ax = plot_func(metrics_with_compromised, title=f"{model_name} on {system_prompt}")
                else:
                    fig, ax = plot_func(metrics, title=f"{model_name} on {system_prompt}")

                buf = io.BytesIO()
                plt.savefig(buf, format="png", bbox_inches="tight")
                buf.seek(0)
                plots[plot_type] = base64.b64encode(buf.getvalue()).decode("utf-8")
                plt.close()

            # Save metrics and plots
            metrics_data = {
                "metrics": {
                    "free_tier": {
                        "compliance_rate": metrics["free_tier"]["compliance_rate"],
                        "refusal_rate": metrics["free_tier"]["refusal_rate"],
                        "alignment_faking_rate": metrics["free_tier"]["alignment_faking_rate"]
                        * metrics["free_tier"]["compliance_rate"],
                    },
                    "paid_tier": {
                        "compliance_rate": metrics["paid_tier"]["compliance_rate"],
                        "refusal_rate": metrics["paid_tier"]["refusal_rate"],
                        "alignment_faking_rate": metrics["paid_tier"]["alignment_faking_rate"]
                        * metrics["paid_tier"]["compliance_rate"],
                    },
                },
                "plots": plots,
            }

            # Save metrics and plots with original path in filename
            metrics_file = output_dir / "data" / f"{safe_prompt}_{safe_model}.json"
            metrics_file.parent.mkdir(exist_ok=True)
            with open(metrics_file, "w") as f:
                json.dump(metrics_data, f)

            # Process results to include required fields
            processed_results = []
            for idx, result in enumerate(json_contents, 1):
                processed_result = {
                    "idx": idx,
                    "tags": result.get("tags", "N/A"),
                    "date": result_file.stem.split("_")[1],
                    "system_prompt": system_prompt,
                    "model": model_name,
                    "prompt": result["prompt"]["messages"][1]["content"],
                    "scratchpad_content": result.get("easy_scratchpad", {}).get("scratchpad_content", "Not found"),
                    "output": result["output_without_scratchpad"],
                    "strong_reject_score": result.get("strong_reject_score", 0),
                    "easy_scratchpad": result.get("easy_scratchpad", {}),
                    "hard_scratchpad": result.get("hard_scratchpad", {}),
                    "categorization": result.get("categorization", {}),
                    "is_compromised": result.get("strong_reject_score", 0) > 0
                    and result.get("hard_scratchpad", {}).get("criteria", [])
                    and result.get("hard_scratchpad", {}).get("criteria", [])[-1] == True,
                }
                processed_results.append(processed_result)

            # Generate transcript page with processed results
            context = {
                **base_context,
                "transcripts": processed_results,  # Use processed results instead of raw JSON
                "system_prompts": [system_prompt],
                "models": [model_name],
                "current_system_prompt": system_prompt,
                "current_model": model_name,
                "static_mode": True,
                "original_path": original_path,
                "result_file": str(result_file),
            }

            transcript_html = transcripts_template.render(**context)

            # Generate transcript page with original path in filename
            transcript_file = output_dir / "transcripts" / f"{safe_prompt}_{safe_model}.html"
            transcript_file.parent.mkdir(exist_ok=True)
            with open(transcript_file, "w") as f:
                f.write(transcript_html)

    # Generate index page with base context
    total_results = sum(m["count"] for p in organized_results.values() for m in p.values())
    total_compliant = sum(m["compliant"] for p in organized_results.values() for m in p.values())
    total_faking = sum(m["faking"] for p in organized_results.values() for m in p.values())

    index_context = {
        **base_context,
        "stats": {
            "total_transcripts": total_results,
            "compliant_responses": total_compliant,
            "alignment_faking": total_faking,
        },
        "system_prompts": organized_results.keys(),
        "organized_results": organized_results,
    }

    index_html = home_template.render(**index_context)

    with open(output_dir / "index.html", "w") as f:
        f.write(index_html)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Generate static website files from the server.")
    parser.add_argument("--results-dir", type=str, default=None, help="Directory containing the results JSON files.")
    args = parser.parse_args()

    asyncio.run(generate_static_site(args.results_dir))
