# -*- coding: utf-8 -*-
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import re

# Load 
file_path = "./data/cifar10_hyperparamsweep_lr001_avg.pkl"
with open(file_path, "rb") as f:
    data = pickle.load(f)

if isinstance(data, (list, tuple)) and len(data) == 2:
    avg_results, std_results = data
elif isinstance(data, dict) and "avg" in data and "std" in data:
    avg_results, std_results = data["avg"], data["std"]
else:
    raise ValueError("Unexpected pickle format")

ctx_iter = 400

#  Metrics 
def calc_ce_metrics_with_uncertainty(avg_hist, std_hist, ctx_iter):
    acc_mean = avg_hist.get("acc_test", [])
    acc_std  = std_hist.get("acc_test", [])
    total = len(acc_mean)
    if total % ctx_iter != 0:
        raise ValueError(f"len(acc_test)={total} is not a multiple of ctx_iter={ctx_iter}")
    num_tasks = total // ctx_iter
    if num_tasks < 2:
        return {}, {}, [], [], [], num_tasks

    m_mean, m_std = {}, {}
    finals_mean, finals_std, mins_mean, mins_std = [], [], [], []

    for t in range(num_tasks):
        block_mean = acc_mean[t*ctx_iter:(t+1)*ctx_iter]
        block_std  = acc_std[t*ctx_iter:(t+1)*ctx_iter]

        f_mean = float(block_mean[-1]); f_std  = float(block_std[-1])
        m_mean[f"T{t+1}_final"] = f_mean; m_std[f"T{t+1}_final"] = f_std

        idx_min = int(np.argmin(block_mean))
        min_mean = float(block_mean[idx_min]); min_std  = float(block_std[idx_min])
        m_mean[f"T{t+1}_min"]   = min_mean;   m_std[f"T{t+1}_min"]   = min_std

        finals_mean.append(f_mean); finals_std.append(f_std)
        mins_mean.append(min_mean); mins_std.append(min_std)

    # avg-min-ACC
    m_mean["avg-min-ACC"] = np.mean(mins_mean[:-1])
    m_std["avg-min-ACC"]  = np.sqrt(np.sum(np.array(mins_std[:-1])**2)) / (num_tasks - 1)

    # avg-ACC
    m_mean["avg-ACC"] = np.mean(finals_mean)
    m_std["avg-ACC"]  = np.sqrt(np.sum(np.array(finals_std)**2)) / num_tasks

    # WC-ACC
    n = num_tasks
    m_mean["WC-ACC"] = (1/n)*finals_mean[-1] + (1 - 1/n)*m_mean["avg-min-ACC"]
    m_std["WC-ACC"]  = np.sqrt((1/n)**2 * finals_std[-1]**2 +
                               (1 - 1/n)**2 * m_std["avg-min-ACC"]**2)

    sg_list, sg_std_list = [], []
    sg_keys = []
    for i in range(n - 1):
        drop_mean = finals_mean[i] - mins_mean[i+1]
        drop_var  = finals_std[i]**2 + mins_std[i+1]**2

        sg_mean = drop_mean / finals_mean[i] if finals_mean[i] else np.nan
        sg_var  = (drop_var / (finals_mean[i]**2)) + ((drop_mean**2) * (finals_std[i]**2)) / (finals_mean[i]**4)
        sg_std  = np.sqrt(sg_var)

        k = f"SG_T{i+1}_to_T{i+2}"
        m_mean[k] = sg_mean
        m_std[k]  = sg_std
        sg_keys.append(k)

        sg_list.append(sg_mean)
        sg_std_list.append(sg_std)

    m_mean["Average_SG"] = np.mean(sg_list)
    m_std["Average_SG"]  = np.sqrt(np.sum(np.array(sg_std_list)**2)) / len(sg_std_list)

    return m_mean, m_std, finals_mean, finals_std, sg_keys, num_tasks

def compute_score(avg_acc, std_acc, avg_sg, std_sg, acc_scale_is_percent=True):
    acc_eff = avg_acc - std_acc
    sg_eff  = avg_sg + std_sg
    scale   = 100.0 if acc_scale_is_percent else 1.0
    return acc_eff - scale * sg_eff

# =========
example_hist = next(iter(avg_results.values()))
acc_scale_is_percent = (np.nanmax(example_hist["acc_test"]) > 2.0)

rows = []
all_sg_columns = set()
all_accTi_columns = set()

for key in avg_results:
    gamma, eta = (key if isinstance(key, tuple) else (None, None))
    avg_hist = avg_results[key]
    std_hist = std_results.get(key, {k: [0.0] * len(v) for k, v in avg_hist.items()})
    m_mean, m_std, finals_mean, finals_std, sg_keys, num_tasks = calc_ce_metrics_with_uncertainty(
        avg_hist, std_hist, ctx_iter
    )
    if not m_mean:
        continue

    for k in sg_keys:
        all_sg_columns.add(k)
        all_sg_columns.add(k + "_std")
    for t in range(num_tasks):
        all_accTi_columns.add(f"T{t+1}_final")
        all_accTi_columns.add(f"T{t+1}_final_std")

    score = compute_score(
        m_mean["avg-ACC"], m_std["avg-ACC"],
        m_mean["Average_SG"], m_std["Average_SG"],
        acc_scale_is_percent
    )

    row = {
        "gamma": gamma, "eta": eta,
        "avg-ACC": m_mean["avg-ACC"], "avg-ACC_std": m_std["avg-ACC"],
        "Average_SG": m_mean["Average_SG"], "Average_SG_std": m_std["Average_SG"],
        "WC-ACC": m_mean["WC-ACC"], "WC-ACC_std": m_std["WC-ACC"],
        "avg-min-ACC": m_mean["avg-min-ACC"], "avg-min-ACC_std": m_std["avg-min-ACC"],
        "score_raw": score
    }

    # Per-task finals
    for t in range(num_tasks):
        row[f"T{t+1}_final"]     = m_mean.get(f"T{t+1}_final", np.nan)
        row[f"T{t+1}_final_std"] = m_std.get(f"T{t+1}_final",  np.nan)

    # SGs
    for k in sg_keys:
        row[k]         = m_mean.get(k, np.nan)
        row[k + "_std"] = m_std.get(k, np.nan)

    rows.append(row)

base_cols = [
    "gamma", "eta",
    "avg-ACC", "avg-ACC_std",
    "Average_SG", "Average_SG_std",
    "WC-ACC", "WC-ACC_std",
    "avg-min-ACC", "avg-min-ACC_std",
    "score_raw"
]

# T{i}_final and T{i}_final_std
accTi_cols = sorted(
    list(all_accTi_columns),
    key=lambda s: (int(s.split("_")[0][1:]), "_std" in s)
)

# SG_Ti_to_Tj and SG_Ti_to_Tj_std
sg_pattern = re.compile(r"^SG_T(\d+)_to_T(\d+)(?:_std)?$")
def parse_sg_key(s: str):
    m = sg_pattern.match(s)
    if m:
        i = int(m.group(1))
        j = int(m.group(2))
        is_std = s.endswith("_std")
        return (i, j, is_std)
    # Fallback: push unknowns to the end, keep stable
    return (10**9, 10**9, s.endswith("_std"))

sg_cols = sorted(list(all_sg_columns), key=parse_sg_key)

# =====  DF  =====
df = pd.DataFrame(rows)
df = df.reindex(columns=base_cols + accTi_cols + sg_cols)

out_xlsx = "metrics_score_by_gamma_eta.xlsx"
df.to_excel(out_xlsx, index=False)
print(f"[OK] table saved to {out_xlsx}")

# ========= Heatmap =========
gammas = sorted(df["gamma"].dropna().unique())
etas   = sorted(df["eta"].dropna().unique())
pivot  = df.pivot(index="gamma", columns="eta", values="score_raw")

plt.figure(figsize=(7, 6))
im = plt.imshow(
    pivot.values,
    origin="lower",
    cmap="viridis",
    vmin=45, vmax=75,
    aspect="equal"
)

cbar = plt.colorbar(im, shrink=0.5, label="score")

plt.xticks(ticks=range(len(etas)), labels=[str(e) for e in etas])
plt.yticks(ticks=range(len(gammas)), labels=[str(g) for g in gammas])
plt.xlabel("eta", fontsize=16)
plt.ylabel("gamma", fontsize=16)
plt.tight_layout()
plt.show()

