import os

import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt

from kal.active_strategies import NAME_MAPPINGS_LATEX, KAL_PLUS_DU, \
    KAL_PLUS_DROP_DU, STRATEGIES, color_mappings, NAME_MAPPINGS, KAL_STARS

dataset_mappings = {
    "xor": "XOR",
    "iris": "Iris",
    "animals": "Animals",
    "cub200": "CUB200",
}
#
ticks_mappings = {
    "xor": np.arange(0.6, 1.01, step=0.1),
    "iris": np.arange(0.7, 1.01, step=0.1),
    "animals": np.arange(0.2, 0.71, step=0.1),
    "cub200": np.arange(0.1, 0.61, step=0.1),
}
datasets = dataset_mappings.keys()
image_folder = "images"

# %%
sns.set(style="whitegrid", font_scale=1.2,
        rc={'figure.figsize': (10, 10)})

fig2, axes2 = plt.subplots(2, 2)
axes2 = np.reshape(axes2, (len(datasets)))

for dataset, ax2 in zip(datasets, axes2):
    dfs = pd.read_pickle(os.path.join(dataset, "results.pkl"))
    dfs['Points'] = [len(used) for used in dfs['Used Idx']]
    ours = [True if "KAL" in strategy else False for strategy in dfs['Strategy']]
    dfs['Ours'] = ours

    dfs = dfs.sort_values(['Strategy', 'Seed', 'Iteration'])
    dfs = dfs.reset_index()

    Strategies = []
    rows = []
    for i, row in dfs.iterrows():
        if row['Strategy'] in [KAL_PLUS_DU, KAL_PLUS_DROP_DU] + KAL_STARS:
            dfs = dfs.drop(i)
        else:
            Strategies.append(NAME_MAPPINGS[row['Strategy']])
    palette = [color_mappings[method] for method in np.unique(dfs['Strategy'])]
    print([(method, color) for method, color in zip(palette, np.unique(dfs['Strategy']))])
    dfs['Strategy'] = Strategies
    if "Accuracy" in dfs.columns:
        dfs['Test Accuracy'] = dfs['Accuracy']
    dfs['F1'] = dfs['Test Accuracy'] / 100

    sns.lineplot(data=dfs, x="Points", y="F1",
                 hue="Strategy", style="Ours", size="Ours",
                 ci=None, style_order=[1, 0], palette=palette,
                 size_order=[1, 0], sizes=[4, 2], ax=ax2, legend="brief")
    sns.despine(left=True, bottom=True)
    ax2.legend().set_visible(False)
    if dataset == "xor" or dataset == "animals":
        ax2.set_ylabel("F1 score",  fontsize=18)
    else:
        ax2.set_ylabel("")
    ax2.set_yticks(ticks_mappings[dataset])
    ax2.set_xlabel("# Points", fontsize=18)
    ax2.set_title(dataset_mappings[dataset], fontsize=24, pad=10)

#%%

handles, labels = axes2[1].get_legend_handles_labels()

n_strategies = len(STRATEGIES)
for i in range(1, n_strategies+1):
    if "KAL" in labels[i]:
        handles[i].set_linestyle(handles[-2].get_linestyle())
        handles[i].set_linewidth(handles[-2].get_linewidth())
        print(labels[i], "linestyle", handles[i].get_linestyle(),
              "linewidth", handles[i].get_linewidth())
    else:
        handles[i].set_linestyle(handles[-1].get_linestyle())
        handles[i].set_linewidth(handles[-1].get_linewidth())
        print(labels[i], "linestyle", handles[i].get_linestyle(),
              "linewidth", handles[i].get_linewidth())

lgd = fig2.legend(handles[:n_strategies+1], labels[:n_strategies+1], fontsize=18, loc='upper center',
                  bbox_to_anchor=(1.1, 0.8), ncol=1)
plt.tight_layout()
plt.savefig(f"grouped_curves.pdf",
            dpi=200, bbox_inches='tight')
plt.show()
