import argparse
import glob
import json
import os
from typing import Dict, List, Tuple

import numpy as np
import matplotlib.pyplot as plt


def _load_final(exp_dir: str):
    runs_path = os.path.join(exp_dir, "runs.npz")
    keymap_path = os.path.join(exp_dir, "algo_keymap.json")
    cfg_path = os.path.join(exp_dir, "config.json")

    with open(keymap_path, "r") as f:
        keymap = json.load(f)
    with open(cfg_path, "r") as f:
        cfg = json.load(f)

    data = np.load(runs_path)

    means = {}
    ses = {}
    for algo, key in keymap.items():
        arr = data[key]  # R x T
        final = arr[:, -1]
        means[algo] = float(final.mean())
        if arr.shape[0] > 1:
            ses[algo] = float(final.std(ddof=1) / np.sqrt(arr.shape[0]))
        else:
            ses[algo] = 0.0
    return means, ses, cfg


def _filter_series(series: Dict[str, List[Tuple[float, float, float]]], keep: List[str] | None):
    if not keep:
        return series
    keep_set = set(keep)
    return {a: pts for a, pts in series.items() if a in keep_set}


def _plot_ablation(exp_dirs: List[str], param_key: str, out_path: str, algos_keep: List[str] | None) -> None:
    series = {}
    for d in exp_dirs:
        means, ses, cfg = _load_final(d)
        xval = cfg[param_key]
        for algo in means:
            series.setdefault(algo, []).append((xval, means[algo], ses[algo]))

    series = _filter_series(series, algos_keep)

    for algo in series:
        series[algo].sort(key=lambda t: t[0])

    plt.figure(figsize=(6.8, 3.2))
    markers = ["o", "s", "D", "^", "v", ">", "<", "P", "X", "*"]
    for i, (algo, pts) in enumerate(series.items()):
        xs = [p[0] for p in pts]
        ms = [p[1] for p in pts]
        ss = [p[2] for p in pts]
        (line,) = plt.plot(xs, ms, label=algo, marker=markers[i % len(markers)], linewidth=1.8, markersize=4)
        color = line.get_color()
        plt.errorbar(xs, ms, yerr=ss, fmt="none", ecolor=color, elinewidth=1.2, capsize=3, alpha=0.7)

    plt.xlabel(param_key)
    plt.ylabel("Final cumulative regret")
    plt.legend(ncol=2, loc="upper left")
    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    plt.tight_layout()
    plt.savefig(out_path, dpi=220)
    plt.close()


def _table_from_series(exp_dirs: List[str], param_key: str, algos_keep: List[str] | None):
    series = {}
    x_values = set()

    for d in exp_dirs:
        means, ses, cfg = _load_final(d)
        xval = cfg[param_key]
        x_values.add(xval)
        for algo in means:
            series.setdefault(algo, []).append((xval, means[algo], ses[algo]))

    if algos_keep:
        series = {a: pts for a, pts in series.items() if a in set(algos_keep)}

    for a in list(series.keys()):
        series[a].sort(key=lambda t: t[0])

    x_list = sorted(x_values)
    header = ["Algorithm"] + [f"{param_key}={x}" for x in x_list]
    rows = []
    for algo, pts in series.items():
        byx = {x: (m, s) for x, m, s in pts}
        row = [algo]
        for x in x_list:
            m, s = byx.get(x, (np.nan, np.nan))
            if np.isnan(m):
                row.append("")
            else:
                row.append(f"{m:.2f}±{s:.2f}")
        rows.append(row)
    return header, rows


def _save_csv(path: str, header: List[str], rows: List[List[str]]):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    with open(path, "w", encoding="utf-8") as f:
        f.write(",".join(header) + "\n")
        for r in rows:
            f.write(",".join(r) + "\n")


def _save_md(path: str, header: List[str], rows: List[List[str]]):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    with open(path, "w", encoding="utf-8") as f:
        f.write("| " + " | ".join(header) + " |\n")
        f.write("| " + " | ".join(['---'] * len(header)) + " |\n")
        for r in rows:
            f.write("| " + " | ".join(r) + " |\n")


def _save_tex(path: str, header: List[str], rows: List[List[str]], caption: str, label: str):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    with open(path, "w", encoding="utf-8") as f:
        f.write("\n".join([
            "\\begin{table}[t]",
            "\\centering",
            "\\small",
            "\\begin{tabular}{" + "l" + "c" * (len(header) - 1) + "}",
            "\\toprule",
            " & ".join(header) + " \\ ",
            "\\midrule",
        ]))
        for r in rows:
            f.write(" " + " & ".join(r) + " \\ \n")
        f.write("\\bottomrule\n\\end{tabular}\n")
        f.write(f"\\caption{{{caption}}}\n")
        f.write(f"\\label{{{label}}}\n")
        f.write("\\end{table}\n")


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--glob_eta", type=str, default=None, help="Glob for eta experiments, e.g., 'out/abl_eta_*'")
    ap.add_argument("--glob_n", type=str, default=None, help="Glob for n experiments, e.g., 'out/abl_n_*'")
    ap.add_argument("--out_eta", type=str, default="out/ablations/abl_eta.png")
    ap.add_argument("--out_n", type=str, default="out/ablations/abl_n.png")
    ap.add_argument("--algos", type=str, default=None, help="Comma-separated algorithms to include")

    ap.add_argument("--table_eta_csv", type=str, default=None)
    ap.add_argument("--table_eta_md", type=str, default=None)
    ap.add_argument("--table_eta_tex", type=str, default=None)
    ap.add_argument("--table_n_csv", type=str, default=None)
    ap.add_argument("--table_n_md", type=str, default=None)
    ap.add_argument("--table_n_tex", type=str, default=None)

    args = ap.parse_args()
    algos_keep = [tok.strip() for tok in args.algos.split(",")] if args.algos else None

    if args.glob_eta:
        ds = sorted(glob.glob(args.glob_eta))
        if ds:
            _plot_ablation(ds, "eta", args.out_eta, algos_keep)
            if args.table_eta_csv or args.table_eta_md or args.table_eta_tex:
                header, rows = _table_from_series(ds, "eta", algos_keep)
                if args.table_eta_csv:
                    _save_csv(args.table_eta_csv, header, rows)
                if args.table_eta_md:
                    _save_md(args.table_eta_md, header, rows)
                if args.table_eta_tex:
                    _save_tex(args.table_eta_tex, header, rows,
                              caption="Ablation over homophily strength $\\eta$ (final cumulative regret; mean$\\pm$SE).",
                              label="tab:abl_eta")

    if args.glob_n:
        ds = sorted(glob.glob(args.glob_n))
        if ds:
            _plot_ablation(ds, "n_users", args.out_n, algos_keep)
            if args.table_n_csv or args.table_n_md or args.table_n_tex:
                header, rows = _table_from_series(ds, "n_users", algos_keep)
                if args.table_n_csv:
                    _save_csv(args.table_n_csv, header, rows)
                if args.table_n_md:
                    _save_md(args.table_n_md, header, rows)
                if args.table_n_tex:
                    _save_tex(args.table_n_tex, header, rows,
                              caption="Ablation over number of users $n$ (final cumulative regret; mean$\\pm$SE).",
                              label="tab:abl_n")

    print("[PLOT-ABL] Done.")


if __name__ == "__main__":
    main()
