import os
import pickle as pkl
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D

# ---------- Hyperparameters ----------
gamma = 0.9
g0 = 1.0

etas = {
    "splitMNIST":        0.4,
    "splitCIFAR10":      0.2,
    "splitMiniImageNet": 0.1,
    "rotatedMNIST":      0.5,
    "domainCIFAR100":    0.1,
}

# ---------- Paths ----------
base_dir = ".\\data"

paths = {
    "splitMNIST":        os.path.join(base_dir, "splitMNIST_avg.pkl"),
    "splitCIFAR10":      os.path.join(base_dir, "Split_cifar10_avg.pkl"),
    "splitMiniImageNet": os.path.join(base_dir, "miniImageNet_avg.pkl"),
    "rotatedMNIST":      os.path.join(base_dir, "rotatedMNIST_avg.pkl"),
    "domainCIFAR100":    os.path.join(base_dir, "domainCIFAR100_800_avg.pkl"),
}

def recover_entropy_from_gain_trace(gain_trace, gamma, g0, eta):
    g = np.asarray(gain_trace, dtype=float)
    return (g[1:] - gamma * g[:-1] - (1 - gamma) * g0) / eta

def entropy_std_from_gain_std(gain_std_trace, gamma, eta):
    s = np.asarray(gain_std_trace, dtype=float)
    var_H_t = (s[1:]**2 + (gamma**2) * s[:-1]**2) / (eta**2)
    return np.sqrt(var_H_t)

display_names = {
    'splitMNIST':        'Split MNIST',
    'splitCIFAR10':      'Split CIFAR-10',
    'splitMiniImageNet': 'Split mini-ImageNet',
    'rotatedMNIST':      'Rotated MNIST',
    'domainCIFAR100':    'Domain CIFAR-100'
}

num_tasks_map = {
    "splitMNIST":        5,
    "splitCIFAR10":      5,
    "splitMiniImageNet": 5,
    "rotatedMNIST":      3,
    "domainCIFAR100":    3,
}

split_keys = [
    "splitMNIST",
    "splitCIFAR10",
    "splitMiniImageNet",
]

rotated_keys = [
    "rotatedMNIST",
    "domainCIFAR100",
]

alphas = [0.9, 0.5, 0.1]
ema_colors = ['palegreen', 'mediumseagreen', 'darkgreen']

entropy_colors = {
    'splitMNIST':       'crimson',
    'splitCIFAR10':     'sienna',
    'splitMiniImageNet':'teal',
    'rotatedMNIST':     'orangered',
    'domainCIFAR100':   'lightsalmon',
}

fig = plt.figure(figsize=(16, 7))
gs = fig.add_gridspec(
    2, 4,
    width_ratios=[2, 1, 1, 1],
    wspace=0.5,
    hspace=0.6
)

ax_entropy_leg = fig.add_subplot(gs[0, 0])   # solo leyenda
ax_entropy     = fig.add_subplot(gs[1, 0])   # scatter

ax_split_mnist = fig.add_subplot(gs[0, 1])
ax_split_c10   = fig.add_subplot(gs[0, 2])
ax_split_mini  = fig.add_subplot(gs[0, 3])

ax_rot_mnist   = fig.add_subplot(gs[1, 1])
ax_domain_c100 = fig.add_subplot(gs[1, 2])
ax_leg_bottom  = fig.add_subplot(gs[1, 3])


axes_map = {
    "splitMNIST":        ax_split_mnist,
    "splitCIFAR10":      ax_split_c10,
    "splitMiniImageNet": ax_split_mini,
    "rotatedMNIST":      ax_rot_mnist,
    "domainCIFAR100":    ax_domain_c100,
}


gain_line_handle = None
ema_handles = []

for name in split_keys + rotated_keys:
    ax = axes_map[name]
    eta = etas[name]

    with open(paths[name], "rb") as f:
        obj = pkl.load(f)

    avg_dict = obj[0]
    gain_full = np.array(avg_dict["ENTROPY GAIN"]["gain_out"], dtype=float)

    T_full = len(gain_full)
    boundary = T_full // num_tasks_map[name]
    gain_trace = gain_full[:boundary]

    H_t = recover_entropy_from_gain_trace(gain_trace, gamma, g0, eta)
    T = len(gain_trace)
    x_gain = np.arange(T)

    line_gain, = ax.plot(
        x_gain, gain_trace,
        linewidth=4,
        color="darkorange",
        label="NGM-SGD"
    )
    if gain_line_handle is None:
        gain_line_handle = line_gain

    local_ema_handles = []
    for alpha, col in zip(alphas, ema_colors):
        g_ema = np.zeros_like(gain_trace)
        g_ema[0] = g0
        for t in range(len(H_t)):
            g_ema[t+1] = (1 - alpha) * g_ema[t] + alpha * (g0 + H_t[t])

        line_ema, = ax.plot(
            x_gain, g_ema,
            linestyle="--",
            linewidth=2.5,
            color=col,
            label=f"EMA+b α={alpha}"
        )
        local_ema_handles.append(line_ema)

    if not ema_handles:
        ema_handles = local_ema_handles

    ax.set_title(display_names[name] + " (T1)", fontsize=20)
    ax.set_xlabel("Iterations", fontsize=20)
    if name in ["splitMNIST", "rotatedMNIST"]:
        ax.set_ylabel("Neuronal gain", fontsize=20)

    ax.tick_params(axis='both', labelsize=18)
    for spine in ["top", "right"]:
        ax.spines[spine].set_visible(False)
    ax.grid(False)

ax_leg_bottom.axis("off")
handles_for_legend = [gain_line_handle] + ema_handles
labels_for_legend = ["NGM-SGD"] + [rf"EMA+b $\alpha={a}$" for a in alphas]
ax_leg_bottom.legend(
    handles_for_legend,
    labels_for_legend,
    frameon=False,
    fontsize=20,
    loc="center"
)

H_means = {}
H_stds = {}

for name, path in paths.items():
    eta = etas[name]
    with open(path, "rb") as f:
        obj = pkl.load(f)

    avg_dict = obj[0]
    std_dict = obj[1]
    gain_trace = np.array(avg_dict["ENTROPY GAIN"]["gain_out"], dtype=float)
    gain_std  = np.array(std_dict["ENTROPY GAIN"]["gain_out"], dtype=float)

    H_t = recover_entropy_from_gain_trace(gain_trace, gamma, g0, eta)
    sigma_H_t = entropy_std_from_gain_std(gain_std, gamma, eta)

    H_means[name] = H_t.mean()
    H_stds[name] = np.sqrt(np.sum(sigma_H_t**2) / (len(H_t)**2))

for name in split_keys + rotated_keys:
    ax_entropy.errorbar(
        H_means[name],
        etas[name],
        xerr=H_stds[name],
        fmt='o',
        markersize=10,
        markerfacecolor='none',       
        markeredgewidth=2,
        markeredgecolor=entropy_colors[name],
        ecolor=entropy_colors[name],
        capsize=4,
        elinewidth=1.5
    )

ax_entropy.set_xlabel("Average entropy  $\\overline{H}$", fontsize=20)
ax_entropy.set_ylabel("$\\eta$ (gain scale)", fontsize=20)

for spine in ["top", "right"]:
    ax_entropy.spines[spine].set_visible(False)
ax_entropy.tick_params(axis='both', labelsize=20)

ax_entropy_leg.axis("off")
legend_handles = []
for name in split_keys + rotated_keys:
    h = Line2D(
        [0], [0],
        marker='o',
        color='none',
        markerfacecolor=entropy_colors[name],   
        markeredgecolor='none',
        markersize=10,
        linestyle='None',
        label=display_names[name]
    )
    legend_handles.append(h)

ax_entropy_leg.legend(
    handles=legend_handles,
    loc="lower right",
    frameon=False,
    fontsize=20
)

ax_entropy.text(-0.3, 2.785, "A)", transform=ax_entropy.transAxes,
                fontsize=36, fontweight='bold')

ax_split_mnist.text(-0.45, 1.2, "B)", transform=ax_split_mnist.transAxes,
                    fontsize=36, fontweight='bold')

plt.show()

# Save
filename = 'app_gainEMA_H'
# fig.savefig(f'{filename}.pdf', format='pdf', dpi=600, bbox_inches='tight')