# -*- coding: utf-8 -*-
"""
Class-incremental plots
"""

import pickle
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter

# =========================
# =========================
plt.rcParams.update({
    "font.family": "sans-serif",
    "font.sans-serif": ["Arial"],
    "axes.edgecolor": "black",
    "axes.linewidth": 1.0,
    "axes.titleweight": "regular",
    "xtick.color": "black",
    "ytick.color": "black",
    "xtick.major.size": 4,
    "ytick.major.size": 0,
    "figure.dpi": 600,
})

# =========================
# Load
# =========================
with open('./data/splitMNIST_avg.pkl', 'rb') as f:
    data_L = pickle.load(f)
with open('./data/Split_cifar10_avg.pkl', 'rb') as f:
    data_M = pickle.load(f)
with open('./data/miniImageNet_avg.pkl', 'rb') as f:
    data_R = pickle.load(f)

methods = ['SGD', 'MSGD', 'ADAM', 'ENTROPY GAIN']
colors  = {
    'ADAM': 'mediumseagreen',
    'MSGD': 'royalblue',
    'ENTROPY GAIN': 'darkorange',
    'SGD': 'grey',
}

label_map = {'ADAM':'Adam', 'MSGD':'MSGD', 'ENTROPY GAIN':'NGM-SGD', 'SGD':'SGD'}

# =========================
n_L = len(data_L[0][methods[0]]['acc_test'])
n_M = len(data_M[0][methods[0]]['acc_test'])
n_R = len(data_R[0][methods[0]]['acc_test'])
it_L = np.arange(1, n_L + 1)
it_M = np.arange(1, n_M + 1)
it_R = np.arange(1, n_R + 1)

# =========================
interval_L = 200
interval_M = 400
interval_R = 200
num_tasks  = 5

# =========================
# =========================
def fmt_pct(x, _pos=None):
    return f"{x:.0f}"

def draw_task_lines(ax, interval, n_total):
    for b in range(interval, min(n_total, num_tasks*interval) + 1, interval):
        ax.axvline(b, color='slategray', linestyle='--', linewidth=1.0, zorder=0.5, alpha=0.7)

def label_tasks_under(ax, interval):
    for i in range(num_tasks):
        x = i * interval + interval / 2
        ax.text(
            x, -0.025, f"Task {i+1}",
            transform=ax.get_xaxis_transform(), ha='center', va='top',
            fontsize=12
        )

def draw_per_task_minima(ax, mu, interval, n_total, color, start_task=2):

    num_t = int(np.ceil(n_total / interval))
    for t in range(start_task, num_t + 1):
        start = (t - 1) * interval
        end   = min(t * interval, n_total)
        if start >= end: 
            continue
        seg = mu[start:end]
        if seg.size == 0 or np.all(np.isnan(seg)):
            continue
        i_local = int(np.nanargmin(seg))
        y_min   = float(seg[i_local])
        i_abs   = start + i_local + 1

        ax.hlines(
            y_min, start + 1, end,
            linestyles=(0, (3, 2)),
            linewidth=1.5,
            alpha=0.9,
            color=color,
            zorder=4
        )
        ax.plot(
            i_abs, y_min,
            marker='o', markersize=4.2,
            markeredgecolor='white', markeredgewidth=0.8,
            color=color,
            zorder=5
        )

def draw_per_task_maxima(ax, mu, interval, n_total, color, start_task=2):
    num_t = int(np.ceil(n_total / interval))
    for t in range(start_task, num_t + 1):
        start = (t - 1) * interval
        end   = min(t * interval, n_total)
        if start >= end:
            continue
        seg = mu[start:end]
        if seg.size == 0 or np.all(np.isnan(seg)):
            continue
        i_local = int(np.nanargmax(seg))
        y_max   = float(seg[i_local])
        i_abs   = start + i_local + 1

        ax.hlines(
            y_max, start + 1, end,
            linestyles=(0, (3, 2)),
            linewidth=1.5,
            alpha=0.9,
            color=color,
            zorder=4
        )
        ax.plot(
            i_abs, y_max,
            marker='o', markersize=4.2,
            markeredgecolor='white', markeredgewidth=0.8,
            color=color,
            zorder=5
        )

def plot_series_with_band(ax, x, mu, sd, color, label=None):
    ax.fill_between(x, mu - sd, mu + sd, color=color, alpha=0.15, zorder=1)
    ax.plot(x, mu, color=color, linewidth=1.0, solid_capstyle='round', zorder=3, label=label)

def tidy_axes(ax, title=None, yformatter=None, ylim=None, xlabel=None, ylabel=None, xticks=None, show_xticklabels=True):
    if title:
        ax.set_title(title, fontsize=16)
    if yformatter:
        ax.yaxis.set_major_formatter(FuncFormatter(yformatter))
    if ylim:
        ax.set_ylim(*ylim)
    if ylabel:
        ax.set_ylabel(ylabel, fontsize=14)
    if xlabel:
        ax.set_xlabel(xlabel, fontsize=14)
    if xticks is not None:
        ax.set_xticks(xticks)
    ax.tick_params(axis='y', which='both', length=0, labelcolor='black', labelsize=10)
    ax.tick_params(axis='x', which='both', width=0.8, labelcolor='black', labelsize=10)
    if not show_xticklabels:
        ax.tick_params(labelbottom=False)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

# =========================
# Fig
# =========================
fig = plt.figure(figsize=(12, 5), constrained_layout=True)
gs  = fig.add_gridspec(8, 3)

# ====== Split MNIST — Accuracy ======
ax_acc_L = fig.add_subplot(gs[0:4, 0])
draw_task_lines(ax_acc_L, interval_L, n_L)
for m in methods:
    mu = np.array(data_L[0][m]['acc_test'], dtype=float)
    sd = np.array(data_L[1][m]['acc_test'], dtype=float)
    plot_series_with_band(ax_acc_L, it_L, mu, sd, colors[m])
    draw_per_task_minima(ax_acc_L, mu, interval_L, n_L, colors[m], start_task=2)
tidy_axes(
    ax_acc_L,
    title='Split MNIST',
    yformatter=fmt_pct,
    ylim=(95, 100),
    ylabel='% Task 1\nAccuracy',
    xticks=np.arange(0, n_L + 1, interval_L),
    show_xticklabels=False
)
label_tasks_under(ax_acc_L, interval_L)

# ====== Split CIFAR-10 — Accuracy ======
ax_acc_M = fig.add_subplot(gs[0:4, 1])
draw_task_lines(ax_acc_M, interval_M, n_M)
for m in methods:
    mu = np.array(data_M[0][m]['acc_test'], dtype=float)
    sd = np.array(data_M[1][m]['acc_test'], dtype=float)
    plot_series_with_band(ax_acc_M, it_M, mu, sd, colors[m])
    draw_per_task_minima(ax_acc_M, mu, interval_M, n_M, colors[m], start_task=2)
tidy_axes(
    ax_acc_M,
    title='Split CIFAR-10',
    yformatter=fmt_pct,
    ylim=(40, 100),
    xticks=np.arange(0, n_M + 1, interval_M),
    show_xticklabels=False
)
label_tasks_under(ax_acc_M, interval_M)

# ====== Split mini-ImageNet — Accuracy ======
ax_acc_R = fig.add_subplot(gs[0:4, 2])
draw_task_lines(ax_acc_R, interval_R, n_R)
for m in methods:
    mu = np.array(data_R[0][m]['acc_test'], dtype=float)
    sd = np.array(data_R[1][m]['acc_test'], dtype=float)
    plot_series_with_band(ax_acc_R, it_R, mu, sd, colors[m])
    draw_per_task_minima(ax_acc_R, mu, interval_R, n_R, colors[m], start_task=2)
tidy_axes(
    ax_acc_R,
    title='Split mini-ImageNet',
    yformatter=fmt_pct,
    ylim=(0, 70),
    xticks=np.arange(0, n_R + 1, interval_R),
    show_xticklabels=False
)
label_tasks_under(ax_acc_R, interval_R)

# ====== Split MNIST — Loss ======
ax_loss_L = fig.add_subplot(gs[4:6, 0])
draw_task_lines(ax_loss_L, interval_L, n_L)
for m in methods:
    mu = np.array(data_L[0][m]['loss_test'], dtype=float)
    sd = np.array(data_L[1][m]['loss_test'], dtype=float)
    plot_series_with_band(ax_loss_L, it_L, mu, sd, colors[m])
    draw_per_task_maxima(ax_loss_L, mu, interval_L, n_L, colors[m], start_task=2)
tidy_axes(
    ax_loss_L,
    ylim=(0, 0.25),
    ylabel='Task 1\nLoss',
    xticks=np.arange(0, n_L + 1, interval_L),
    show_xticklabels=False
)

# ====== Split CIFAR-10 — Loss ======
ax_loss_M = fig.add_subplot(gs[4:6, 1])
draw_task_lines(ax_loss_M, interval_M, n_M)
for m in methods:
    mu = np.array(data_M[0][m]['loss_test'], dtype=float)
    sd = np.array(data_M[1][m]['loss_test'], dtype=float)
    plot_series_with_band(ax_loss_M, it_M, mu, sd, colors[m])
    draw_per_task_maxima(ax_loss_M, mu, interval_M, n_M, colors[m], start_task=2)
tidy_axes(
    ax_loss_M,
    ylim=(0, 2.0),
    xticks=np.arange(0, n_M + 1, interval_M),
    show_xticklabels=False
)

# ====== Split mini-ImageNet — Loss ======
ax_loss_R = fig.add_subplot(gs[4:6, 2])
draw_task_lines(ax_loss_R, interval_R, n_R)
for m in methods:
    mu = np.array(data_R[0][m]['loss_test'], dtype=float)
    sd = np.array(data_R[1][m]['loss_test'], dtype=float)
    plot_series_with_band(ax_loss_R, it_R, mu, sd, colors[m])
    draw_per_task_maxima(ax_loss_R, mu, interval_R, n_R, colors[m], start_task=2)
tidy_axes(
    ax_loss_R,
    ylim=(0.5, 6.0),
    xticks=np.arange(0, n_R + 1, interval_R),
    show_xticklabels=False
)

# ====== Split MNIST — Gain ======
ax_gain_L = fig.add_subplot(gs[6:8, 0])
draw_task_lines(ax_gain_L, interval_L, n_L)
mu = np.array(data_L[0]['ENTROPY GAIN']['gain_out'], dtype=float)
sd = np.array(data_L[1]['ENTROPY GAIN']['gain_out'], dtype=float)
plot_series_with_band(ax_gain_L, it_L, mu, sd, colors['ENTROPY GAIN'])
tidy_axes(
    ax_gain_L,
    ylim=(0.8, 3.5),
    ylabel='Neuronal\nGain',
    xlabel='Iteration',
    xticks=np.arange(0, n_L + 1, interval_L),
    show_xticklabels=True
)

# ====== Split CIFAR-10 — Gain ======
ax_gain_M = fig.add_subplot(gs[6:8, 1])
draw_task_lines(ax_gain_M, interval_M, n_M)
mu = np.array(data_M[0]['ENTROPY GAIN']['gain_out'], dtype=float)
sd = np.array(data_M[1]['ENTROPY GAIN']['gain_out'], dtype=float)
plot_series_with_band(ax_gain_M, it_M, mu, sd, colors['ENTROPY GAIN'])
tidy_axes(
    ax_gain_M,
    ylim=(0.8, 3.5),
    xlabel='Iteration',
    xticks=np.arange(0, n_M + 1, interval_M),
    show_xticklabels=True
)

# ====== Split mini-ImageNet — Gain ======
ax_gain_R = fig.add_subplot(gs[6:8, 2])
draw_task_lines(ax_gain_R, interval_R, n_R)
mu = np.array(data_R[0]['ENTROPY GAIN']['gain_out'], dtype=float)
sd = np.array(data_R[1]['ENTROPY GAIN']['gain_out'], dtype=float)
plot_series_with_band(ax_gain_R, it_R, mu, sd, colors['ENTROPY GAIN'])
tidy_axes(
    ax_gain_R,
    ylim=(0.8, 3.5),
    xlabel='Iteration',
    xticks=np.arange(0, n_R + 1, interval_R),
    show_xticklabels=True
)

# =========================
# Legend
# =========================
handles = [plt.Line2D([], [], color=colors[m], linewidth=2.5) for m in methods]
legend_labels = [label_map[m] for m in methods]
fig.legend(
    handles, legend_labels,
    loc='upper center',
    bbox_to_anchor=(0.5, 1.10),
    ncol=len(legend_labels),
    frameon=False,
    prop={'size': 16}
)

fig.align_ylabels([ax_gain_L, ax_loss_L, ax_acc_L])
fig.align_ylabels([ax_gain_R, ax_gain_M, ax_gain_L])

plt.subplots_adjust(top=0.90)  
plt.show()

# =========================
# Save
# =========================
filename = 'fig_classIncremental'
# fig.savefig(f'{filename}.svg', format='svg', dpi=600, bbox_inches='tight')
# fig.savefig(f'{filename}.png', dpi=600, bbox_inches='tight')
fig.savefig(f'{filename}.pdf', format='pdf', dpi=600, bbox_inches='tight')
