# -*- coding: utf-8 -*-
import os
import re
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

ctx_iter = 400
methods  = ['SGD', 'MSGD', 'Adam', 'NGM-SGD']

LR_PATHS = [
    (1.0,    "./data/cifar10_lr1_avg.pkl"),
    (0.1,    "./data/cifar10_lr01_avg.pkl"),
    (0.01,   "./data/cifar10_lr001_avg.pkl"),
    (0.001,  "./data/cifar10_lr0001_avg.pkl"),
    (0.0001, "./data/cifar10_lr00001_avg.pkl"),
]

VMIN, VMAX = 70, 90

RENAME_MAP = {"ADAM": "Adam", "ENTROPY GAIN": "NGM-SGD"}

def rename_methods_in_pickle(pkl_path):
    with open(pkl_path, "rb") as f:
        avg_results, std_results = pickle.load(f)

    for old, new in RENAME_MAP.items():
        if old in avg_results and new not in avg_results:
            avg_results[new] = avg_results.pop(old)
        if old in std_results and new not in std_results:
            std_results[new] = std_results.pop(old)

    with open(pkl_path, "wb") as f:
        pickle.dump((avg_results, std_results), f, protocol=pickle.HIGHEST_PROTOCOL)

for _, pkl_path in LR_PATHS:
    if os.path.exists(pkl_path):
        rename_methods_in_pickle(pkl_path)

def load_avg_std(pkl_path):
    with open(pkl_path, "rb") as f:
        avg_results, std_results = pickle.load(f)
    return avg_results, std_results

def calc_ce_metrics_with_uncertainty(avg_hist, std_hist, ctx_iter):
    """
    """
    acc_mean = list(avg_hist.get("acc_test", []))
    acc_std  = list(std_hist.get("acc_test", []))
    total = len(acc_mean)
    if total % ctx_iter != 0:
        raise ValueError(f"len(acc_test)={total} is not a multiple of ctx_iter={ctx_iter}")
    num_tasks = total // ctx_iter
    if num_tasks < 2:
        return {}, {}

    m_mean, m_std = {}, {}
    finals_mean, finals_std = [], []
    mins_mean, mins_std     = [], []

    for t in range(num_tasks):
        lo, hi = t*ctx_iter, (t+1)*ctx_iter
        block_mean = acc_mean[lo:hi]
        block_std  = acc_std[lo:hi] if len(acc_std) else [0.0]*ctx_iter

        f_mean = float(block_mean[-1])
        f_std  = float(block_std[-1])
        m_mean[f"T{t+1}_final"] = f_mean
        m_std[f"T{t+1}_final"]  = f_std

        idx_min  = int(np.argmin(block_mean))
        min_mean = float(block_mean[idx_min])
        min_std  = float(block_std[idx_min])

        finals_mean.append(f_mean); finals_std.append(f_std)
        mins_mean.append(min_mean);  mins_std.append(min_std)

    m_mean["avg-ACC"] = float(np.mean(finals_mean))
    m_std["avg-ACC"]  = float(np.sqrt(np.sum(np.array(finals_std, dtype=float)**2)) / len(finals_mean))

    sg_list, sg_std_list = [], []
    for i in range(num_tasks - 1):
        drop_mean = finals_mean[i] - mins_mean[i+1]
        drop_var  = finals_std[i]**2 + mins_std[i+1]**2

        sg_mean = drop_mean / finals_mean[i] if finals_mean[i] else np.nan
        sg_var  = (drop_var / (finals_mean[i]**2)) + ((drop_mean**2) * (finals_std[i]**2)) / (finals_mean[i]**4)
        sg_std  = float(np.sqrt(sg_var))

        k = f"SG_T{i+1}_to_T{i+2}"
        m_mean[k] = float(sg_mean)
        m_std[k]  = float(sg_std)

        sg_list.append(float(sg_mean))
        sg_std_list.append(float(sg_std))

    m_mean["Average_SG"] = float(np.mean(sg_list))
    m_std["Average_SG"]  = float(np.sqrt(np.sum(np.array(sg_std_list, dtype=float)**2)) / len(sg_std_list))

    return m_mean, m_std

def compute_score(avg_acc, std_acc, avg_sg, std_sg, acc_scale_is_percent=True):
    """
    score
    """
    acc_eff = avg_acc - std_acc
    return acc_eff

rows = []
all_T_cols = set()
all_SG_cols = set()

for lr, pkl_path in LR_PATHS:
    if not os.path.exists(pkl_path):
        print(f"[WARN] Missing file: {pkl_path}")
        continue

    avg_lr, std_lr = load_avg_std(pkl_path)

    first_m = next((m for m in methods if m in avg_lr), None)
    if first_m is None:
        print(f"[WARN] No known methods found in {pkl_path}")
        continue
    acc_scale_is_percent = (np.nanmax(np.array(avg_lr[first_m]["acc_test"])) > 2.0)

    for m in methods:
        if m not in avg_lr:
            continue

        avg_hist = avg_lr[m]
        std_hist = std_lr.get(m, {k: [0.0]*len(v) for k, v in avg_hist.items()})

        m_mean, m_std = calc_ce_metrics_with_uncertainty(avg_hist, std_hist, ctx_iter)
        if not m_mean:
            continue

        score = compute_score(
            m_mean["avg-ACC"], m_std["avg-ACC"],
            m_mean["Average_SG"], m_std["Average_SG"],
            acc_scale_is_percent
        )

        row = {
            "method": m,
            "lr": lr,
            "avg-ACC": m_mean["avg-ACC"], "avg-ACC_std": m_std["avg-ACC"],
            "Average_SG": m_mean["Average_SG"], "Average_SG_std": m_std["Average_SG"],
            "score_raw": score,
        }

        T_keys = sorted([k for k in m_mean if k.startswith("T") and k.endswith("_final")],
                        key=lambda s: int(re.findall(r"\d+", s)[0]))
        for tk in T_keys:
            tidx = re.findall(r"\d+", tk)[0]
            row[f"T{tidx}_final"]     = m_mean[tk]
            row[f"T{tidx}_final_std"] = m_std[tk]
            all_T_cols.add(f"T{tidx}_final"); all_T_cols.add(f"T{tidx}_final_std")

        SG_keys = [k for k in m_mean if re.match(r"^SG_T\d+_to_T\d+$", k)]
        def sg_sort_key(s):
            i, j = map(int, re.findall(r"\d+", s))
            return (i, j)
        SG_keys.sort(key=sg_sort_key)
        for sk in SG_keys:
            row[sk] = m_mean[sk]
            row[f"{sk}_std"] = m_std[sk]
            all_SG_cols.add(sk); all_SG_cols.add(f"{sk}_std")

        rows.append(row)

df = pd.DataFrame(rows)

base_cols = [
    "method", "lr",
    "avg-ACC", "avg-ACC_std",
    "Average_SG", "Average_SG_std",
    "score_raw"
]

T_cols  = sorted(list(all_T_cols),  key=lambda s: (int(re.findall(r"\d+", s)[0]), "_std" in s))
SG_cols = sorted(list(all_SG_cols), key=lambda s: (*map(int, re.findall(r"\d+", s)[:2]), s.endswith("_std")))

ordered_cols = base_cols + T_cols + SG_cols
df = df.reindex(columns=ordered_cols)

# Save to Excel
out_xlsx = "metrics_by_method_lr.xlsx"
df.to_excel(out_xlsx, index=False)
print(f"[OK] table saved to {out_xlsx}")

# ========= Heatmap =========
piv = df.pivot(index="method", columns="lr", values="score_raw")
piv = piv.reindex(columns=sorted(piv.columns))

plt.figure(figsize=(10, 4.5))
im = plt.imshow(
    piv.values,
    origin="lower",
    cmap="viridis",
    vmin=VMIN if VMIN is not None else np.nanmin(piv.values),
    vmax=VMAX if VMAX is not None else np.nanmax(piv.values),
    aspect="equal"
)
cbar = plt.colorbar(im)
cbar.set_label("score")
plt.xticks(ticks=range(len(piv.columns)), labels=[str(c) for c in piv.columns])
plt.yticks(ticks=range(len(piv.index)),   labels=list(piv.index))
plt.xlabel("learning rate")
plt.ylabel("method")
plt.tight_layout()
# plt.savefig("heatmap_score.svg", format="svg", dpi=600, bbox_inches="tight")
plt.show()