from torch.utils.data import Subset
import numpy as np
from pathlib import Path
import torch
import pickle
from tqdm import tqdm
import pandas as pd


def convert_json_to_pd(crowd_data):
    data_ = {"task": [], "worker": [], "label": []}
    for task, all_ in crowd_data.items():
        for rev, ans in all_.items():
            data_["task"].append(task)
            data_["worker"].append(rev)
            data_["label"].append(ans)
    data_ = pd.DataFrame(data_)
    return data_


class WAUM:
    def __init__(
        self,
        pi,
        all_ans,
        train_set,
        n_classes,
        criterion,
        model,
        n_epoch,
        optimizer,
        spam=0,
        nw=1,
    ):
        self.pi = pi
        self.all_ans = all_ans
        self.train_set = train_set
        self.n_classes = n_classes
        self.model = model
        self.n_spam = spam
        self.n_workers = nw

        def get_ans():
            all_ans_sp = {}
            for task, ans in self.all_ans.items():
                all_ans_sp[int(task)] = {}
                for worker, label in ans.items():
                    if worker.startswith("spam"):
                        num = worker.split("spam")[1]
                        worker_last = self.n_workers - self.n_spam - 1
                        worker = worker_last + int(num)
                    all_ans_sp[int(task)][int(worker)] = label
            self.all_ans_sp = all_ans_sp

        get_ans()
        self.crowd_data = convert_json_to_pd(self.all_ans_sp)
        self.criterion = criterion
        self.n_epoch = n_epoch
        self.optimizer = optimizer
        self.initial_lr = self.optimizer.param_groups[0]["lr"]
        self.checkpoint = {
            "epoch": n_epoch,
            "model": self.model.state_dict(),
            "optimizer": self.optimizer.state_dict(),
        }

        self.path = Path("./temp/").mkdir(parents=True, exist_ok=True)
        torch.save(self.checkpoint, "./temp/checkpoint_aum.pth")

    def get_aum(self, batch_idx=[0, 1, 2], weighted=True):
        """
        batch_idx: [image, true_label, index] positions in batch
        """
        AUM_recorder = {
            "task": [],
            "worker": [],
            "label": [],
            "truth": [],
            "epoch": [],
            "label_logit": [],
            "label_prob": [],
            "other_max_logit": [],
            "other_max_prob": [],
            "secondlogit": [],
            "secondprob": [],
        }
        self.weighted = weighted
        if self.weighted:
            AUM_recorder["score"] = []
        workers = self.crowd_data["worker"].unique()
        for id_worker, j in tqdm(
            enumerate(workers), total=len(workers), desc="workers"
        ):
            data_j = self.crowd_data[self.crowd_data["worker"] == j]
            sub = Subset(
                self.train_set,
                [int(i) for i in list(data_j["task"].values) if int(i) < 9500],
            )

            dl = torch.utils.data.DataLoader(
                sub, batch_size=50, worker_init_fn=0, shuffle=True
            )
            pij = torch.tensor(self.pi[int(j)]).type(torch.FloatTensor).cuda()
            self.model.cuda()
            self.model.train()
            for epoch in range(self.n_epoch):
                for batch in dl:
                    xi = batch[batch_idx[0]]
                    truth = batch[batch_idx[1]]
                    idx = batch[batch_idx[2]].tolist()
                    all_ans = data_j[data_j["task"].isin(idx)]
                    all_ans.index = all_ans["task"]
                    all_ans = all_ans.reindex(idx)
                    capture = np.where(~np.isnan(all_ans["task"]))[0]
                    if len(capture) > 0:
                        self.optimizer.zero_grad()
                        labels = torch.tensor(all_ans["label"].values).type(
                            torch.long
                        )
                        xi, labels, idx = (
                            xi[capture],
                            labels[capture],
                            np.array(idx)[capture],
                        )
                        xi, labels = xi.cuda(), labels.cuda()
                        out = self.model(xi)
                        # print(out, labels, len(capture))
                        loss = self.criterion(out, labels)
                        loss.backward()
                        self.optimizer.step()

                        len_ = len(idx)
                        AUM_recorder["task"].extend(idx)
                        AUM_recorder["label"].extend(labels.tolist())
                        AUM_recorder["truth"].extend(truth.tolist())
                        AUM_recorder["worker"].extend([j] * len_)
                        AUM_recorder["epoch"].extend([epoch] * len_)

                        # s_y and P_y
                        if len_ > 1:
                            AUM_recorder["label_logit"].extend(
                                out.gather(1, labels.view(-1, 1))
                                .squeeze()
                                .tolist()
                            )
                            probs = out.softmax(dim=1)
                            AUM_recorder["label_prob"].extend(
                                probs.gather(1, labels.view(-1, 1))
                                .squeeze()
                                .tolist()
                            )
                        else:
                            AUM_recorder["label_logit"].extend(
                                out.gather(1, labels.view(-1, 1))
                                .squeeze(0)
                                .tolist()
                            )
                            probs = out.softmax(dim=1)
                            AUM_recorder["label_prob"].extend(
                                probs.gather(1, labels.view(-1, 1))
                                .squeeze(0)
                                .tolist()
                            )

                        # (s\y)[1] and (P\y)[1]
                        masked_logits = torch.scatter(
                            out, 1, labels.view(-1, 1), float("-inf")
                        )
                        masked_probs = torch.scatter(
                            probs, 1, labels.view(-1, 1), float("-inf")
                        )
                        (
                            other_logit_values,
                            other_logit_index,
                        ) = masked_logits.max(1)
                        (
                            other_prob_values,
                            other_prob_index,
                        ) = masked_probs.max(1)

                        if len(other_logit_values) > 1:
                            other_logit_values = other_logit_values.squeeze()
                            other_prob_values = other_prob_values.squeeze()
                        AUM_recorder["other_max_logit"].extend(
                            other_logit_values.tolist()
                        )
                        AUM_recorder["other_max_prob"].extend(
                            other_prob_values.tolist()
                        )

                        # s[2] ans P[2]
                        second_logit = torch.sort(out, axis=1)[0][:, -2]
                        second_prob = torch.sort(probs, axis=1)[0][:, -2]
                        AUM_recorder["secondlogit"].extend(
                            second_logit.tolist()
                        )
                        AUM_recorder["secondprob"].extend(second_prob.tolist())

                        if self.weighted:
                            for ll in range(len_):
                                AUM_recorder["score"].append(
                                    self.get_psuccess(probs[ll], pij)
                                    .cpu()
                                    .numpy()
                                )
            self.reset()

        # reinit model at the end
        self.reset()
        self.AUM_recorder = pd.DataFrame(AUM_recorder)
        recorder2 = self.AUM_recorder.copy()
        for task in tqdm(recorder2.task.unique()):
            tmp = recorder2[recorder2.task == task]
            for j in tmp.worker.unique():
                recorder2.loc[
                    recorder2[
                        (recorder2.task == task) & (recorder2.worker == j)
                    ].score.index,
                    "score",
                ] = tmp[(tmp.worker == j) & (tmp.epoch == 29)].score.values[0]
        self.AUM_recorder = recorder2

    def get_psuccess(self, probas, pij):
        with torch.no_grad():
            return probas @ torch.diag(pij)

    def get_psi1_waum(self):
        aum_df = self.AUM_recorder
        dico_cpt_aum = {}
        aum_df["margin"] = np.array(aum_df["label_prob"]) - np.array(
            aum_df["other_max_prob"]
        )
        unique_task = np.unique(np.array(aum_df["task"]))
        for i, each_task in tqdm(
            enumerate(unique_task),
            total=len(unique_task),
            desc="computing WAUM",
        ):
            temp = aum_df[aum_df["task"] == each_task]
            avg = []
            score = []
            for j in np.unique(np.array(temp["worker"])):
                tempj = temp[temp["worker"] == j]
                if self.weighted:
                    avg.append(
                        (
                            np.array(tempj["margin"])
                            * np.array(tempj["score"])
                        ).mean()
                    )
                    score.append(tempj["score"].iloc[0])
                else:
                    avg.append(np.array(tempj["margin"]).mean())
                    score.append(1.0)
            self.waum = dico_cpt_aum[each_task] = np.sum(avg) / sum(score)

    def get_psi5_waum(self):
        aum_df = self.AUM_recorder
        dico_cpt_aum = {}
        aum_df["margin"] = np.array(aum_df["label_prob"]) - np.array(
            aum_df["secondprob"]
        )
        unique_task = np.unique(np.array(aum_df["task"]))
        for i, each_task in tqdm(
            enumerate(unique_task),
            total=len(unique_task),
            desc="computing WAUM",
        ):
            temp = aum_df[aum_df["task"] == each_task]
            avg = []
            for j in np.unique(np.array(temp["worker"])):
                tempj = temp[temp["worker"] == j]
                score = []
                if self.weighted:
                    avg.append(
                        (
                            np.array(tempj["margin"])
                            * np.array(tempj["score"])
                        ).mean()
                    )
                    score.append(tempj["score"].iloc[0])
                else:
                    avg.append(np.array(tempj["margin"]).mean())
                    score.append(1.0)
            dico_cpt_aum[each_task] = np.sum(avg) / sum(score)
        self.waum = dico_cpt_aum

    def cut_lowests(self, alpha=0.01):
        quantile = np.quantile(list(self.waum.values()), alpha)
        tasks_too_hard = [
            index
            for index in list(self.waum.keys())
            if self.waum[index] <= quantile
        ]
        with open("./temp/cut", "wb") as fp:  # Pickling cut tasks
            pickle.dump(tasks_too_hard, fp)
        return tasks_too_hard

    def get_final_labels(self, pi):  # return soft distribution reweighted
        baseline = np.zeros((len(self.all_ans), self.n_classes))
        for task_id in list(self.all_ans_sp.keys()):
            task = self.all_ans_sp[task_id]
            for worker, vote in task.items():
                baseline[int(task_id), int(vote)] += pi[int(worker)][
                    int(vote), int(vote)
                ]
        self.baseline = baseline
        return baseline / baseline.sum(axis=1).reshape(-1, 1)

    def reset(self):
        check_ = torch.load("./temp/checkpoint_aum.pth")
        self.n_epoch = check_["epoch"]
        self.model.load_state_dict(self.checkpoint["model"])
        self.optimizer.load_state_dict(self.checkpoint["optimizer"])
        self.optimizer.param_groups[0]["lr"] = self.initial_lr
