import argparse
import json
from collections import defaultdict
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Tuple

import matplotlib.pyplot as plt
import numpy as np

from plot_style import FONT_SIZES, apply_global_style, line_colors

def load_runs(runs_dir: Path) -> List[dict]:
    run_files = sorted(runs_dir.glob("*.json"))
    if not run_files:
        raise FileNotFoundError(f"No JSON run files in {runs_dir}")

    runs: List[dict] = []
    for path in run_files:
        with path.open() as f:
            try:
                runs.append(json.load(f))
            except json.JSONDecodeError as exc:
                print(f"Skipping {path}: {exc}")
    return runs


def find_latest_trends_dir(runs_root: Path, weight_type: str) -> Path:
    suffix = f"trends_k_{weight_type}"
    trend_dirs = sorted(
        [p for p in runs_root.iterdir() if p.is_dir() and p.name.endswith(suffix)],
        key=lambda p: p.name,
    )

    if not trend_dirs:
        raise FileNotFoundError(
            f"No subdirectories ending with '{suffix}' found under {runs_root}."
        )

    return trend_dirs[-1]


def normalize_pow_val(value) -> float | None:
    if value is None:
        return None
    try:
        return float(value)
    except (TypeError, ValueError):
        return None


def select_latest_runs(runs: List[dict]) -> Dict[str, Dict[float | None, dict]]:
    grouped: Dict[str, Dict[float | None, dict]] = defaultdict(dict)
    for run in runs:
        config = run.get("config", {})
        objective = config.get("objective")
        if objective is None:
            continue
        pow_val = normalize_pow_val(config.get("pow_val"))
        existing = grouped[objective].get(pow_val)
        if existing is None or run.get("timestamp", "") > existing.get("timestamp", ""):
            grouped[objective][pow_val] = run
    return grouped


def extract_series(run: dict) -> Tuple[List[float], List[float], List[float] | None]:
    results = run.get("results", {})
    num_alloc = results.get("num_alloc_ls") or run.get("config", {}).get("num_alloc_ls")
    mean_regret = results.get("mean_regret")
    std_regret = results.get("std_regret")
    if num_alloc is None or mean_regret is None:
        raise ValueError("Run is missing num_alloc_ls or mean_regret")
    return num_alloc, mean_regret, std_regret


def output_path(plots_dir: Path, objective: str) -> Path:
    return plots_dir / f"{objective}_trends_k.pdf"


def legend_output_path(plots_dir: Path) -> Path:
    return plots_dir / "legend_trends_k.pdf"


def add_plot_metadata(
    fig: plt.Figure,
    run_ids: List[str],
    analysis_timestamp: str,
    max_ids_per_line: int = 3,
) -> None:
    id_chunks = [
        run_ids[i:i + max_ids_per_line]
        for i in range(0, len(run_ids), max_ids_per_line)
    ]
    formatted_lines = [", ".join(chunk) for chunk in id_chunks]
    run_ids_text = "Run IDs: " + "\n         ".join(formatted_lines)

    fig.text(
        0.02,
        0.02,
        run_ids_text,
        fontsize=FONT_SIZES["metadata"],
        color="gray",
        verticalalignment="bottom",
        horizontalalignment="left",
        family="monospace",
    )

    timestamp_text = f"created: {analysis_timestamp}"
    fig.text(
        0.98,
        0.02,
        timestamp_text,
        fontsize=FONT_SIZES["metadata"],
        color="gray",
        verticalalignment="bottom",
        horizontalalignment="right",
        style="italic",
    )


def save_common_legend(
    handles: List[plt.Line2D],
    labels: List[str],
    output_path: Path,
    title: str,
) -> None:
    if not handles:
        return

    unique: Dict[str, plt.Line2D] = {}
    for handle, label in zip(handles, labels):
        if not label or label.startswith("_"):
            continue
        if label not in unique:
            unique[label] = handle

    legend_handles = list(unique.values())
    legend_labels = list(unique.keys())
    if not legend_labels:
        return

    max_label_len = max(len(label) for label in legend_labels)
    fig_width = max(3.0, 0.18 * max_label_len + 1.8)
    fig_height = max(1.4, 0.55 * len(legend_labels) + 0.8)
    fig = plt.figure(figsize=(fig_width, fig_height))
    legend = fig.legend(
        legend_handles,
        legend_labels,
        loc="center",
        ncol=1,
        frameon=False,
        title=title,
        handlelength=2.2,
        handletextpad=0.8,
        labelspacing=0.6,
        borderaxespad=0.0,
    )
    legend.get_title().set_fontsize(FONT_SIZES["legend"])
    fig.tight_layout(pad=0.6)
    fig.savefig(output_path, dpi=300, bbox_inches="tight")
    plt.close(fig)


def plot_objective(
    runs_by_pow: Dict[float | None, dict],
    objective: str,
    plots_dir: Path,
    output: Path | None,
    show: bool,
    metadata: str | None,
    analysis_timestamp: str | None,
) -> Tuple[List[plt.Line2D], List[str]]:
    if not runs_by_pow:
        print(f"No runs found for objective '{objective}'.")
        return [], []

    pow_vals = sorted(
        runs_by_pow.keys(),
        key=lambda v: (v is None, v),
    )
    colors = line_colors(None)
    if objective == "gini":
        colors = colors[-1:]
    if objective == "kolm":
        colors[3] = colors[-2]
    colors = colors[: len(pow_vals)]

    fig, ax = plt.subplots(1, 1, figsize=(7, 5))

    for color, pow_val in zip(colors, pow_vals):
        run = runs_by_pow[pow_val]
        num_alloc, mean_regret, std_regret = extract_series(run)

        num_alloc_arr = np.array(num_alloc, dtype=float)
        mean_arr = np.array(mean_regret, dtype=float)
        std_arr = np.array(std_regret, dtype=float) if std_regret is not None else None

        if pow_val is None:
            label = "gini"
        elif pow_val == -float("inf"):
            label = r"$-\infty$"
        else:
            label = f"{pow_val:g}"

        ax.plot(num_alloc_arr, mean_arr, marker="o", color=color, label=label)
        if std_arr is not None:
            ax.fill_between(
                num_alloc_arr,
                mean_arr - std_arr,
                mean_arr + std_arr,
                color=color,
                alpha=0.2,
                linewidth=0,
            )

    ax.set_xlabel("k")
    ax.set_ylabel("Regret")
    ax.grid(True, ls=":", alpha=0.5)

    handles, labels = ax.get_legend_handles_labels()
    combined_handles = handles
    combined_labels = labels

    if metadata is not None and analysis_timestamp is not None:
        plt.subplots_adjust(bottom=0.15)
        add_plot_metadata(fig, [metadata], analysis_timestamp)
        fig.tight_layout(rect=[0, 0.08, 1, 1])
    else:
        fig.tight_layout()

    if output is not None:
        fig.savefig(output, dpi=300)
        print(f"Saved figure to {output}")

    if show:
        plt.show()
    else:
        plt.close(fig)

    return combined_handles, combined_labels


def main() -> None:
    apply_global_style()
    parser = argparse.ArgumentParser(description="Plot regret trends over num_alloc from logged runs.")
    parser.add_argument("--runs-dir", type=Path, default=Path("runs"), help="Directory containing run JSON files.")
    parser.add_argument(
        "--weight-type",
        choices=("linear", "geometric"),
        default="linear",
        help="Choose which weight type trends folder suffix to load.",
    )
    parser.add_argument("--output", type=Path, default=None, help="(Deprecated) Optional path to save the plot image.")
    parser.add_argument("--add-metadata", action="store_true", help="Include run metadata below plots.")
    parser.add_argument("--no-show", action="store_true", help="Skip displaying the plot window.")

    args = parser.parse_args()
    runs_dir: Path = args.runs_dir

    if not runs_dir.exists():
        raise FileNotFoundError(f"Runs directory not found: {runs_dir}")

    latest_dir = find_latest_trends_dir(runs_dir, args.weight_type)
    print(f"Loading runs from {latest_dir}")

    runs = load_runs(latest_dir)
    data = select_latest_runs(runs)

    plots_dir = Path("plots") / f"{args.weight_type}-weights"
    plots_dir.mkdir(parents=True, exist_ok=True)
    analysis_timestamp = datetime.now().strftime("%Y-%m-%d_%H.%M.%S") if args.add_metadata else None
    metadata = latest_dir.name if args.add_metadata else None

    all_handles: List[plt.Line2D] = []
    all_labels: List[str] = []

    for objective in ["wpm", "kolm", "gini"]:
        handles, labels = plot_objective(
            data.get(objective, {}),
            objective=objective,
            plots_dir=plots_dir,
            output=output_path(plots_dir, objective),
            show=not args.no_show,
            metadata=metadata,
            analysis_timestamp=analysis_timestamp,
        )
        all_handles.extend(handles)
        all_labels.extend(labels)

    save_common_legend(all_handles, all_labels, legend_output_path(plots_dir), title="q")


if __name__ == "__main__":
    main()
