import argparse
import json
import re
from collections import defaultdict
from datetime import datetime
from pathlib import Path
from typing import Dict, List

import matplotlib.pyplot as plt
import numpy as np

from plot_style import FONT_SIZES, apply_global_style, line_colors

def load_runs(runs_dir: Path) -> List[dict]:
    run_files = sorted(runs_dir.glob("*.json"))
    if not run_files:
        raise FileNotFoundError(f"No JSON run files in {runs_dir}")

    runs: List[dict] = []
    for path in run_files:
        with path.open() as f:
            try:
                runs.append(json.load(f))
            except json.JSONDecodeError as exc:
                print(f"Skipping {path}: {exc}")
    return runs


def find_latest_trends_dir(runs_root: Path, weight_type: str) -> Path:
    suffix = f"trends_T_{weight_type}"
    trend_dirs = sorted(
        [p for p in runs_root.iterdir() if p.is_dir() and p.name.endswith(suffix)],
        key=lambda p: p.name,
    )

    if not trend_dirs:
        raise FileNotFoundError(
            f"No subdirectories ending with '{suffix}' found under {runs_root}."
        )

    return trend_dirs[-1]


def select_latest_runs(runs: List[dict]) -> Dict[str, Dict[int, dict]]:
    grouped: Dict[str, Dict[int, dict]] = defaultdict(dict)
    for run in runs:
        config = run.get("config", {})
        objective = config.get("objective")
        num_alloc = config.get("num_alloc")
        if objective is None or num_alloc is None:
            continue
        existing = grouped[objective].get(num_alloc)
        if existing is None or run.get("timestamp", "") > existing.get("timestamp", ""):
            grouped[objective][num_alloc] = run
    return grouped


def extract_series(run: dict) -> tuple[List[float], List[float], List[float] | None]:
    results = run.get("results", {})
    time_steps = results.get("time_steps")
    mean_regret = results.get("mean_regret")
    std_regret = results.get("std_regret")
    if time_steps is None or mean_regret is None:
        raise ValueError("Run is missing time_steps or mean_regret")
    return time_steps, mean_regret, std_regret


def sanitize_objective(name: str) -> str:
    return re.sub(r"[^\w.-]+", "_", name).strip("_") or "objective"


def build_output_path(plots_dir: Path, objective: str) -> Path:
    safe_objective = sanitize_objective(objective)
    return plots_dir / f"{safe_objective}_trends_T.pdf"


def build_legend_path(plots_dir: Path) -> Path:
    return plots_dir / "legend_trends_T.pdf"


def add_plot_metadata(
    fig: plt.Figure,
    run_ids: List[str],
    analysis_timestamp: str,
    max_ids_per_line: int = 3,
) -> None:
    id_chunks = [
        run_ids[i:i + max_ids_per_line]
        for i in range(0, len(run_ids), max_ids_per_line)
    ]
    formatted_lines = [", ".join(chunk) for chunk in id_chunks]
    run_ids_text = "Run IDs: " + "\n         ".join(formatted_lines)

    fig.text(
        0.02,
        0.02,
        run_ids_text,
        fontsize=FONT_SIZES["metadata"],
        color="gray",
        verticalalignment="bottom",
        horizontalalignment="left",
        family="monospace",
    )

    timestamp_text = f"created: {analysis_timestamp}"
    fig.text(
        0.98,
        0.02,
        timestamp_text,
        fontsize=FONT_SIZES["metadata"],
        color="gray",
        verticalalignment="bottom",
        horizontalalignment="right",
        style="italic",
    )


def save_common_legend(
    handles: List[plt.Line2D],
    labels: List[str],
    output_path: Path,
    title: str,
) -> None:
    if not handles:
        return

    unique: Dict[str, plt.Line2D] = {}
    for handle, label in zip(handles, labels):
        if not label or label.startswith("_"):
            continue
        if label not in unique:
            unique[label] = handle

    legend_handles = list(unique.values())
    legend_labels = list(unique.keys())
    if not legend_labels:
        return

    max_label_len = max(len(label) for label in legend_labels)
    fig_width = max(3.0, 0.18 * max_label_len + 1.8)
    fig_height = max(1.4, 0.55 * len(legend_labels) + 0.8)
    fig = plt.figure(figsize=(fig_width, fig_height))
    legend = fig.legend(
        legend_handles,
        legend_labels,
        loc="center",
        ncol=1,
        frameon=False,
        title=title,
        handlelength=2.2,
        handletextpad=0.8,
        labelspacing=0.6,
        borderaxespad=0.0,
    )
    legend.get_title().set_fontsize(FONT_SIZES["legend"])
    fig.tight_layout(pad=0.6)
    fig.savefig(output_path, dpi=300, bbox_inches="tight")
    plt.close(fig)


def plot_objectives(
    data: Dict[str, Dict[int, dict]],
    plots_dir: Path,
    show: bool,
    metadata: str | None,
    analysis_timestamp: str | None,
) -> None:
    objectives = sorted(data.keys())
    if not objectives:
        raise ValueError("No runs found for plotting")

    figs: List[plt.Figure] = []
    all_handles: List[plt.Line2D] = []
    all_labels: List[str] = []

    for objective in objectives:
        fig, ax = plt.subplots(figsize=(7, 5))
        runs_by_alloc = data[objective]
        if not runs_by_alloc:
            plt.close(fig)
            continue

        figs.append(fig)

        num_alloc_values = sorted(runs_by_alloc.keys())
        colors = line_colors(len(num_alloc_values))

        for color, num_alloc in zip(colors, num_alloc_values):
            run = runs_by_alloc[num_alloc]
            time_steps, mean_regret, std_regret = extract_series(run)
            normalized_regret = [
                (m / np.sqrt(t)) if t > 0 else np.nan
                for t, m in zip(time_steps, mean_regret)
            ]
            ax.loglog(time_steps, normalized_regret, label=f"{num_alloc}", color=color, marker="o")

            if std_regret is not None:
                lower = [
                    max((m - s) / np.sqrt(t), 1e-12) if t > 0 else np.nan
                    for t, m, s in zip(time_steps, mean_regret, std_regret)
                ]
                upper = [
                    max((m + s) / np.sqrt(t), 1e-12) if t > 0 else np.nan
                    for t, m, s in zip(time_steps, mean_regret, std_regret)
                ]
                ax.fill_between(time_steps, lower, upper, color=color, alpha=0.2, linewidth=0)

        ax.set_xlabel("Time Steps (T)")
        ax.set_ylabel(r"$\mathrm{Regret}(T) / \sqrt{T}$")
        ax.grid(True, which="both", ls=":", alpha=0.5)

        handles, labels = ax.get_legend_handles_labels()
        all_handles.extend(handles)
        all_labels.extend(labels)

        if metadata is not None and analysis_timestamp is not None:
            plt.subplots_adjust(bottom=0.15)
            add_plot_metadata(fig, [metadata], analysis_timestamp)
            fig.tight_layout(rect=[0, 0.08, 1, 1])
        else:
            fig.tight_layout()

        output_path = build_output_path(plots_dir, objective)
        fig.savefig(output_path, dpi=300)
        print(f"Saved figure to {output_path}")

    if show and figs:
        plt.show()
    else:
        for fig in figs:
            plt.close(fig)

    save_common_legend(all_handles, all_labels, build_legend_path(plots_dir), title="k")


def main() -> None:
    apply_global_style()
    parser = argparse.ArgumentParser(description="Plot regret trends from logged runs.")
    parser.add_argument("--runs-dir", type=Path, default=Path("runs"), help="Directory containing run JSON files.")
    parser.add_argument(
        "--weight-type",
        choices=("linear", "geometric"),
        default="linear",
        help="Choose which weight type trends folder suffix to load.",
    )
    parser.add_argument(
        "--output-prefix",
        type=Path,
        default=None,
        help="(Deprecated) Optional prefix for saving plots.",
    )
    parser.add_argument("--add-metadata", action="store_true", help="Include run metadata below plots.")
    parser.add_argument("--no-show", action="store_true", help="Skip displaying the plot window (useful for headless runs).")

    args = parser.parse_args()
    runs_dir: Path = args.runs_dir

    if not runs_dir.exists():
        raise FileNotFoundError(f"Runs directory not found: {runs_dir}")

    latest_dir = find_latest_trends_dir(runs_dir, args.weight_type)
    print(f"Loading runs from {latest_dir}")

    runs = load_runs(latest_dir)
    data = select_latest_runs(runs)
    plots_dir = Path("plots") / f"{args.weight_type}-weights"
    plots_dir.mkdir(parents=True, exist_ok=True)
    analysis_timestamp = datetime.now().strftime("%Y-%m-%d_%H.%M.%S") if args.add_metadata else None
    metadata = latest_dir.name if args.add_metadata else None
    plot_objectives(data, plots_dir, show=not args.no_show, metadata=metadata, analysis_timestamp=analysis_timestamp)


if __name__ == "__main__":
    main()
