import argparse, json, random
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

def infer_cfg(folder: Path):
    first = next(folder.glob("*.json"))
    prefix = first.name.split("_")[0].lower()
    if prefix == "ti2i":
        return dict(bins=10, lam_k=1.8, eps_k=0.05)
    else:

        return dict(bins=7, lam_k=1.5, eps_k=0.10)


def load_folder(folder: Path, key: str = "element_count") -> pd.DataFrame:
    records = []
    for fp in folder.glob("*.json"):
        with open(fp, "r", encoding="utf-8") as f:
            data = json.load(f)
        if isinstance(data, list):
            elem_cnt = len(data)
        elif isinstance(data, dict) and key in data:
            elem_cnt = data[key]
        else:
            raise ValueError(f"Unsupported JSON format in {fp}")
        records.append({"__file": fp.name, key: elem_cnt})
    return pd.DataFrame(records)


def dataset_stats(df: pd.DataFrame, key: str = "element_count") -> dict:
    vals = df[key].to_numpy()
    meta = {
        "count": int(vals.size),
        "mean": float(vals.mean()),
        "std": float(vals.std(ddof=0)),
        "min": float(vals.min()),
        "max": float(vals.max()),
        "q05": float(np.quantile(vals, 0.05)),
        "q25": float(np.quantile(vals, 0.25)),
        "q50": float(np.quantile(vals, 0.50)),
        "q75": float(np.quantile(vals, 0.75)),
        "q95": float(np.quantile(vals, 0.95)),
    }
    return meta

def balanced_sample(
    df: pd.DataFrame,
    n: int,
    key: str = "element_count",
    bins: int = 5,
    seed: int = 42,
    lam_k: float = 1.0,      # <<< λ
    eps_k: float = 0.10      # <<< ε
) -> list[str]:
    rng = random.Random(seed)

    mu_pop = df[key].mean()
    sig_pop = df[key].std(ddof=0)
    lam = lam_k * sig_pop / (mu_pop + 1e-6)           # <<< lam_k
    eps = eps_k * sig_pop * (20 / n)                  # <<< eps_k

    # stratification
    df = df.copy()
    bins = min(bins, n)
    df["bin"] = pd.qcut(df[key], q=bins, labels=False, duplicates="drop")

    base, rem = divmod(n, bins)
    chosen = []
    for b in range(bins):
        idx = df.index[df["bin"] == b].tolist()
        rng.shuffle(idx)
        take = base + (1 if b < rem else 0)
        chosen.extend(idx[:take])

    sub = df.loc[chosen]
    not_chosen = list(set(df.index) - set(chosen))

    def J(sdf: pd.DataFrame) -> float:
        return abs(sdf[key].mean() - mu_pop) + \
               lam * abs(sdf[key].std(ddof=0) - sig_pop)

    while J(sub) > eps and not_chosen:
        out_idx = rng.choice(sub.index.tolist())
        in_idx  = rng.choice(not_chosen)
        new_sub = pd.concat([sub.drop(out_idx), df.loc[[in_idx]]])
        if J(new_sub) < J(sub):
            sub = new_sub
            not_chosen.remove(in_idx)
            not_chosen.append(out_idx)
        else:
            break
    return sub["__file"].tolist()


def verify_folder(folder: Path, R: int, seed: int):
    cfg = infer_cfg(folder)
    df_meta = load_folder(folder)
    meta = dataset_stats(df_meta)  

    print(f"\n=== Dataset summary: {folder.name} ===")
    print(f"count={meta['count']}  mean={meta['mean']:.2f}  std={meta['std']:.2f}  "
          f"min={meta['min']:.0f}  q25={meta['q25']:.0f}  med={meta['q50']:.0f}  "
          f"q75={meta['q75']:.0f}  max={meta['max']:.0f}\n")

    meta_path = f"metrics_{folder.name}_meta.json"
    with open(meta_path, "w", encoding="utf-8") as f:
        json.dump(meta, f, ensure_ascii=False, indent=2)
    print("Saved dataset meta to", meta_path)

    pop_mean = meta["mean"]

    n_list, table = [5, 6, 10, 12, 15, 20], []
    for n in n_list:
        batch_means = []
        for i in range(R):
            picks = balanced_sample(df_meta, n=n, seed=seed + i,
                                    bins=cfg["bins"],
                                    lam_k=cfg["lam_k"],
                                    eps_k=cfg["eps_k"])
            counts = [df_meta.loc[df_meta["__file"] == fn,
                                  "element_count"].item() for fn in picks]
            batch_means.append(np.mean(counts))
        batch_means = np.array(batch_means)
        table.append({
            "n": n,
            "delta_mean": abs(batch_means.mean() - pop_mean),
            "sigma_mean": batch_means.std(),
            "worst_gap":  abs(batch_means - pop_mean).max(),
            "pop_mean": meta["mean"],
            "pop_std":  meta["std"]
        })
        print(f"n={n:2d} | δ̄={table[-1]['delta_mean']:.2f}  "
              f"σ={table[-1]['sigma_mean']:.2f}  "
              f"worst={table[-1]['worst_gap']:.2f}")

    df_out = pd.DataFrame(table)
    csv_path = f"metrics_{folder.name}.csv"
    df_out.to_csv(csv_path, index=False)
    print("Saved MC metrics to", csv_path)

    K = 6
    bins_edges = pd.qcut(df_meta["element_count"], q=K,
                         duplicates="drop", retbins=True)[1]
    heat_mat = np.zeros((len(n_list), K))
    for i, n in enumerate(n_list):
        bin_counts = np.zeros(K)
        for r in range(R):
            picks = balanced_sample(df_meta, n=n, seed=seed + r,
                                    bins=cfg["bins"],
                                    lam_k=cfg["lam_k"],
                                    eps_k=cfg["eps_k"])
            vals = df_meta.loc[df_meta["__file"].isin(picks),
                               "element_count"].values
            hist, _ = np.histogram(vals, bins=bins_edges)
            bin_counts += hist
        heat_mat[i] = bin_counts / max(bin_counts.sum(), 1)

    fig, axs = plt.subplots(1, 3, figsize=(14, 4))
    counts_all = df_meta["element_count"]


    axs[0].hist(counts_all, bins="auto", density=True,
                color="steelblue", alpha=0.6, edgecolor="black", label="hist")
    counts_all.plot(kind="kde", ax=axs[0], lw=2,
                    color="darkorange", label="KDE")

    axs[0].axvline(meta["mean"], linestyle="--", linewidth=1.5, label="mean")
    axs[0].axvline(meta["mean"] - meta["std"], linestyle=":", linewidth=1.0)
    axs[0].axvline(meta["mean"] + meta["std"], linestyle=":", linewidth=1.0)
    axs[0].set_xlabel("element count"); axs[0].set_ylabel("density")
    axs[0].set_title(f"Distribution ({len(df_meta)})\nμ={meta['mean']:.2f}, σ={meta['std']:.2f}")
    axs[0].legend()


    axs[1].plot(df_out["n"], df_out["sigma_mean"], marker="o",
                label=r"$\sigma_{\mathrm{mean}}$")
    axs[1].plot(df_out["n"], df_out["worst_gap"], marker="s",
                label="worst gap")
    axs[1].set_xlabel("batch size n"); axs[1].set_ylabel("difficulty gap")
    axs[1].set_title("Sampling stability")
    axs[1].legend(); axs[1].grid(alpha=0.3)

    sns.heatmap(heat_mat, annot=True, fmt=".2f", cmap="YlGnBu",
                xticklabels=[f"{int(bins_edges[j])}-{int(bins_edges[j+1])}"
                             for j in range(K)],
                yticklabels=[str(n) for n in n_list], ax=axs[2])
    axs[2].set_xlabel("element-count bin"); axs[2].set_ylabel("batch size n")
    axs[2].set_title("Difficulty coverage")

    plt.tight_layout()
    out_png = f"metrics_{folder.name}_f.png"
    plt.savefig(out_png, dpi=600)
    print("Saved figure to", out_png)

if __name__ == "__main__":
    ap = argparse.ArgumentParser()
    ap.add_argument("--folder", type=str, default="datasets/t2i")
    ap.add_argument("-R", type=int, default=100)
    ap.add_argument("--seed", type=int, default=48)
    args = ap.parse_args()

    verify_folder(Path(args.folder), R=args.R, seed=args.seed)


# python verify_sampler.py --folder T2I_GT -R 100
# python verify_sampler.py --folder TI2I_GT -R 100