# -*- coding: utf-8 -*-

import pickle
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter


plt.rcParams.update({
    "font.family": "sans-serif",
    "font.sans-serif": ["Arial"],
    "axes.edgecolor": "black",
    "axes.linewidth": 1.0,
    "figure.dpi": 600,
})


with open('./data/Split_cifar10_avg.pkl', 'rb') as f:
    data_C10 = pickle.load(f)


methods = ['SGD', 'MSGD', 'ADAM', 'ENTROPY GAIN']
colors = {
    'ADAM': 'mediumseagreen',
    'MSGD': 'royalblue',
    'ENTROPY GAIN': 'darkorange',
    'SGD': 'grey',
}
label_map = {
    'ADAM': 'Adam',
    'MSGD': 'MSGD',
    'ENTROPY GAIN': 'NGM-SGD',
    'SGD': 'SGD'
}

arrow_methods = ['ENTROPY GAIN', 'ADAM', 'SGD','MSGD'] 

# ----------------------------------
# Config
# ----------------------------------
interval = 400
num_tasks = 2
n_total = len(data_C10[0][methods[0]]['acc_test'])
n = interval * 2
it = np.arange(1, n + 1)

# ----------------------------------
# Helpers
# ----------------------------------
def fmt_pct(x, _pos=None):
    return f"{x:.0f}"

def draw_task_lines(ax, interval, n_total, num_tasks):
    for b in range(interval, min(n_total, num_tasks*interval)+1, interval):
        ax.axvline(b, linestyle='--', color='slategray',
                   linewidth=1.0, alpha=0.7)

def draw_per_task_minima(ax, mu, interval, n_total, color):
    """Horizontal dash at Task 2 min."""
    start = interval
    end = min(2*interval, n_total)
    seg = mu[start:end]
    i_local = int(np.nanargmin(seg))
    y_min = float(seg[i_local])
    ax.hlines(y_min, start+1, end,
              linestyles=(0, (3,2)), linewidth=1.5,
              color=color, alpha=0.9)
    ax.plot(start + i_local + 1, y_min,
            'o', color=color, markersize=4,
            markeredgecolor='white', markeredgewidth=0.8)
    return y_min

# ----------------------------------
import os

if not os.path.exists("stab_gap.png"):
    from pdf2image import convert_from_path
    pages = convert_from_path("stab_gap.pdf", dpi=600)
    pages[0].save("stab_gap.png", "PNG")
    print("✓ Rasterizado 600 dpi → stab_gap.png")
else:
    print("✓ PNG 600 dpi ya existe")


# ----------------------------------
fig, (ax_img, ax) = plt.subplots(
    1, 2,
    figsize=(7, 3),
    gridspec_kw={'width_ratios': [1, 1.25]}
)


import matplotlib.image as mpimg

stab_img = mpimg.imread("stab_gap.png")
ax_img.imshow(stab_img)
# ax_img.set_aspect('auto')
ax_img.axis("off")


y_max_T1 = {}
y_min_T2 = {}

for m in methods:
    mu = np.array(data_C10[0][m]['acc_test'], dtype=float)[:n]
    sd = np.array(data_C10[1][m]['acc_test'], dtype=float)[:n]

    ax.fill_between(it, mu - sd, mu + sd, color=colors[m], alpha=0.15)
    ax.plot(it, mu, color=colors[m], linewidth=1.0)

    y_max_T1[m] = float(np.nanmax(mu[:interval]))
    y_min_T2[m] = draw_per_task_minima(ax, mu, interval, n, colors[m])

draw_task_lines(ax, interval, n, num_tasks)


ax.set_ylabel("% Old Task Accuracy", fontsize=16)
ax.set_xlabel("Iteration", fontsize=16)
ax.set_ylim(40, 100)
ax.yaxis.set_major_formatter(FuncFormatter(fmt_pct))
ax.set_xticks([0, interval, 2*interval])
ax.set_xticklabels(["0", "400", "800"])

ax.text(interval/2, 39, "OLD Task", ha='center', va='top', fontsize=13, fontweight='bold')
ax.text(interval + interval/2, 39, "NEW Task", ha='center', va='top', fontsize=13, fontweight='bold')

ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.tick_params(axis='y', length=0)


ax.set_xlim(1, n + 200)

offsets = {
    'ENTROPY GAIN': 10,
    'ADAM':         73,
    'SGD':          136,
    'MSGD':         200,
}

for m in arrow_methods:
    x_arrow = n + offsets[m]
    x_text  = x_arrow + 7

    ax.annotate(
        '',
        xy=(x_arrow, y_min_T2[m]),
        xytext=(x_arrow, y_max_T1[m]),
        arrowprops=dict(arrowstyle='<->', color=colors[m], lw=1.3),
        ha='center', va='center'
    )

    ax.text(
        x_text,
        (y_min_T2[m] + y_max_T1[m]) / 2,
        f"SG$_{{{label_map[m]}}}$",
        color=colors[m],
        fontsize=10,
        rotation=90,
        va='center',
        ha='left'
    )


plt.tight_layout()
plt.show()


# Save
# fig.savefig("fig_stability_gap_explainer.svg",  format="svg", dpi=600, bbox_inches="tight")
# fig.savefig("fig_stability_gap_explainer.png", dpi=600, bbox_inches='tight')
# fig.savefig("fig_stability_gap_explainer.pdf", format='pdf', dpi=600, bbox_inches='tight')
