import os
from random import shuffle

import pandas as pd
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset


class VisionTransformerDataset(Dataset):
    def __init__(self, data_path="vision_data", n_heads=2, split="train", name = 'mnist', load_cache=True, cut_off = 0.1) -> None:
        super().__init__()
        self.split = split
        self.n_heads = n_heads
        self.name = name
        self.cut_off = cut_off
        if load_cache and os.path.exists(os.path.join(data_path, f"{name}_{split}_{self.cut_off}.pt")):
            print("Cache file found! Loading from cache...")
            self.load_from_cache(data_path, name)
        else:
            if load_cache and not os.path.exists(os.path.join(data_path, f"{name}_{split}_{self.cut_off}.pt")):
                print(f"Cache file not found!")
            print(f"Loading from csv files...")
            self.load_from_csv(data_path, name)
            print("Cache file generated! Loading from cache...")
            self.load_from_cache(data_path, name)
        self.len_data = len(self.accuracy)
        print(f"Done loading {self.split} data !")

    def load_from_cache(self, data_path, name):
        data = torch.load(os.path.join(data_path, f"{name}_{self.split}_{self.cut_off}.pt"))
        self.embedding = data["embedding"]
        self.encoder = data["encoder"]
        self.classifier = data["classifier"]
        self.accuracy = data["accuracy"]
        self.len_data = self.accuracy.shape[0]

    def load_from_csv(self, data_path, name):
        # Load dataset
        data_folder = os.path.join(data_path, name)

        #load accuracy from csv file
        meta_folder = os.path.join(data_path, "metadata")
        model_info =  pd.read_csv(os.path.join(meta_folder, f"{name}.csv"))

        #keep only the model trained for 75 epochs
        check1 =model_info['ckpt_epoch']=='75'
        # model_info = model_info[check1 & check2] #TODO: uncomment this
        if self.cut_off >0:
            if self.name =='mnist':
                check2= model_info['test_top1_accuracy'] >= self.cut_off
            else:
                check2= model_info['test_accuracy'] >= self.cut_off
            check1 = check1&check2

        model_info = model_info[check1]

        file_names_all = model_info['ckpt_file']
        file_names_all = [file_name[len(name)+1:] for file_name in file_names_all]
        len_all = len(file_names_all)
        num_train = int(float(len_all * 0.7))
        num_val = int(float(len_all * 0.15))
        num_test = len_all - num_val - num_train
        # Shuffle before splitting
        shuffle(file_names_all)
        file_names_splitted = {"train": file_names_all[:num_train], "val": file_names_all[num_train:num_train + num_val], "test": file_names_all[num_train + num_val:]}
        splits = ["train", "val", "test"]

        encoder_names = {"queries": {"weight": "W_q"},
                         "keys": {"weight": "W_k"},
                         "values": {"weight": "W_v"},
                         "out_projection": {"weight": "W_o"},

                         "fc1": {"weight": "W_A", "bias": "b_A"},
                         "fc2": {"weight": "W_B", "bias": "b_B"}
                         }
        # Load data into data structures
        for split in splits:

            # Initialize data structures
            embedding = {"weight": [], "bias": []}
            encoder = []
            classifier = {"weight": [], "bias": []}

            accuracy = []

            file_names = file_names_splitted[split]
            len_data = len(file_names)
            all_model_keys = None

            for idx, file in enumerate(file_names):
                if file.split(".")[-1] == "pt":
                    try:
                        checkpoint = torch.load(os.path.join(data_folder, file), map_location=torch.device('cpu'))
                        # Init empty memory to load weights
                        if all_model_keys is None:
                            all_model_keys = [k.split(".") for k in checkpoint.keys()]

                            for key in all_model_keys:
                                if key[0] in ["embedding", "classifier"] and key[-1] in ["weight", "bias"]:
                                    eval(key[0])[key[-1]].append(torch.empty([len_data, *checkpoint[".".join(key)].shape], device="cpu").unsqueeze(1))
                                elif key[0] == "encoder" and key[-1] in ["weight", "bias"]:
                                    # If encoder layer is not inside the list, then initiate it
                                    if int(key[1]) > len(encoder) - 1:
                                        layer_shape = checkpoint[".".join(key)].shape
                                        if key[-2] in ["queries", "keys", "values"]:
                                            encoder.append({encoder_names[key[-2]][key[-1]]:
                                                                torch.empty([len_data, *checkpoint[".".join(key)].reshape(self.n_heads, layer_shape[0]//self.n_heads, layer_shape[1]).transpose(-1, -2).shape], device="cpu").unsqueeze(1)})
                                        elif key[-2] == "out_projection":
                                            encoder.append({encoder_names[key[-2]][key[-1]]:
                                                                torch.empty([len_data, *checkpoint[".".join(key)].transpose(-1, -2).reshape(self.n_heads, layer_shape[1]//self.n_heads, layer_shape[0]).shape], device="cpu").unsqueeze(1)}).unsqueeze(1)
                                        else:
                                            if key[-1] == "weight":
                                                encoder.append({encoder_names[key[-2]][key[-1]]:
                                                                    torch.empty([len_data, *checkpoint[".".join(key)].transpose(0,1).shape], device="cpu").unsqueeze(1)})
                                            else:
                                                encoder.append({encoder_names[key[-2]][key[-1]]:
                                                                    torch.empty([len_data, *checkpoint[".".join(key)].shape], device="cpu").unsqueeze(1)})
                                    else:
                                        layer_shape = checkpoint[".".join(key)].shape
                                        if key[-2] in ["queries", "keys", "values"]:
                                            encoder[int(key[1])][encoder_names[key[-2]][key[-1]]] = torch.empty([len_data, *checkpoint[".".join(key)].reshape(self.n_heads, layer_shape[0]//self.n_heads, layer_shape[1]).transpose(-1, -2).shape], device="cpu").unsqueeze(1)
                                        elif key[-2] == "out_projection":
                                            encoder[int(key[1])][encoder_names[key[-2]][key[-1]]] = torch.empty([len_data, *checkpoint[".".join(key)].transpose(-1, -2).reshape(self.n_heads, layer_shape[1]//self.n_heads, layer_shape[0]).shape], device="cpu").unsqueeze(1)
                                        else:
                                            if key[-1] == "weight":
                                                encoder[int(key[1])][encoder_names[key[-2]][key[-1]]] = torch.empty([len_data, *checkpoint[".".join(key)].transpose(0,1).shape], device="cpu").unsqueeze(1)
                                            else:
                                                encoder[int(key[1])][encoder_names[key[-2]][key[-1]]] = torch.empty([len_data, *checkpoint[".".join(key)].shape], device="cpu").unsqueeze(1)

                        # Load weight into the correct slide of the data
                        idx_layer_embedding = 0
                        idx_layer_classifier = 0
                        for key in all_model_keys:
                            if key[0] in ["embedding", "classifier", "encoder"] and key[-1] in ["weight", "bias"]:
                                if key[0] == "embedding":
                                    eval(key[0])[key[-1]][idx_layer_embedding // 2][idx] = checkpoint[".".join(key)].unsqueeze(0)
                                    idx_layer_embedding += 1
                                elif key[0] == "classifier":
                                    eval(key[0])[key[-1]][idx_layer_classifier // 2][idx] = checkpoint[".".join(key)].unsqueeze(0)
                                    idx_layer_classifier += 1
                                elif key[0] == "encoder":
                                    layer_shape = checkpoint[".".join(key)].shape
                                    if key[-2] in ["queries", "keys", "values"]:
                                        encoder[int(key[1])][encoder_names[key[-2]][key[-1]]][idx] = checkpoint[".".join(key)].reshape(self.n_heads, layer_shape[0]//self.n_heads, layer_shape[1]).transpose(-1, -2).unsqueeze(0)
                                    elif key[-2] == "out_projection":
                                        encoder[int(key[1])][encoder_names[key[-2]][key[-1]]][idx] = checkpoint[".".join(key)].transpose(0,1).reshape(self.n_heads, layer_shape[1]//self.n_heads, layer_shape[0]).unsqueeze(0)
                                    else:
                                        if key[-1] == "weight":
                                            encoder[int(key[1])][encoder_names[key[-2]][key[-1]]][idx] = checkpoint[".".join(key)].transpose(0,1).unsqueeze(0)
                                        else:
                                            encoder[int(key[1])][encoder_names[key[-2]][key[-1]]][idx] = checkpoint[".".join(key)].unsqueeze(0)
                        #get accuracy with corresponding file name
                        if self.name == "mnist":
                            accuracy.append(model_info[model_info['ckpt_file']==f"{name}/{file}"]['test_top1_accuracy'].values[0])
                        else:
                            accuracy.append(model_info[model_info['ckpt_file']==f"{name}/{file}"]['test_accuracy'].values[0])


                    except Exception as e:
                        raise(e)
                        print(f"Error loading {file} !")
                        continue
            encoder = {k: [block[k] for block in encoder] for k in encoder[0].keys()}
            #make accuracy a tensor of shape [len_data,1]
            accuracy = torch.tensor(accuracy).unsqueeze(1).float()
            torch.save({"embedding": embedding, "encoder": encoder, "classifier": classifier, "accuracy": accuracy},
                   os.path.join(data_path, f"{name}_{split}_{self.cut_off}.pt"))
        print("done")
        # W_q, W_k, W_v, W_o is already splitted by the head, i.e. [batch_size, n_heads, dim_in, dim_out]
        # All weight matrices in the encoder are transpose, because right weight multiply is use, i.e., XW instead of WX !!!

    def __len__(self):
        return self.len_data

    def __getitem__(self, index):
        embedding = {k: [self.embedding[k][layer][index] for layer in range(len(self.embedding[k]))] for k in self.embedding.keys()}
        classifier = {k: [self.classifier[k][layer][index] for layer in range(len(self.classifier[k]))] for k in self.classifier.keys()}
        # encoder = [{k: block[k][index] for k in block.keys()} for block in self.encoder]
        encoder = {k: [key_block[index] for key_block in self.encoder[k]] for k in self.encoder.keys()}
        accuracy = self.accuracy[index]
        return {"embedding": embedding, "encoder": encoder, "classifier": classifier, "accuracy": accuracy}

    @staticmethod
    def collate_fn(batch):
        embedding = {k: [torch.stack([sample["embedding"][k][layer] for sample in batch], dim=0) for layer in range(len(batch[0]["embedding"][k]))] for k in batch[0]["embedding"].keys()}
        classifier = {k: [torch.stack([sample["classifier"][k][layer] for sample in batch], dim=0) for layer in range(len(batch[0]["classifier"][k]))] for k in batch[0]["classifier"].keys()}
        # encoder = [{k: torch.stack([sample["encoder"][block_idx][k] for sample in batch], dim=0) for k in batch[0]["encoder"][block_idx].keys()} for block_idx in range(len(batch[0]["encoder"]))]
        encoder = {k: [torch.stack([sample["encoder"][k][layer] for sample in batch], dim=0) for layer in range(len(batch[0]["encoder"][k]))] for k in batch[0]["encoder"].keys()}
        accuracy = torch.stack([sample["accuracy"] for sample in batch])
        return {"embedding": embedding, "classifier": classifier, "encoder": encoder, "accuracy": accuracy}

    @staticmethod
    def to_device(batch, device):
        embedding = {k: [batch["embedding"][k][layer].to(device) for layer in range(len(batch["embedding"][k]))] for k in batch["embedding"].keys()}
        classifier = {k: [batch["classifier"][k][layer].to(device) for layer in range(len(batch["classifier"][k]))] for k in batch["classifier"].keys()}
        encoder = {k: [batch["encoder"][k][layer].to(device) for layer in range(len(batch["encoder"][k]))] for k in batch["encoder"].keys()}
        accuracy = batch["accuracy"].to(device)
        return {"embedding": embedding, "encoder": encoder, "classifier": classifier, "accuracy": accuracy}

class TransformerClassificationDataset(VisionTransformerDataset):
    def __init__(self, vision_data_path="vision_data", text_data_path="text_data", n_heads=2, split="train", name1='mnist', name2 ='ag_news', load_cache=True) -> None:
        self.dataset_vision = VisionTransformerDataset(data_path=vision_data_path, n_heads=2, split=split, name=name1, load_cache=load_cache)
        self.dataset_text = VisionTransformerDataset(data_path=text_data_path, n_heads=2, split=split, name=name2, load_cache=load_cache)
        self.dataset_dict = {"vision": self.dataset_vision, "text": self.dataset_text}
        self.encoder = { k: [torch.cat([self.dataset_dict["vision"].encoder[k][i], self.dataset_dict["text"].encoder[k][i]], dim=0)
                             for i in range(len(self.dataset_dict["vision"].encoder[k]))] for k in self.dataset_dict["vision"].encoder.keys()}
        self.label = torch.cat([torch.zeros(self.dataset_vision.len_data, dtype=torch.long), torch.ones(self.dataset_text.len_data, dtype=torch.long)], dim=0)
        self.one_hot_label = F.one_hot(self.label)
        self.len_data = self.dataset_vision.len_data + self.dataset_text.len_data


    def __getitem__(self, index):
        encoder = {k: [key_block[index] for key_block in self.encoder[k]] for k in self.encoder.keys()}
        label = self.one_hot_label[index]
        return {"encoder": encoder, "label": label}

    @staticmethod
    def collate_fn(batch):
        encoder = {k: [torch.stack([sample["encoder"][k][layer] for sample in batch], dim=0) for layer in range(len(batch[0]["encoder"][k]))] for k in batch[0]["encoder"].keys()}
        label = torch.stack([sample["label"] for sample in batch])
        return {"encoder": encoder, "label": label}

    @staticmethod
    def to_device(batch, device):
        encoder = {k: [batch["encoder"][k][layer].to(device) for layer in range(len(batch["encoder"][k]))] for k in batch["encoder"].keys()}
        label = batch["label"].to(device)
        return {"encoder": encoder, "label": label}


if __name__ == "__main__":

    dataset = TransformerClassificationDataset(vision_data_path="nfn_transformer/nfn-vision-v2/model-v2",
                                               text_data_path="nfn_transformer/nfn-vision-v2/model-v2",
                                               split="test",
                                               name='mnist')

    loader = DataLoader(dataset=dataset, batch_size=10, shuffle=True, collate_fn=dataset.collate_fn)

    for idx, batch in enumerate(loader):
        batch = dataset.to_device(batch, "cpu")
        encoder, label = batch["encoder"], batch["label"]
        print("test")
