import os

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

from kal.active_strategies import color_mappings, KALS

strategies = [
    CAL := "constrained",
    SUPERVISED := "supervised",
    RANDOM := "random"
]

strategies_mappings = {
    CAL: "KAL",
    SUPERVISED: "SupLoss",
    RANDOM: "Random",
}

result_folder = "dogvsperson"
dataset = "Dog vs Person"
points = [1000, 1500, 2000, 2500, 3000]
seeds = range(3)

dfs = {
    "Dataset": [],
    "Seed": [],
    "Iter": [],
    "Points": [],
    "mAP": [],
    "Ours": [],
    "Strategy": [],
}

for seed in seeds:
    for i, used_points in enumerate(points):
        for strategy in strategies:
            dfs["Dataset"].append(dataset)
            dfs["Seed"].append(seed)
            dfs["Iter"].append(i)
            dfs["Points"].append(used_points)
            dfs["Strategy"].append(strategies_mappings[strategy])
            dfs["Ours"].append(strategy in KALS)
            if i != 0:
                df_file = os.path.join(result_folder, f"metrics_{strategy}_strategy_{seed}_seed_"
                                                  f"{used_points}_points.pkl")
            else:
                df_file = os.path.join(result_folder, f"metrics_{CAL}_strategy_{seed}_seed_"
                                                      f"{used_points}_points.pkl")
            df = pd.read_pickle(df_file)
            dfs['mAP'].append(df['test accuracy'].item() * 100)

dfs = pd.DataFrame(dfs).sort_values("Strategy")

sns.set(style="whitegrid", font_scale=2.5,
        rc={'figure.figsize': (10, 10)})
palette = [color_mappings[method] for method in np.unique(dfs['Strategy'])]

ax2 = sns.lineplot(data=dfs, x="Points", y="mAP",
                   hue="Strategy", style="Ours", size="Ours",
                   style_order=[1, 0], palette=palette,
                   size_order=[1, 0], sizes=[4, 2], legend="brief")
sns.despine(left=True, bottom=True)
ax2.legend().set_visible(False)
ax2.set_xlabel("# Points")
ax2.set_title(dataset, fontsize=32, pad=10)

handles, labels = ax2.get_legend_handles_labels()

n_strategies = len(strategies)
for i in range(1, n_strategies+1):
    if "KAL" in labels[i]:
        handles[i].set_linestyle(handles[-2].get_linestyle())
        handles[i].set_linewidth(handles[-2].get_linewidth())
        print(labels[i], "linestyle", handles[i].get_linestyle(),
              "linewidth", handles[i].get_linewidth())
    else:
        handles[i].set_linestyle(handles[-1].get_linestyle())
        handles[i].set_linewidth(handles[-1].get_linewidth())
        print(labels[i], "linestyle", handles[i].get_linestyle(),
              "linewidth", handles[i].get_linewidth())
fig2 = ax2.get_figure()
lgd = fig2.legend(handles[:n_strategies+1], labels[:n_strategies+1],
                  bbox_to_anchor=(0.5, 0.9), ncol=1)
plt.tight_layout()
plt.savefig("Obj_rec_curves.pdf",  dpi=200, bbox_inches='tight')
plt.show()

# df_mean = dfs.groupby("Dataset").mean()
# df_std = dfs.groupby(["Dataset", "Seed"]).mean().groupby("Dataset").std()
#
# df_final = pd.DataFrame({
#     "Dataset": [dataset],
#     "SupLoss": [f"${df_mean['SupLoss'].item():.2f}$ {{\\tiny $\\pm {df_std['SupLoss'].item():.2f}$ }}"],
#     "Random": [f"${df_mean['Random'].item():.2f}$ {{\\tiny $\\pm {df_std['Random'].item():.2f}$ }}"],
#     "KAL": [f"${df_mean['KAL'].item():.2f}$ {{\\tiny $\\pm {df_std['KAL'].item():.2f}$ }}"]
# })
