# -*- coding: utf-8 -*-
"""
Metrics
"""

import pickle
import numpy as np
import pandas as pd

RENAME = {"ADAM": "Adam", "ENTROPY GAIN": "NGM-SGD"}

def rename_methods_in_results(avg_res: dict, std_res: dict):
    for old, new in RENAME.items():
        if old in avg_res and new not in avg_res:
            avg_res[new] = avg_res.pop(old)
        if old in std_res and new not in std_res:
            std_res[new] = std_res.pop(old)

def safe_acc_std(std_hist: dict, length: int):
    acc_std = std_hist.get("acc_test", [])
    if not acc_std or len(acc_std) != length:
        return [0.0] * length
    return acc_std

def calc_ce_metrics_with_uncertainty(avg_hist, std_hist, ctx_iter):
    acc_mean = list(avg_hist.get("acc_test", []))
    total_len = len(acc_mean)
    if total_len == 0:
        return {}, {}
    acc_std = list(safe_acc_std(std_hist, total_len))
    num_tasks = total_len // ctx_iter
    if num_tasks < 2:
        return {}, {}
    m_mean, m_std = {}, {}
    finals_mean, finals_std = [], []
    mins_mean, mins_std = [], []
    for t in range(num_tasks):
        lo, hi = t * ctx_iter, (t + 1) * ctx_iter
        block_mean = acc_mean[lo:hi]
        block_std = acc_std[lo:hi]
        f_mean = float(block_mean[-1])
        f_std = float(block_std[-1])
        m_mean[f"T{t+1}_final"] = f_mean
        m_std[f"T{t+1}_final"] = f_std
        idx_min = int(np.argmin(block_mean))
        min_mean = float(block_mean[idx_min])
        min_std = float(block_std[idx_min])
        m_mean[f"T{t+1}_min"] = min_mean
        m_std[f"T{t+1}_min"] = min_std
        finals_mean.append(f_mean); finals_std.append(f_std)
        mins_mean.append(min_mean); mins_std.append(min_std)
    avg_min_acc = float(np.mean(mins_mean[1:])) if num_tasks > 1 else float(np.mean(mins_mean))
    std_min_acc = float(np.sqrt(np.sum(np.array(mins_std[1:], dtype=float)**2)) / (num_tasks - 1)) if num_tasks > 1 else float(np.sqrt(np.sum(np.array(mins_std, dtype=float)**2)) / max(1, num_tasks))
    m_mean["avg-min-ACC"] = avg_min_acc
    m_std["avg-min-ACC"] = std_min_acc
    avg_acc = float(np.mean(finals_mean))
    std_acc = float(np.sqrt(np.sum(np.array(finals_std, dtype=float)**2)) / num_tasks)
    m_mean["avg-ACC"] = avg_acc
    m_std["avg-ACC"]  = std_acc
    n = num_tasks
    wc_mean = (1/n) * finals_mean[-1] + (1 - 1/n) * avg_min_acc
    wc_std = float(np.sqrt((1/n)**2 * finals_std[-1]**2 + (1 - 1/n)**2 * std_min_acc**2))
    m_mean["WC-ACC"] = float(wc_mean)
    m_std["WC-ACC"]  = wc_std
    sg_list, sg_std_list = [], []
    for i in range(num_tasks - 1):
        fi_mean, fi_std = finals_mean[i], finals_std[i]
        mi1_mean, mi1_std = mins_mean[i+1], mins_std[i+1]
        sg_mean = (fi_mean - mi1_mean) / fi_mean if fi_mean else np.nan
        sg_var = (mi1_mean**2 / (fi_mean**4)) * (fi_std**2) + (1.0 / (fi_mean**2)) * (mi1_std**2)
        sg_std = float(np.sqrt(sg_var))
        key = f"SG_T{i+1}_to_T{i+2}"
        m_mean[key] = float(sg_mean)
        m_std[key]  = sg_std
        sg_list.append(float(sg_mean))
        sg_std_list.append(sg_std)
    if len(sg_list) > 0:
        avg_sg = float(np.mean(sg_list))
        std_sg = float(np.sqrt(np.sum(np.array(sg_std_list, dtype=float)**2)) / len(sg_std_list))
    else:
        avg_sg, std_sg = np.nan, np.nan
    m_mean["Average_SG"] = avg_sg
    m_std["Average_SG"]  = std_sg
    return m_mean, m_std

with open('./data/splitMNIST_avg.pkl', 'rb') as f:
    avg_splitMNIST, std_splitMNIST = pickle.load(f)
with open('./data/Split_cifar10_avg.pkl', 'rb') as f:
    avg_splitCIFAR10, std_splitCIFAR10 = pickle.load(f)
with open('./data/miniImageNet_avg.pkl', 'rb') as f:
    avg_splitMiniImageNet, std_splitMiniImageNet = pickle.load(f)
with open('./data/rotatedMNIST_avg.pkl', 'rb') as f:
    avg_rotatedMNIST, std_rotatedMNIST = pickle.load(f)
with open('./data/domainCIFAR100_800_avg.pkl', 'rb') as f:
    avg_domainCIFAR100, std_domainCIFAR100 = pickle.load(f)

rename_methods_in_results(avg_splitMNIST, std_splitMNIST)
rename_methods_in_results(avg_splitCIFAR10, std_splitCIFAR10)
rename_methods_in_results(avg_splitMiniImageNet, std_splitMiniImageNet)
rename_methods_in_results(avg_rotatedMNIST, std_rotatedMNIST)
rename_methods_in_results(avg_domainCIFAR100, std_domainCIFAR100)

methods_all = ['NGM-SGD', 'MSGD', 'Adam', 'SGD']
ctx_iters = {
    'splitMNIST': 200,
    'splitCIFAR10': 400,
    'splitminiImageNet': 200,
    'rotatedMNIST': 400,
    'domainCIFAR100': 800,
}
datasets = {
    'splitMNIST': (avg_splitMNIST, std_splitMNIST, r'\shortstack{Split\\MNIST}'),
    'splitCIFAR10': (avg_splitCIFAR10, std_splitCIFAR10, r'\shortstack{Split\\CIFAR-10}'),
    'splitminiImageNet': (avg_splitMiniImageNet, std_splitMiniImageNet, r'\shortstack{Split\\mini-ImageNet}'),
    'rotatedMNIST': (avg_rotatedMNIST, std_rotatedMNIST, r'\shortstack{Rotated\\MNIST}'),
    'domainCIFAR100': (avg_domainCIFAR100, std_domainCIFAR100, r'\shortstack{Domain\\CIFAR-100}'),
}

def extract_summary(avg_res, std_res, ctx_iter, methods):
    out = {}
    for m in methods:
        if m not in avg_res:
            continue
        mean_metrics, std_metrics = calc_ce_metrics_with_uncertainty(avg_res[m], std_res.get(m, {}), ctx_iter)
        if not mean_metrics:
            continue
        out[m] = {
            'avg-ACC': (mean_metrics['avg-ACC'], std_metrics['avg-ACC']),
            'avg-min-ACC': (mean_metrics['avg-min-ACC'], std_metrics['avg-min-ACC']),
            'WC-ACC': (mean_metrics['WC-ACC'], std_metrics['WC-ACC']),
            'avg-SG': (mean_metrics['Average_SG'], std_metrics['Average_SG'])
        }
    return out

def interval_overlap(a_mean, a_std, b_mean, b_std):
    a_lo, a_hi = a_mean - a_std, a_mean + a_std
    b_lo, b_hi = b_mean - b_std, b_mean + b_std
    return not (a_hi < b_lo or b_hi < a_lo)

def compute_bold_flags(summary, metric, higher_is_better=True):
    items = [(m, summary[m][metric][0], summary[m][metric][1]) for m in summary]
    if not items:
        return {}
    if higher_is_better:
        best = max(items, key=lambda x: x[1])
    else:
        best = min(items, key=lambda x: x[1])
    flags = {}
    for m, mu, sd in items:
        if interval_overlap(mu, sd, best[1], best[2]):
            if higher_is_better and mu >= best[1] - 1e-12:
                flags[m] = True
            elif not higher_is_better and mu <= best[1] + 1e-12:
                flags[m] = True
            else:
                flags[m] = True
        else:
            flags[m] = False
    return flags

def fmt(mu, sd, bold=False):
    s = f"{mu:.3f} $\\pm$ {sd:.3f}"
    return f"\\textbf{{{s}}}" if bold else s

def build_table(datasets, ctx_iters, methods):
    lines = []
    header = []
    header.append(r"\begin{table}[t]")
    header.append(r"  \centering")
    header.append(r"  \caption{\textbf{Main quantitative metrics.} For all benchmarks we report (across tasks) the average final accuracy (avg-ACC), average minimum accuracy (avg-min-ACC), average stability gap drop (avg-SG), and the worst-case accuracy (WC-ACC). Highlighted values indicate the best results; when multiple values are highlighted, it is because their standard error ranges overlap.}")
    header.append(r"  \label{tab:ce_metrics}")
    header.append(r"  \begin{adjustwidth}{-\oddsidemargin}{-\evensidemargin}")
    header.append(r"  \centering")
    header.append(r"  \begin{tabular}{@{}cccccc@{}}")
    header.append(r"    \hline \hline")
    header.append(r"     &  & \textbf{avg-ACC ($\uparrow$)} & \textbf{avg-min-ACC ($\uparrow$)} & \textbf{WC-ACC ($\uparrow$)} & \textbf{avg-SG ($\downarrow$)} \\")
    header.append(r"    \hline ")
    lines += header
    first_block = True
    for key, (avg_res, std_res, ds_label) in datasets.items():
        summary = extract_summary(avg_res, std_res, ctx_iters[key], methods)
        if not summary:
            continue
        b1 = compute_bold_flags(summary, 'avg-ACC', True)
        b2 = compute_bold_flags(summary, 'avg-min-ACC', True)
        b3 = compute_bold_flags(summary, 'WC-ACC', True)
        b4 = compute_bold_flags(summary, 'avg-SG', False)
        rows = []
        for m in methods:
            if m not in summary:
                continue
            a_mu, a_sd = summary[m]['avg-ACC']
            b_mu, b_sd = summary[m]['avg-min-ACC']
            c_mu, c_sd = summary[m]['WC-ACC']
            d_mu, d_sd = summary[m]['avg-SG']
            m_print = m + ("*" if m == "NGM-SGD" else "")
            row = f"      & {m_print} & {fmt(a_mu, a_sd, b1.get(m, False))} & {fmt(b_mu, b_sd, b2.get(m, False))} & {fmt(c_mu, c_sd, b3.get(m, False))} & {fmt(d_mu, d_sd, b4.get(m, False))} \\\\"
            rows.append(row)
        lines.append(f"    \\multirow[c]{{{len(rows)}}}{{*}}{{{ds_label}}}")
        lines.append(rows[0])
        for r in rows[1:]:
            lines.append(r)
        lines.append(r"    \hline")
    lines.append(r"    \hline")
    lines.append(r"  \end{tabular}")
    lines.append(r"  \end{adjustwidth}")
    lines.append(r"\end{table}")
    return "\n".join(lines)

latex_table = build_table(datasets, ctx_iters, methods_all)
print(latex_table)
