import os

import pandas as pd
import torch
from data.Animals import classes as anim_classes
from sklearn.datasets import load_iris
from sklearn.preprocessing import MinMaxScaler
from torchvision.transforms import transforms
from tqdm import tqdm

from data.Cub200 import CUBDataset
from kal.active_strategies import KAL_PLUS_DU, \
    KAL_PLUS_DROP_DU, NAME_MAPPINGS
from kal.knowledge import XORLoss, IrisLoss, AnimalLoss, CUB200Loss
from kal.utils import set_seed

set_seed(0)
dev = torch.device("cpu")

data_folder = os.path.join("..", "..", "data", "CUB200")

xor_points = 100000
xor_input_size = 2
xor_data = torch.rand(xor_points, xor_input_size).to(dev)

iris_data = load_iris().data
iris_data = MinMaxScaler().fit_transform(iris_data)
iris_data = torch.as_tensor(iris_data)

data_mappings = {
    "xor": xor_data,
    "iris": iris_data,
}

dataset_name_mappings = {
    "xor": "XOR",
    "iris": "Iris",
    "animals": "Animals",
    "cub200": "CUB200",
}


datasets = dataset_name_mappings.keys()

# ticks_mappings = {
#     "xor": np.arange(0.6, 1.01, step=0.1),
#     "iris": np.arange(0.7, 1.01, step=0.1),
#     "animals": np.arange(0.2, 0.71, step=0.1),
#     "cub200": np.arange(0.1, 0.61, step=0.1),
# }
cub_dataset = CUBDataset(data_folder, transforms.Compose([transforms.ToTensor,]))
knowledge_mappings = {
    "xor": XORLoss(),
    "iris": IrisLoss(),
    "animals": AnimalLoss(anim_classes),
    "cub200": CUB200Loss(cub_dataset.main_classes, cub_dataset.attributes,
                         cub_dataset.class_attr_comb),
}

kloss_mappings = {
    "xor": "XOR_Kloss",
    "iris": "Iris_Kloss",
    "animals": "Animals_Kloss",
    "cub200": "CUB200_Kloss",
}

list_df = []
for dataset in datasets:
    csv_file = f"{kloss_mappings[dataset]}.csv"
    if os.path.exists(csv_file):
        k_df = pd.read_csv(csv_file)
        list_df.append(k_df)
        if dataset == "animals":
            break
        continue

    df = pd.read_pickle(os.path.join(dataset, "results.pkl"))
    df = df.sort_values(['Strategy', 'Seed', 'Iteration'])
    df = df.reset_index()
    print("Loaded results:", os.path.join(dataset, "results.pkl"))

    k_df = {
        "Strategy": [],
        "Seed": [],
        "Iter": [],
        "K_Loss": [],
    }

    knowledge_satisfaction = []
    pbar = tqdm(df.iterrows(), total=df.shape[0])
    for i, row in pbar:
        if row['Strategy'] in [KAL_PLUS_DU, KAL_PLUS_DROP_DU]:
            dfs = df.drop(i)
        else:
            k_df['Strategy'].append(NAME_MAPPINGS[row['Strategy']])
            k_df['Seed'].append(row['Seed'])
            k_df['Iter'].append(row['Iteration'])
            preds = torch.as_tensor(row['Predictions'])
            if dataset in ['xor', 'iris']:
                data = data_mappings[dataset]
                train_idx = row['Train Idx']
                data = data[train_idx]
                if dataset == "xor" and len(preds.shape) == 2:
                    preds = preds[:, 0]
                k_loss = knowledge_mappings[dataset](preds, x=data).mean()
            else:
                k_loss = knowledge_mappings[dataset](preds).mean()
            pbar.set_description(f"{row['Strategy']}, s: {row['Seed']}, loss: {k_loss:.2f}")
            k_df['K_Loss'].append(k_loss.item())
    pbar.close()

    k_df = pd.DataFrame(k_df)
    dfs_kloss_mean = k_df.groupby("Strategy").mean()["K_Loss"]
    dfs_kloss_std = k_df.groupby(["Strategy", "Seed"]).mean()["K_Loss"] \
        .groupby("Strategy").std().tolist()
    klosses = []
    for mean, std in zip(dfs_kloss_mean, dfs_kloss_std):
        kloss = f"${mean:.2f}$ {{\\tiny $\\pm {std:.2f}$ }}"
        klosses.append(kloss)

    df_kloss = pd.DataFrame({
        "Strategy": dfs_kloss_mean.index,
        kloss_mappings[dataset]: klosses,
    }).set_index("Strategy")

    df_kloss.to_csv(csv_file)
    list_df.append(df_kloss)

    if dataset == "animals":
        break

dfs = pd.concat(list_df, axis=1)

print(dfs.to_latex(float_format="%.2f", escape=False))
with open(os.path.join(f"k_losses_table.txt"), "w") as f:
    f.write(dfs.to_latex(float_format="%.2f", escape=False))


#     palette = [color_mappings[method] for method in np.unique(dfs['Strategy'])]
#     print([(method, color) for method, color in zip(palette, np.unique(dfs['Strategy']))])
#     dfs['Strategy'] = Strategies
#     if "Accuracy" in dfs.columns:
#         dfs['Test Accuracy'] = dfs['Accuracy']
#     dfs['F1'] = dfs['Test Accuracy'] / 100
#
#     sns.lineplot(data=dfs, x="Points", y="F1",
#                  hue="Strategy", style="Ours", size="Ours",
#                  ci=None, style_order=[1, 0], palette=palette,
#                  size_order=[1, 0], sizes=[4, 2], ax=ax2)
#     sns.despine(left=True, bottom=True)
#     ax2.legend().set_visible(False)
#     if dataset == "xor" or dataset == "animals":
#         ax2.set_ylabel("F1 score",  fontsize=18)
#     else:
#         ax2.set_ylabel("")
#     ax2.set_yticks(ticks_mappings[dataset])
#     ax2.set_xlabel("# Points", fontsize=18)
#     ax2.set_title(dataset_mappings[dataset], fontsize=24, pad=10)
#
# #%%
#
# handles, labels = axes2[1].get_legend_handles_labels()
# n_strategies = len(STRATEGIES)
# lgd = fig2.legend(handles[:n_strategies], labels[:n_strategies], fontsize=18, loc='upper center',
#                   bbox_to_anchor=(1.12, 0.8), ncol=1)
# plt.tight_layout()
# plt.savefig(f"grouped_curves.pdf",
#             dpi=200, bbox_inches='tight')
# plt.show()
