import argparse, json, sys, re
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

def normalize_dataset_name(s: str, fallback_from_path: str = "") -> str:
    if isinstance(s, str) and s.strip():
        t = re.sub(r"[^a-z0-9]+", "", s.strip().lower())
    else:
        t = ""
    if not t:
        # infer from filename if we can
        fname = fallback_from_path.lower()
        if "cifar" in fname:
            return "cifar"
        if "eurosat" in fname:
            return "eurosat"
        return "unknown"
    if t in ("cifar10","cifar010","cifar"):
        return "cifar"
    if t in ("eurosat","eurosatallbands"):
        return "eurosat"
    return t

def pick_col(df, candidates):
    for c in candidates:
        if c in df.columns:
            return c
    return None

def safe_kl(p, q, eps=1e-12):
    p = np.clip(p, eps, 1.0); q = np.clip(q, eps, 1.0)
    p = p / p.sum(); q = q / q.sum()
    return float(np.sum(p * (np.log(p) - np.log(q))))

def jsd(p, q, eps=1e-12):
    m = 0.5*(p+q)
    return 0.5*safe_kl(p, m, eps) + 0.5*safe_kl(q, m, eps)

def tvd(p, q):
    return 0.5*float(np.abs(p-q).sum())

def load_rows(csv_path: Path) -> pd.DataFrame:
    if not csv_path.exists():
        print(f"[plot] ERROR: file not found: {csv_path}", flush=True)
        return pd.DataFrame()
    try:
        df = pd.read_csv(csv_path)
    except Exception as e:
        print(f"[plot] ERROR reading {csv_path}: {e}", flush=True)
        return pd.DataFrame()

    print(f"[plot] Loaded {csv_path} with {len(df)} rows and columns: {list(df.columns)}", flush=True)

    # figure out columns
    p_true_col = pick_col(df, ["p_true_json","p_true","ptrue_json","ptrue"])
    p_meas_col = pick_col(df, ["p_meas_json","p_meas","pmeas_json","pmeas","phat_json","phat"])
    shots_col  = pick_col(df, ["shots","num_shots","nshots"])
    dataset_col= pick_col(df, ["dataset","data_set","ds"])

    if not p_true_col or not p_meas_col:
        print(f"[plot] ERROR: Could not find p_true/p_meas JSON columns in {csv_path}.", flush=True)
        return pd.DataFrame()
    if not shots_col:
        print(f"[plot] ERROR: Could not find shots column in {csv_path}.", flush=True)
        return pd.DataFrame()

    rows = []
    skipped = 0
    for _, r in df.iterrows():
        try:
            p_true = np.array(json.loads(r[p_true_col]), dtype=float)
            p_meas = np.array(json.loads(r[p_meas_col]), dtype=float)
            if p_true.shape != p_meas.shape or p_true.ndim != 1:
                skipped += 1; continue
            dset = normalize_dataset_name(str(r[dataset_col]) if dataset_col else "", str(csv_path))
            shots = int(r[shots_col])
            rows.append({"dataset": dset,
                         "shots": shots,
                         "tvd": tvd(p_meas, p_true),
                         "jsd": jsd(p_meas, p_true)})
        except Exception as e:
            skipped += 1
            continue

    out = pd.DataFrame(rows)
    print(f"[plot] Parsed {len(out)} rows (skipped {skipped}) from {csv_path}", flush=True)
    if len(out):
        print(f"[plot] Datasets found: {sorted(out['dataset'].unique())}", flush=True)
        for d in sorted(out['dataset'].unique()):
            ss = sorted(out[out['dataset']==d]['shots'].unique())
            print(f"[plot]  {d}: shots groups = {ss}", flush=True)
    return out

def plot_metric_vs_shots(df_all, metric, out_prefix):
    plt.rcParams.update({"font.size": 16})
    any_written = False
    for dset, dfD in df_all.groupby("dataset"):
        if dset == "unknown":
            print("[plot] WARNING: dataset 'unknown' — filenames lacked hints; skipping.", flush=True)
            continue
        g = dfD.groupby("shots")[metric].agg(["mean","std","count"]).reset_index()
        if g.empty:
            print(f"[plot] WARNING: no data for dataset={dset}, metric={metric}", flush=True)
            continue
        x = g["shots"].to_numpy()
        y = g["mean"].to_numpy()
        s = g["std"].to_numpy()
        idx = np.argsort(x); x,y,s = x[idx],y[idx],s[idx]
        marker = "." if dset=="cifar" else "|"

        plt.figure()
        plt.errorbar(x, y, yerr=s, marker=marker, capsize=4, linestyle="-", linewidth=1.5)
        plt.xlabel("shots")
        ylabel = "Total Variation Distance (TVD)" if metric=="tvd" else "Jensen–Shannon divergence"
        plt.ylabel(ylabel)
        out_pdf = Path(f"{out_prefix}_{metric}_vs_shots_{dset}.pdf")
        plt.tight_layout(); plt.savefig(out_pdf, bbox_inches="tight")
        print(f"[plot] wrote {out_pdf}", flush=True)
        any_written = True

        # also dump summary csv
        out_csv = Path(f"{out_prefix}_{metric}_vs_shots_{dset}_summary.csv")
        g.loc[idx].to_csv(out_csv, index=False)
        print(f"[plot] wrote {out_csv}", flush=True)

    if not any_written:
        print("[plot] NOTE: no plots were written; check dataset names/columns above.", flush=True)

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--csv", nargs="+", required=True, help="one or more CSV files")
    ap.add_argument("--out_prefix", default="plots", help="prefix for output filenames")
    args = ap.parse_args()

    frames = []
    for p in args.csv:
        frames.append(load_rows(Path(p).expanduser()))
    if not frames:
        print("[plot] ERROR: no CSVs loaded.", flush=True); sys.exit(1)
    df_all = pd.concat(frames, ignore_index=True)
    if df_all.empty:
        print("[plot] ERROR: combined dataframe is empty; check logs above.", flush=True); sys.exit(1)

    plot_metric_vs_shots(df_all, "tvd", args.out_prefix)
    plot_metric_vs_shots(df_all, "jsd", args.out_prefix)

if __name__ == "__main__":
    main()