# -*- coding: utf-8 -*-
"""
Class- and Domain-incremental plots (test accuracy only, single panel)
"""

import pickle
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter

plt.rcParams.update({
    "font.family": "sans-serif",
    "font.sans-serif": ["Arial"],
    "axes.edgecolor": "black",
    "axes.linewidth": 1.0,
    "axes.titleweight": "regular",
    "xtick.color": "black",
    "ytick.color": "black",
    "xtick.major.size": 4,
    "ytick.major.size": 0,
    "figure.dpi": 600,
})

RENAME = {"ADAM": "Adam", "ENTROPY GAIN": "NGM-SGD", "ADAM reset": "Adam reset"}

def rename_methods_in_results(avg_dict, std_dict):
    """
    Cambia los nombres de los métodos en los diccionarios avg/std
    según el mapeo en RENAME.
    """
    for old, new in RENAME.items():
        if old in avg_dict:
            avg_dict[new] = avg_dict.pop(old)
        if old in std_dict:
            std_dict[new] = std_dict.pop(old)

with open('./data/ER_ex1_splitMNIST_1K_avg.pkl', 'rb') as f: # R_splitMNIST_avg ER_ex1_splitMNIST_1K_avg
    avg_splitMNIST, std_splitMNIST = pickle.load(f)
with open('./data/ER_ex2_splitCIFAR10_1K_avg.pkl', 'rb') as f: # R_splitCIFAR10_avg ER_ex2_splitCIFAR10_1K_avg
    avg_splitCIFAR10, std_splitCIFAR10 = pickle.load(f)
with open('./data/ER_ex3_splitminiImagenet_1K_avg.pkl', 'rb') as f: # R_splitminiImageNet_avg ER_ex3_splitminiImagenet_1K_avg
    avg_splitMiniImageNet, std_splitMiniImageNet = pickle.load(f)

rename_methods_in_results(avg_splitMNIST, std_splitMNIST)
rename_methods_in_results(avg_splitCIFAR10, std_splitCIFAR10)
rename_methods_in_results(avg_splitMiniImageNet, std_splitMiniImageNet)

data_cls_L = (avg_splitMNIST,      std_splitMNIST)
data_cls_M = (avg_splitCIFAR10,    std_splitCIFAR10)
data_cls_R = (avg_splitMiniImageNet, std_splitMiniImageNet)

with open('./data/ER_ex4_rotatedMNIST_1K_avg.pkl', 'rb') as f: # R_rotatedMNIST_avg  ER_ex4_rotatedMNIST_2K_avg
    avg_rotMNIST, std_rotMNIST = pickle.load(f)
with open('./data/ER_ex5_domainCIFAR100_1K_avg.pkl', 'rb') as f: # R_domainCIFAR100_avg ER_ex5_domainCIFAR100_1K_avg
    avg_domCIFAR, std_domCIFAR = pickle.load(f)

rename_methods_in_results(avg_rotMNIST, std_rotMNIST)
rename_methods_in_results(avg_domCIFAR, std_domCIFAR)

data_dom_L = (avg_rotMNIST, std_rotMNIST)
data_dom_R = (avg_domCIFAR, std_domCIFAR)

methods = ['SGD', 'Adam', 'Adam reset', 'MSGD', 'MSGD reset', 'NGM-SGD']

colors = {
    'NGM-SGD':    'darkorange',
    'SGD':        'grey',
    'Adam':       'mediumseagreen',
    'Adam reset': 'olive',
    'MSGD':       'royalblue',
    'MSGD reset': 'skyblue',
}

label_map = {
    'NGM-SGD':    'NGM-SGD',
    'SGD':        'SGD',
    'Adam':       'Adam',
    'Adam reset': 'Adam Reset',
    'MSGD':       'MSGD',
    'MSGD reset': 'MSGD Reset',
}


def fmt_pct(x, _pos=None):
    return f"{x:.0f}"

def draw_task_lines(ax, interval, n_total, num_tasks):
    for b in range(interval, min(n_total, num_tasks * interval) + 1, interval):
        ax.axvline(b, color='slategray', linestyle='--',
                   linewidth=1.0, zorder=0.5, alpha=0.7)

def label_tasks_under(ax, interval, num_tasks):
    for i in range(num_tasks):
        x = i * interval + interval / 2
        ax.text(
            x, -0.025, f"Task {i+1}",
            transform=ax.get_xaxis_transform(), ha='center', va='top',
            fontsize=11
        )

def draw_per_task_minima(ax, mu, interval, n_total, color, start_task=2):
    num_t = int(np.ceil(n_total / interval))
    for t in range(start_task, num_t + 1):
        start = (t - 1) * interval
        end   = min(t * interval, n_total)
        if start >= end:
            continue
        seg = mu[start:end]
        if seg.size == 0 or np.all(np.isnan(seg)):
            continue
        i_local = int(np.nanargmin(seg))
        y_min   = float(seg[i_local])
        i_abs   = start + i_local + 1

        ax.hlines(
            y_min, start + 1, end,
            linestyles=(0, (3, 2)),
            linewidth=1.5,
            alpha=0.9,
            color=color,
            zorder=4
        )
        ax.plot(
            i_abs, y_min,
            marker='o', markersize=4.2,
            markeredgecolor='white', markeredgewidth=0.8,
            color=color,
            zorder=5
        )

def plot_series_with_band(ax, x, mu, sd, color, label=None):
    ax.fill_between(x, mu - sd, mu + sd, color=color, alpha=0.15, zorder=1)
    ax.plot(x, mu, color=color, linewidth=1.0,
            solid_capstyle='round', zorder=3, label=label)

def tidy_axes(ax, title=None, yformatter=None, ylim=None,
              xlabel=None, ylabel=None, xticks=None,
              show_xticklabels=True):
    if title:
        ax.set_title(title, fontsize=18)
    if yformatter:
        ax.yaxis.set_major_formatter(FuncFormatter(yformatter))
    if ylim:
        ax.set_ylim(*ylim)
    if ylabel:
        ax.set_ylabel(ylabel, fontsize=16)
    if xlabel:
        ax.set_xlabel(xlabel, fontsize=16)
    if xticks is not None:
        ax.set_xticks(xticks)
    ax.tick_params(axis='y', which='both', length=0,
                   labelcolor='black', labelsize=9)
    ax.tick_params(axis='x', which='both', width=0.8,
                   labelcolor='black', labelsize=9)
    if not show_xticklabels:
        ax.tick_params(labelbottom=False)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)


n_L_cls = len(data_cls_L[0][next(iter(data_cls_L[0]))]['acc_test'])
n_M_cls = len(data_cls_M[0][next(iter(data_cls_M[0]))]['acc_test'])
n_R_cls = len(data_cls_R[0][next(iter(data_cls_R[0]))]['acc_test'])

it_L_cls = np.arange(1, n_L_cls + 1)
it_M_cls = np.arange(1, n_M_cls + 1)
it_R_cls = np.arange(1, n_R_cls + 1)

interval_L_cls = 200
interval_M_cls = 400
interval_R_cls = 200
num_tasks_cls  = 5

n_L_dom = len(data_dom_L[0][next(iter(data_dom_L[0]))]['acc_test'])
n_R_dom = len(data_dom_R[0][next(iter(data_dom_R[0]))]['acc_test'])

it_L_dom = np.arange(1, n_L_dom + 1)
it_R_dom = np.arange(1, n_R_dom + 1)

interval_L_dom = 400
interval_R_dom = 800
num_tasks_dom  = 3


fig = plt.figure(figsize=(12, 5), constrained_layout=True)
gs  = fig.add_gridspec(2, 6, hspace=0.05)


# Split MNIST
ax1 = fig.add_subplot(gs[0, 0:2])
draw_task_lines(ax1, interval_L_cls, n_L_cls, num_tasks_cls)
for m in methods:
    if m not in data_cls_L[0]:
        continue
    mu = np.array(data_cls_L[0][m]['acc_test'], dtype=float)
    sd = np.array(data_cls_L[1][m]['acc_test'], dtype=float)
    plot_series_with_band(ax1, it_L_cls, mu, sd, colors[m])
    draw_per_task_minima(ax1, mu, interval_L_cls, n_L_cls, colors[m], start_task=2)
tidy_axes(
    ax1,
    title='Split MNIST',
    yformatter=fmt_pct,
    ylim=(90, 100),
    ylabel='% Task 1\nAccuracy',
    xticks=np.arange(0, n_L_cls + 1, interval_L_cls),
    show_xticklabels=False
)
label_tasks_under(ax1, interval_L_cls, num_tasks_cls)

# Split CIFAR-10
ax2 = fig.add_subplot(gs[0, 2:4])
draw_task_lines(ax2, interval_M_cls, n_M_cls, num_tasks_cls)
for m in methods:
    if m not in data_cls_M[0]:
        continue
    mu = np.array(data_cls_M[0][m]['acc_test'], dtype=float)
    sd = np.array(data_cls_M[1][m]['acc_test'], dtype=float)
    plot_series_with_band(ax2, it_M_cls, mu, sd, colors[m])
    draw_per_task_minima(ax2, mu, interval_M_cls, n_M_cls, colors[m], start_task=2)
tidy_axes(
    ax2,
    title='Split CIFAR-10',
    yformatter=fmt_pct,
    ylim=(0, 100),
    xticks=np.arange(0, n_M_cls + 1, interval_M_cls),
    show_xticklabels=False
)
label_tasks_under(ax2, interval_M_cls, num_tasks_cls)

# Split mini-ImageNet
ax3 = fig.add_subplot(gs[0, 4:6])
draw_task_lines(ax3, interval_R_cls, n_R_cls, num_tasks_cls)
for m in methods:
    if m not in data_cls_R[0]:
        continue
    mu = np.array(data_cls_R[0][m]['acc_test'], dtype=float)
    sd = np.array(data_cls_R[1][m]['acc_test'], dtype=float)
    plot_series_with_band(ax3, it_R_cls, mu, sd, colors[m])
    draw_per_task_minima(ax3, mu, interval_R_cls, n_R_cls, colors[m], start_task=2)
tidy_axes(
    ax3,
    title='Split mini-ImageNet',
    yformatter=fmt_pct,
    ylim=(0, 70),
    xticks=np.arange(0, n_R_cls + 1, interval_R_cls),
    show_xticklabels=False
)
label_tasks_under(ax3, interval_R_cls, num_tasks_cls)


# Rotated MNIST
ax4 = fig.add_subplot(gs[1, 0:3])
draw_task_lines(ax4, interval_L_dom, n_L_dom, num_tasks_dom)
for m in methods:
    if m not in data_dom_L[0]:
        continue
    mu = np.array(data_dom_L[0][m]['acc_test'], dtype=float)
    sd = np.array(data_dom_L[1][m]['acc_test'], dtype=float)
    plot_series_with_band(ax4, it_L_dom, mu, sd, colors[m])
    draw_per_task_minima(ax4, mu, interval_L_dom, n_L_dom, colors[m], start_task=2)
tidy_axes(
    ax4,
    title='Rotated MNIST',
    yformatter=fmt_pct,
    ylim=(80, 100),
    ylabel='% Task 1\nAccuracy',
    xlabel='Iteration',
    xticks=np.arange(0, n_L_dom + 1, interval_L_dom),
    show_xticklabels=True
)
label_tasks_under(ax4, interval_L_dom, num_tasks_dom)

# Domain CIFAR-100
ax5 = fig.add_subplot(gs[1, 3:6])
draw_task_lines(ax5, interval_R_dom, n_R_dom, num_tasks_dom)
for m in methods:
    if m not in data_dom_R[0]:
        continue
    mu = np.array(data_dom_R[0][m]['acc_test'], dtype=float)
    sd = np.array(data_dom_R[1][m]['acc_test'], dtype=float)
    plot_series_with_band(ax5, it_R_dom, mu, sd, colors[m])
    draw_per_task_minima(ax5, mu, interval_R_dom, n_R_dom, colors[m], start_task=2)
tidy_axes(
    ax5,
    title='Domain CIFAR-100',
    yformatter=fmt_pct,
    ylim=(20, 80),
    xlabel='Iteration',
    xticks=np.arange(0, n_R_dom + 1, interval_R_dom),
    show_xticklabels=True
)
label_tasks_under(ax5, interval_R_dom, num_tasks_dom)



handles = []
labels  = []
for m in methods:
    present = (
        (m in data_cls_L[0]) or (m in data_cls_M[0]) or (m in data_cls_R[0]) or
        (m in data_dom_L[0]) or (m in data_dom_R[0])
    )
    if present:
        handles.append(plt.Line2D([], [], color=colors[m], linewidth=2.5))
        labels.append(label_map[m])

fig.legend(
    handles, labels,
    loc='upper center',
    bbox_to_anchor=(0.5, 1.1),
    ncol=len(labels),
    frameon=False,
    prop={'size': 16}
)

# plt.subplots_adjust(top=0.88)
plt.show()

# =========================
# Save
# =========================
filename = 'app_ER_testAcc'
fig.savefig(f'{filename}.pdf', format='pdf', dpi=600, bbox_inches='tight')
# fig.savefig(f'{filename}.png', dpi=600, bbox_inches='tight')
