import os, csv
import matplotlib.pyplot as plt
from utils import config_id, results_paths

def _load_round_csv(path):
    rounds, loss, acc = [], [], []
    rc, re, rs, cc, ce, cs, ct = [], [], [], [], [], [], []
    with open(path, "r") as f:
        r = csv.DictReader(f)
        for row in r:
            rounds.append(int(row["round"]))
            loss.append(float(row["loss"]))
            acc.append(float(row["acc"]))
            rc.append(float(row["round_client_time_sec"]))
            re.append(float(row["round_each_client_time_sec"]))
            rs.append(float(row["round_server_time_sec"]))
            cc.append(float(row["cum_client_time_sec"]))
            ce.append(float(row["cum_each_client_time_sec"]))
            cs.append(float(row["cum_server_time_sec"]))
            ct.append(float(row["cum_time_sec"]))
    return {
        "round": rounds, "loss": loss, "acc": acc,
        "round_client_time_sec": rc, "round_each_client_time_sec": re, "round_server_time_sec": rs,
        "cum_client_time_sec": cc,  "cum_each_client_time_sec": ce,  "cum_server_time_sec": cs,
        "cum_time_sec": ct
    }

def _csv_path_from_cfg(cfg):
    """cfg → (eid, csv_path)"""
    eid = config_id(cfg)
    path = results_paths(cfg)["round_csv"]
    return eid, path if os.path.exists(path) else (eid, None)

def aggregate_group(cfg_list):
    import numpy as np
    runs = []
    for cfg in cfg_list:
        _, p = _csv_path_from_cfg(cfg)
        if p:
            runs.append(_load_round_csv(p))
    if not runs:
        raise FileNotFoundError("no CSVs found for given cfg_list")

    L = min(len(r["round"]) for r in runs)

    def stack(key):
        return np.stack([r[key][:L] for r in runs], axis=0)

    rounds = runs[0]["round"][:L]

    acc_mat   = stack("acc")
    loss_mat  = stack("loss")
    rcli_mat  = stack("round_client_time_sec")
    rsvr_mat  = stack("round_server_time_sec")
    ctime_mat = stack("cum_time_sec")
    ccli_mat  = stack("cum_client_time_sec")
    csvr_mat  = stack("cum_server_time_sec")

    return {
        "round": rounds,
        "acc_mean":  acc_mat.mean(0).tolist(),
        "acc_std":   acc_mat.std(0, ddof=0).tolist(),
        "loss_mean": loss_mat.mean(0).tolist(),
        "round_client_mean": rcli_mat.mean(0).tolist(),
        "round_server_mean": rsvr_mat.mean(0).tolist(),
        "cum_time_mean":     ctime_mat.mean(0).tolist(),
        "cum_client_mean":   ccli_mat.mean(0).tolist(),
        "cum_server_mean":   csvr_mat.mean(0).tolist(),
        "n_runs": len(runs),
    }


def print_group_stats(stats, label=None):
    lbl = (label or f"{stats.get('n_runs', '?')} runs")
    print(f"[stats] {lbl}")
    for r, acc, std, ctm in zip(stats["round"], stats["acc_mean"], stats["acc_std"], stats["cum_time_mean"]):
        print(f"  round {r:>3d} | acc={acc:6.2f}% ± {std:4.2f} | cum_time={ctm:7.2f}s")

def plot_groups(list_of_cfg_lists, labels=None,
                title="Group Comparison", plot_name=None,
                show=True, save_path=None):
    """
    list_of_cfg_lists: [[cfg1,cfg2,...], [cfgA,cfgB,...], ...]
    plot_name: 저장될 파일명 (확장자 제외). None이면 저장 안 함.
    """
    group_stats = []
    labels = labels or [f"group{i+1}" for i in range(len(list_of_cfg_lists))]
    for i, cfgs in enumerate(list_of_cfg_lists):
        st = aggregate_group(cfgs)
        group_stats.append(st)
        print_group_stats(st, label=labels[i])

    fig, axes = plt.subplots(1, 2, figsize=(12, 5), dpi=120)

    ax = axes[0]
    for st, lab in zip(group_stats, labels):
        ax.plot(
            st["round"], st["acc_mean"],
            linestyle="-", linewidth=2, marker="o", markersize=3,
            label=lab
        )
    ax.set_title(f"{title} — Accuracy")
    ax.set_xlabel("Round"); ax.set_ylabel("Accuracy (%)")
    ax.grid(True); ax.legend(loc="best")

    ax = axes[1]
    for st, lab in zip(group_stats, labels):
        ax.plot(
            st["round"], st["cum_time_mean"],
            linestyle="-", linewidth=2, marker="o", markersize=3,
            label=lab
        )
    ax.set_title(f"{title} — Cumulative Time")
    ax.set_xlabel("Round"); ax.set_ylabel("Time (s)")
    ax.grid(True); ax.legend(loc="best")

    if plot_name:
        os.makedirs("./plots", exist_ok=True)
        out_path = os.path.join("./plots", f"{plot_name}.png")
        fig.savefig(out_path, bbox_inches="tight", dpi=300)
        print(f"[viz] saved plot to {out_path}")
    elif save_path:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        fig.savefig(save_path, bbox_inches="tight", dpi=300)
        print(f"[viz] saved plot to {save_path}")

    if show:
        plt.show()
    else:
        plt.close(fig)

