# -*- coding: utf-8 -*-
"""
Gain tasks plots
"""

# -*- coding: utf-8 -*-
import os
import pickle
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from matplotlib.ticker import MaxNLocator

plt.rc('font', family='sans-serif', serif=['Arial'])
plt.rc('axes', edgecolor='black', linewidth=3)
plt.rc('xtick', color='black')
plt.rc('ytick', color='black')
plt.rc('grid', color='black')

RENAME = {"ENTROPY GAIN": "NGM-SGD", "ADAM": "Adam"}

def rename_methods_in_results(avg_res: dict, std_res: dict):
    for old, new in RENAME.items():
        if old in avg_res and new not in avg_res:
            avg_res[new] = avg_res.pop(old)
        if old in std_res and new not in std_res:
            std_res[new] = std_res.pop(old)

def pick_method_key(d: dict, candidates):
    keys_lower = {k.lower(): k for k in d.keys()}
    for c in candidates:
        k = keys_lower.get(c.lower())
        if k is not None:
            return k
    raise KeyError(f"None of {candidates} found. Available: {list(d.keys())}")

def get_gain_series(avg_res: dict, std_res: dict, method_name: str):
    if method_name not in avg_res:
        k = pick_method_key(avg_res, [method_name, "ENTROPY GAIN", "Entropy Gain"])
    else:
        k = method_name
    avg_series = np.asarray(avg_res[k]["gain_out"], dtype=float)
    std_series = None
    if std_res and k in std_res and "gain_out" in std_res[k]:
        std_series = np.asarray(std_res[k]["gain_out"], dtype=float)
    return avg_series, std_series

xlim_class   = (0.9, 3.1)
ylim_class   = (1.2, 4.0)
xlim_domain  = (0.9, 3.1)
ylim_domain  = (1.2, 4.0)

with open('./data/splitMNIST_avg.pkl', 'rb') as f:
    avg_splitMNIST, std_splitMNIST = pickle.load(f)
with open('./data/Split_cifar10_avg.pkl', 'rb') as f:
    avg_splitCIFAR10, std_splitCIFAR10 = pickle.load(f)
with open('./data/miniImageNet_avg.pkl', 'rb') as f:
    avg_splitMiniImageNet, std_splitMiniImageNet = pickle.load(f)
with open('./data/rotatedMNIST_avg.pkl', 'rb') as f:
    avg_rotatedMNIST, std_rotatedMNIST = pickle.load(f)
with open('./data/domainCIFAR100_800_avg.pkl', 'rb') as f:
    avg_domainCIFAR100, std_domainCIFAR100 = pickle.load(f)

rename_methods_in_results(avg_splitMNIST, std_splitMNIST)
rename_methods_in_results(avg_splitCIFAR10, std_splitCIFAR10)
rename_methods_in_results(avg_splitMiniImageNet, std_splitMiniImageNet)
rename_methods_in_results(avg_rotatedMNIST, std_rotatedMNIST)
rename_methods_in_results(avg_domainCIFAR100, std_domainCIFAR100)

ctx_iters = {
    'splitMNIST': 200,
    'splitCIFAR10': 400,
    'splitminiImageNet': 200,
    'rotatedMNIST': 400,
    'domainCIFAR100': 800,
}

datasets = {
    'splitMNIST':        {'avg': avg_splitMNIST,        'std': std_splitMNIST,        'tasks': 5, 'iters_per_task': ctx_iters['splitMNIST'],        'type': 'class-incremental'},
    'splitCIFAR10':      {'avg': avg_splitCIFAR10,      'std': std_splitCIFAR10,      'tasks': 5, 'iters_per_task': ctx_iters['splitCIFAR10'],      'type': 'class-incremental'},
    'splitminiImageNet': {'avg': avg_splitMiniImageNet, 'std': std_splitMiniImageNet, 'tasks': 5, 'iters_per_task': ctx_iters['splitminiImageNet'], 'type': 'class-incremental'},
    'rotMNIST':          {'avg': avg_rotatedMNIST,      'std': std_rotatedMNIST,      'tasks': 3, 'iters_per_task': ctx_iters['rotatedMNIST'],      'type': 'domain-incremental'},
    'domainCIFAR100':    {'avg': avg_domainCIFAR100,    'std': std_domainCIFAR100,    'tasks': 3, 'iters_per_task': ctx_iters['domainCIFAR100'],    'type': 'domain-incremental'},
}

method_preferred = 'NGM-SGD'  # same plot semantics as your original script (one method across datasets)

colors = {
    'splitMNIST':        'crimson',
    'splitCIFAR10':      'sienna',
    'splitminiImageNet': 'teal',
    'rotMNIST':          'orangered',
    'domainCIFAR100':    'lightsalmon',
}

marker = 'o'
display_names = {
    'splitMNIST':        'Split MNIST',
    'splitCIFAR10':      'Split CIFAR-10',
    'splitminiImageNet': 'Split mini-ImageNet',
    'rotMNIST':          'Rotated MNIST',
    'domainCIFAR100':    'Domain CIFAR-100'
}

fits = {}
for name, cfg in datasets.items():
    avg_res, std_res = cfg['avg'], cfg['std']
    avg_gain, std_gain = get_gain_series(avg_res, std_res, method_preferred)
    mts, its = cfg['tasks'], cfg['iters_per_task']
    if avg_gain.size != mts * its:
        raise ValueError(f"{name}: gain length={avg_gain.size} != tasks*iters {mts*its}")
    avg_arr = avg_gain.reshape(mts, its)
    std_arr = std_gain.reshape(mts, its) if std_gain is not None and std_gain.size == avg_gain.size else None
    if cfg['type'] == 'class-incremental':
        avg_arr = avg_arr[:3, :]
        std_arr = std_arr[:3, :] if std_arr is not None else None
        n_tasks = 3
    else:
        n_tasks = mts
    mean_per_task = avg_arr.mean(axis=1)
    if std_arr is not None:
        std_per_task = std_arr.mean(axis=1)
    else:
        std_per_task = avg_arr.std(axis=1)
    x = np.arange(1, n_tasks + 1)
    m, b = np.polyfit(x, mean_per_task, 1)
    fits[name] = (m, b, mean_per_task, std_per_task, n_tasks)

fig, (ax_c, ax_d) = plt.subplots(1, 2, figsize=(10, 5), sharey=False)

for name, cfg in datasets.items():
    m, b, mean_per_task, std_per_task, n_tasks = fits[name]
    x = np.arange(1, n_tasks + 1)
    ax = ax_c if cfg['type'] == 'class-incremental' else ax_d
    ax.errorbar(
        x, mean_per_task, yerr=std_per_task,
        fmt=marker, ms=10, color=colors[name],
        capsize=4, linestyle='none', label='_nolegend_'
    )
    ax.plot(x, m * x + b, linestyle='--', color=colors[name])

ax_c.set_xlim(*xlim_class)
ax_c.set_ylim(*ylim_class)
ax_c.set_xticks([1, 2, 3])
ax_d.set_xlim(*xlim_domain)
ax_d.set_ylim(*ylim_domain)
ax_d.set_xticks([1, 2, 3])

ax_c.set_title('Class-Incremental Tasks', fontsize=18)
ax_c.set_xlabel('Task Index', fontsize=20)
ax_c.set_ylabel('Neuronal Gain', fontsize=20)
ax_d.set_title('Domain-Incremental Tasks', fontsize=18)
ax_d.set_xlabel('Task Index', fontsize=20)
ax_d.set_ylabel('Neuronal Gain', fontsize=20)

for ax in (ax_c, ax_d):
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.tick_params(axis='both', which='both', labelcolor='black')
    ax.yaxis.set_major_locator(MaxNLocator(nbins=4, prune='both'))

handles = [Line2D([], [], marker=marker, linestyle='none', color=colors[k], label=display_names[k])
           for k in datasets.keys()]

fig.legend(handles=handles[:3], loc='upper center', ncol=3, frameon=False, fontsize=16, bbox_to_anchor=(0.5, 1.09))
fig.legend(handles=handles[3:],  loc='upper center', ncol=2, frameon=False, fontsize=16, bbox_to_anchor=(0.5, 1.03))

plt.tight_layout(rect=[0, 0, 1, 0.94])
plt.savefig('fig_extra_gtaskComplexity.pdf', format='pdf', dpi=600, bbox_inches='tight')
plt.show()

