import argparse
import json
from collections import defaultdict
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Tuple

import matplotlib.pyplot as plt
import numpy as np

from plot_style import FONT_SIZES, apply_global_style, line_colors

def load_runs(runs_dir: Path) -> List[dict]:
    run_files = sorted(runs_dir.glob("*.json"))
    if not run_files:
        raise FileNotFoundError(f"No JSON run files in {runs_dir}")

    runs: List[dict] = []
    for path in run_files:
        with path.open() as f:
            try:
                runs.append(json.load(f))
            except json.JSONDecodeError as exc:
                print(f"Skipping {path}: {exc}")
    return runs


def find_latest_trends_dir(runs_root: Path, weight_type: str) -> Path:
    suffix = f"trends_pow_{weight_type}"
    trend_dirs = sorted(
        [p for p in runs_root.iterdir() if p.is_dir() and p.name.endswith(suffix)],
        key=lambda p: p.name,
    )

    if not trend_dirs:
        raise FileNotFoundError(
            f"No subdirectories ending with '{suffix}' found under {runs_root}."
        )

    return trend_dirs[-1]


def select_latest_runs(runs: List[dict]) -> Dict[str, Dict[int, dict]]:
    grouped: Dict[str, Dict[int, dict]] = defaultdict(dict)
    for run in runs:
        config = run.get("config", {})
        objective = config.get("objective")
        num_alloc = config.get("num_alloc")
        if objective is None or num_alloc is None:
            continue
        existing = grouped[objective].get(num_alloc)
        if existing is None or run.get("timestamp", "") > existing.get("timestamp", ""):
            grouped[objective][num_alloc] = run
    return grouped


def extract_series(run: dict) -> Tuple[List[float], List[float], List[float] | None]:
    results = run.get("results", {})
    pow_vals = results.get("pow_val_ls") or run.get("config", {}).get("pow_val_ls")
    mean_regret = results.get("mean_regret")
    std_regret = results.get("std_regret")
    if pow_vals is None or mean_regret is None:
        raise ValueError("Run is missing pow_val_ls or mean_regret")
    return pow_vals, mean_regret, std_regret


def split_pow_values(
    pow_vals: np.ndarray, mean_regret: np.ndarray, std_regret: np.ndarray | None
) -> Tuple[np.ndarray, np.ndarray, np.ndarray | None, np.ndarray, np.ndarray, np.ndarray | None]:
    neg_inf_mask = np.isneginf(pow_vals)
    finite_mask = np.isfinite(pow_vals)

    pow_neg_inf = pow_vals[neg_inf_mask]
    mean_neg_inf = mean_regret[neg_inf_mask]
    std_neg_inf = std_regret[neg_inf_mask] if std_regret is not None else None

    pow_finite = pow_vals[finite_mask]
    mean_finite = mean_regret[finite_mask]
    std_finite = std_regret[finite_mask] if std_regret is not None else None

    return pow_neg_inf, mean_neg_inf, std_neg_inf, pow_finite, mean_finite, std_finite


def add_axis_break(ax_left: plt.Axes, ax_right: plt.Axes) -> None:
    ax_left.spines["right"].set_visible(False)
    ax_right.spines["left"].set_visible(False)
    ax_right.yaxis.tick_right()
    ax_right.tick_params(labelleft=False)

    d = 0.015
    kwargs = dict(transform=ax_left.transAxes, color="black", clip_on=False)
    ax_left.plot((1 - d, 1 + d), (-d, +d), **kwargs)
    ax_left.plot((1 - d, 1 + d), (1 - d, 1 + d), **kwargs)
    kwargs = dict(transform=ax_right.transAxes, color="black", clip_on=False)
    ax_right.plot((-d, +d), (-d, +d), **kwargs)
    ax_right.plot((-d, +d), (1 - d, 1 + d), **kwargs)


def output_path(plots_dir: Path, objective: str) -> Path:
    return plots_dir / f"{objective}_trends_pow.pdf"


def legend_output_path(plots_dir: Path) -> Path:
    return plots_dir / "legend_trends_pow.pdf"


def add_plot_metadata(
    fig: plt.Figure,
    run_ids: List[str],
    analysis_timestamp: str,
    max_ids_per_line: int = 3,
) -> None:
    id_chunks = [
        run_ids[i:i + max_ids_per_line]
        for i in range(0, len(run_ids), max_ids_per_line)
    ]
    formatted_lines = [", ".join(chunk) for chunk in id_chunks]
    run_ids_text = "Run IDs: " + "\n         ".join(formatted_lines)

    fig.text(
        0.02,
        0.02,
        run_ids_text,
        fontsize=FONT_SIZES["metadata"],
        color="gray",
        verticalalignment="bottom",
        horizontalalignment="left",
        family="monospace",
    )

    timestamp_text = f"created: {analysis_timestamp}"
    fig.text(
        0.98,
        0.02,
        timestamp_text,
        fontsize=FONT_SIZES["metadata"],
        color="gray",
        verticalalignment="bottom",
        horizontalalignment="right",
        style="italic",
    )


def save_common_legend(
    handles: List[plt.Line2D],
    labels: List[str],
    output_path: Path,
    title: str,
) -> None:
    if not handles:
        return

    unique: Dict[str, plt.Line2D] = {}
    for handle, label in zip(handles, labels):
        if not label or label.startswith("_"):
            continue
        if label not in unique:
            unique[label] = handle

    legend_handles = list(unique.values())
    legend_labels = list(unique.keys())
    if not legend_labels:
        return

    max_label_len = max(len(label) for label in legend_labels)
    fig_width = max(3.0, 0.18 * max_label_len + 1.8)
    fig_height = max(1.4, 0.55 * len(legend_labels) + 0.8)
    fig = plt.figure(figsize=(fig_width, fig_height))
    legend = fig.legend(
        legend_handles,
        legend_labels,
        loc="center",
        ncol=1,
        frameon=False,
        title=title,
        handlelength=2.2,
        handletextpad=0.8,
        labelspacing=0.6,
        borderaxespad=0.0,
    )
    legend.get_title().set_fontsize(FONT_SIZES["legend"])
    fig.tight_layout(pad=0.6)
    fig.savefig(output_path, dpi=300, bbox_inches="tight")
    plt.close(fig)


def plot_objective(
    runs_by_alloc: Dict[int, dict],
    objective: str,
    plots_dir: Path,
    output: Path | None,
    show: bool,
    metadata: str | None,
    analysis_timestamp: str | None,
) -> Tuple[List[plt.Line2D], List[str]]:
    if not runs_by_alloc:
        print(f"No runs found for objective '{objective}'.")
        return [], []

    num_alloc_values = sorted(runs_by_alloc.keys())
    colors = line_colors(len(num_alloc_values))

    has_neg_inf = False
    for num_alloc in num_alloc_values:
        pow_vals, mean_regret, _ = extract_series(runs_by_alloc[num_alloc])
        if np.isneginf(np.array(pow_vals, dtype=float)).any():
            has_neg_inf = True
            break

    if has_neg_inf:
        fig, (ax_left, ax_right) = plt.subplots(
            1, 2, figsize=(8, 5), sharey=True, gridspec_kw={"width_ratios": [1, 4]}
        )
        fig.subplots_adjust(wspace=0.05)
        add_axis_break(ax_left, ax_right)
        ax_left.set_xlim(-0.5, 0.5)
        ax_left.set_xticks([0.0])
        ax_left.set_xticklabels([r"$-\infty$"])
    else:
        fig, ax_right = plt.subplots(1, 1, figsize=(7, 5))
        ax_left = None

    legend_ax = ax_right

    for color, num_alloc in zip(colors, num_alloc_values):
        run = runs_by_alloc[num_alloc]
        pow_vals, mean_regret, std_regret = extract_series(run)

        pow_vals_arr = np.array(pow_vals, dtype=float)
        mean_arr = np.array(mean_regret, dtype=float)
        std_arr = np.array(std_regret, dtype=float) if std_regret is not None else None

        (
            pow_neg_inf,
            mean_neg_inf,
            std_neg_inf,
            pow_finite,
            mean_finite,
            std_finite,
        ) = split_pow_values(pow_vals_arr, mean_arr, std_arr)

        label = f"{num_alloc}"
        label_used = False

        if pow_finite.size > 0:
            ax_right.plot(pow_finite, mean_finite, marker="o", color=color, label=label)
            if std_finite is not None:
                ax_right.fill_between(
                    pow_finite,
                    mean_finite - std_finite,
                    mean_finite + std_finite,
                    color=color,
                    alpha=0.2,
                    linewidth=0,
                )
            label_used = True

        if has_neg_inf and pow_neg_inf.size > 0 and ax_left is not None:
            x_neg_inf = np.zeros_like(pow_neg_inf, dtype=float)
            if std_neg_inf is not None:
                ax_left.errorbar(
                    x_neg_inf,
                    mean_neg_inf,
                    yerr=std_neg_inf,
                    fmt="o",
                    color=color,
                    label=label if not label_used else "_nolegend_",
                )
            else:
                ax_left.plot(
                    x_neg_inf,
                    mean_neg_inf,
                    marker="o",
                    color=color,
                    label=label if not label_used else "_nolegend_",
                )
            if not label_used:
                legend_ax = ax_left

    ax_right.set_xlabel("q")
    if ax_left is None:
        ax_right.set_ylabel("Regret")
    else:
        ax_left.set_ylabel("Regret")
    ax_right.grid(True, ls=":", alpha=0.5)
    if ax_left is not None:
        ax_left.grid(True, ls=":", alpha=0.5)

    handles_right, labels_right = ax_right.get_legend_handles_labels()
    handles_left: List[plt.Line2D] = []
    labels_left: List[str] = []
    if ax_left is not None:
        handles_left, labels_left = ax_left.get_legend_handles_labels()
    handles = handles_right + handles_left
    labels = labels_right + labels_left
    combined_handles = handles
    combined_labels = labels

    if metadata is not None and analysis_timestamp is not None:
        fig.subplots_adjust(bottom=0.15)
        add_plot_metadata(fig, [metadata], analysis_timestamp)
        fig.tight_layout(rect=[0, 0.08, 1, 0.88])
        fig.subplots_adjust(top=0.88)
    else:
        fig.tight_layout()
        fig.subplots_adjust(top=0.88)

    if output is not None:
        fig.savefig(output, dpi=300)
        print(f"Saved figure to {output}")

    if show:
        plt.show()
    else:
        plt.close(fig)

    return combined_handles, combined_labels


def main() -> None:
    apply_global_style()
    parser = argparse.ArgumentParser(description="Plot regret trends over pow_val from logged runs.")
    parser.add_argument("--runs-dir", type=Path, default=Path("runs"), help="Directory containing run JSON files.")
    parser.add_argument(
        "--weight-type",
        choices=("linear", "geometric"),
        default="linear",
        help="Choose which weight type trends folder suffix to load.",
    )
    parser.add_argument("--output", type=Path, default=None, help="(Deprecated) Optional path to save the plot image.")
    parser.add_argument("--add-metadata", action="store_true", help="Include run metadata below plots.")
    parser.add_argument("--no-show", action="store_true", help="Skip displaying the plot window.")

    args = parser.parse_args()
    runs_dir: Path = args.runs_dir

    if not runs_dir.exists():
        raise FileNotFoundError(f"Runs directory not found: {runs_dir}")

    latest_dir = find_latest_trends_dir(runs_dir, args.weight_type)
    print(f"Loading runs from {latest_dir}")

    runs = load_runs(latest_dir)
    data = select_latest_runs(runs)

    plots_dir = Path("plots") / f"{args.weight_type}-weights"
    plots_dir.mkdir(parents=True, exist_ok=True)
    analysis_timestamp = datetime.now().strftime("%Y-%m-%d_%H.%M.%S") if args.add_metadata else None
    metadata = latest_dir.name if args.add_metadata else None

    all_handles: List[plt.Line2D] = []
    all_labels: List[str] = []

    for objective in ["wpm", "kolm"]:
        handles, labels = plot_objective(
            data.get(objective, {}),
            objective=objective,
            plots_dir=plots_dir,
            output=output_path(plots_dir, objective),
            show=not args.no_show,
            metadata=metadata,
            analysis_timestamp=analysis_timestamp,
        )
        all_handles.extend(handles)
        all_labels.extend(labels)

    save_common_legend(all_handles, all_labels, legend_output_path(plots_dir), title="k")


if __name__ == "__main__":
    main()
