from .earlytrain import EarlyTrain
import torch
import torch.nn as nn
import numpy as np
import os
import matplotlib.pyplot as plt


class Uncertainty(EarlyTrain):
    def __init__(self, dst_train, args, mean, std, fraction=0.5, random_seed=None, epochs=200, selection_method="LeastConfidence",
                 specific_model=None, balance=False, **kwargs):
        super().__init__(dst_train, args, mean, std, fraction, random_seed, epochs, specific_model, **kwargs)

        selection_choices = ["LeastConfidence",
                             "Entropy",
                             "Margin",
                             "EL2N"]
        if selection_method not in selection_choices:
            raise NotImplementedError("Selection algorithm unavailable.")
        self.selection_method = selection_method

        self.epochs = epochs
        self.balance = balance

        self.scores = None

        self.repeat = self.args.repeats

        self.epoch = 0
        self.cycle = 0
        self.warmup = args.warmup_epochs

        self.pool_rate = 0.5

        self.fraction = fraction

        self.adapt_size = int(0.5 * len(self.dst_train))

        self.avg_scores = []

        self.original_state_dict = None

        self.cls_dist = np.zeros(self.num_classes)
        for target in self.dst_train_unlabel.targets:
            self.cls_dist[target] += 1
        print(f"Class-wise distribution of dataset {self.args.dataset}: ", self.cls_dist)

    def unlearn(self, ul_scores, correct_inds=None):
        '''
            Unlearn on samples with high uncertainty
        '''

        poison_ids = self.dst_train_unlabel.poison_ids
        clean_ids = self.dst_train_unlabel.clean_ids
        noise_ids = self.dst_train_unlabel.noise_ids

        if correct_inds is not None:
            mean_score = np.sum(ul_scores * correct_inds) / np.sum(correct_inds)
        else:
            mean_score = np.mean(ul_scores)

        num_ul = int(np.sum(ul_scores > mean_score))

        list_of_pool_idx = np.argsort(ul_scores)[::-1][:num_ul]

        cls_cnt_pool = np.zeros(self.dst_train_unlabel.num_classes)
        for id in list_of_pool_idx:
            target = self.dst_train_unlabel.targets[id]
            cls_cnt_pool[target] += 1

        ul_gamma = self.args.reg_gamma
        lambda_reg = 0.1  # 0.1 for the lowering of learning rate

        num_ul = len(list_of_pool_idx)

        ul_poi_ids = np.array([i for i in list_of_pool_idx if i in poison_ids])

        ul_noi_ids = np.array([i for i in list_of_pool_idx if i in noise_ids])

        print(f'\n=> Unlearning one epoch with gamma {ul_gamma:.5f} with L2Reg {lambda_reg:.4f} on {num_ul} (poi: {len(ul_poi_ids)}, noi: {len(ul_noi_ids)}) highly uncertain samples')
        print(f"=> Unlearn pool: class distribution {list(cls_cnt_pool.astype('int'))}")

        perm_list_of_unlearn_idx = np.random.permutation(list_of_pool_idx)

        unlearn_dataset = self.dst_pretrain_dict['dst_train'] if self.if_dst_pretrain else self.dst_train

        unlearn_sampler = torch.utils.data.BatchSampler(perm_list_of_unlearn_idx,
                                                        batch_size=self.args.selection_batch,
                                                        drop_last=False)

        unlearn_loader = torch.utils.data.DataLoader(unlearn_dataset, shuffle=False,
                                                     batch_sampler=unlearn_sampler,
                                                     num_workers=self.args.workers, pin_memory=True)

        smooth_v = self.args.unlearn_smooth   # 0.0-1.0, label gets more smoother
        ul_criterion = nn.CrossEntropyLoss(reduction='none', label_smoothing=smooth_v).to(self.args.device)
        print(f"=> Label smoothing = {smooth_v:.4f}")

        if self.original_state_dict is None:
            original_state_dict = {k: v.clone().detach() for k, v in self.model.state_dict().items()}
        else:
            original_state_dict = self.original_state_dict

        for i, (inputs, targets) in enumerate(unlearn_loader):
            inputs, targets = inputs.to(self.args.device), targets.to(self.args.device)

            # Forward propagation, compute loss, get predictions
            self.model_optimizer.zero_grad()
            outputs = self.model(inputs)

            ul_loss = ul_criterion(outputs, targets)

            # L2 regularization toward original weights
            l2_reg = 0
            for name, param in self.model.named_parameters():
                if name in original_state_dict:
                    l2_reg += torch.norm(param - original_state_dict[name]) ** 2

            # Update loss, backward propagate, update optimizer
            loss = ul_gamma * ul_loss.mean() + lambda_reg*l2_reg

            if i == 0 or i + 1 == len(unlearn_loader):
                print('| Unlearn Epoch: Iter[%3d/%3d]\t\tLoss: %.6f\tUL_Loss: %.6f\tL2Reg: %.6f' % (i + 1, len(unlearn_loader), loss.item(), ul_loss.mean(), l2_reg))

            loss.backward()

            self.model_optimizer.step()

    def before_train(self):
        pass

    def after_loss(self, outputs, loss, targets, batch_inds, epoch):
        pass

    def before_epoch(self):
        pass

    def after_epoch(self):
        calc_score_ind = True
        if self.args.prefix_coresize:
            calc_score_ind = True
        elif self.epoch == self.warmup:
            calc_score_ind = True
        else:
            calc_score_ind = False

        # if self.epoch == self.warmup:
        #     self.original_state_dict = {k: v.clone().detach() for k, v in self.model.state_dict().items()}
        self.original_state_dict = {k: v.clone().detach() for k, v in self.model.state_dict().items()}

        poison_ids = self.dst_train_unlabel.poison_ids
        clean_ids = self.dst_train_unlabel.clean_ids
        noise_ids = self.dst_train_unlabel.noise_ids

        if calc_score_ind:
            print("\n=> Run scores calculation ...")
            scores, correct_inds = self.rank_uncertainty(correct_out=True)
            avg_score = np.sum(scores * correct_inds) / np.sum(correct_inds)

            print(f"\n| Avg. Score (all) = {np.mean(scores):.4f}, Avg. Score (correct) = {avg_score:.4f}, Num. Correct = {int(np.sum(correct_inds))}")

        if self.args.unlearn and self.epoch > self.warmup:

            if self.args.unlearn:
                if calc_score_ind:
                    self.unlearn(ul_scores=scores, correct_inds=correct_inds)
                else:
                    self.unlearn(ul_scores=scores)

            if (self.dst_test is not None or self.dst_test_bad is not None) and self.args.selection_test_interval > 0 and (self.epoch + 1) % self.args.selection_test_interval == 0:
                if self.dst_test is not None:
                    self.test(self.epoch, mode='clean')
                if self.dst_test_bad is not None:
                    self.test(self.epoch, mode='bad')

            print("\n=> Run scores calculation after unlearning")
            scores, correct_inds = self.rank_uncertainty(correct_out=True)
            avg_score = np.sum(scores * correct_inds) / np.sum(correct_inds)

        if calc_score_ind or self.epoch >= self.warmup:
            if self.epoch == 0:
                self.scores = scores
            elif self.epoch < self.warmup:
                self.scores += scores
            elif self.epoch == self.warmup:
                if self.scores is None:
                    self.scores = scores
                if self.args.prefix_coresize:
                    self.coreset_size = int(np.sum(self.scores > np.sum(self.avg_scores))) if self.args.correct_only else int(np.sum(self.scores > np.mean(self.scores)))
                    print(f"\n=> Fix coreset size [{self.coreset_size}] with accumulative entropy of warm-up")

                else:
                    self.coreset_size = int(self.fraction*len(self.scores))
                    if not self.args.adaptsize:
                        print(f"\n=> Fix coreset size [{self.coreset_size}] by the pre-definition")

                self.scores = scores

                print("\n========== Restart entropy accumulation ==========\n")
            else:
                self.scores += scores

            # append averaged score only once!
            if self.epoch >= self.warmup:
                self.avg_scores.append(avg_score)
            elif self.epoch < self.warmup:
                self.avg_scores.append(avg_score)
            elif self.epoch == self.warmup:
                self.avg_scores = []
                self.avg_scores.append(avg_score)

            print("\n=> Intermediate selection:")

            if self.args.adaptsize:
                self.adapt_size = np.sum(scores > avg_score) if self.args.correct_only else np.sum(scores > np.mean(scores))
                print(f"Subset size of '> Avg. Uncertainty' = {self.adapt_size}")
                selection_result = np.argsort(scores)[::-1][:self.adapt_size]
            else:
                selection_result = np.argsort(scores)[::-1][:self.coreset_size]

            poison_scores = scores[poison_ids]
            clean_scores = scores[clean_ids]
            print(f"TrainSet: cln_un, poi_un = {clean_scores.mean():.4f}\t{poison_scores.mean():.4f}")

            num_subpoison = 0
            num_subnoise = 0
            cls_cnt = np.zeros(self.dst_train_unlabel.num_classes)
            for id in selection_result:
                target = self.dst_train_unlabel.targets[id]
                cls_cnt[target] += 1
                if id in self.dst_train_unlabel.poison_ids:
                    num_subpoison += 1
                if id in self.dst_train_unlabel.noise_ids:
                    num_subnoise += 1

            print(f"Subset: class distribution {list(cls_cnt.astype('int'))}")

            sub_cln_ids = np.array([i for i in selection_result if i not in poison_ids])
            clean_scores = scores[sub_cln_ids]

            if num_subpoison == 0:
                print(f'Subset: p_num, p_rate, n_num, cln_un, poi_un = {num_subpoison}\t{num_subpoison / len(selection_result):.4f}\t{num_subnoise}\t{clean_scores.mean():.4f}\tNone')
                print(f"Subset: min-max uncertainty (clean, poison): {clean_scores.min():.4f}\t{clean_scores.max():.4f}\tNone\tNone")
            else:
                sub_poi_ids = np.array([i for i in selection_result if i in poison_ids])
                poison_scores = scores[sub_poi_ids]

                print(f'Subset: p_num, p_rate, n_num, cln_un, poi_un = {num_subpoison}\t{num_subpoison / len(selection_result):.4f}\t{num_subnoise}\t{clean_scores.mean():.4f}\t{poison_scores.mean():.4f}')
                print(f"Subset: min-max uncertainty (clean, poison): {clean_scores.min():.4f}\t{clean_scores.max():.4f}\t{poison_scores.min():.4f}\t{poison_scores.max():.4f}")

            ################# Printing Accumulated Results #######################
            print("\n=> Accumulated selection:")

            if self.args.adaptsize:
                if self.args.prefix_coresize and self.epoch >= self.warmup:
                    adapt_size = self.coreset_size
                else:
                    if self.args.correct_only:
                        # avg_score = np.sum(self.avg_scores) if self.epoch > self.warmup else avg_score
                        if self.epoch > self.warmup:
                            avg_score = np.sum(self.avg_scores)
                        elif self.args.prefix_coresize:
                            avg_score = np.sum(self.avg_scores)
                        else:
                            avg_score = avg_score
                    else:
                        avg_score = np.mean(self.scores)
                    adapt_size = np.sum(self.scores > avg_score)
                print(f"Subset size of '> Avg. Uncertainty' = {adapt_size}")
                selection_result = np.argsort(self.scores)[::-1][:adapt_size]
            else:
                selection_result = np.argsort(self.scores)[::-1][:self.coreset_size]

            accu_poison_scores = self.scores[poison_ids]
            accu_clean_scores = self.scores[clean_ids]
            print(f"TrainSet: cln_un, poi_un = {accu_clean_scores.mean():.4f}\t{accu_poison_scores.mean():.4f}")

            accu_num_subpoison = 0
            accu_num_subnoise = 0
            cls_cnt = np.zeros(self.dst_train_unlabel.num_classes)
            for id in selection_result:
                target = self.dst_train_unlabel.targets[id]
                cls_cnt[target] += 1
                if id in self.dst_train_unlabel.poison_ids:
                    accu_num_subpoison += 1
                if id in self.dst_train_unlabel.noise_ids:
                    accu_num_subnoise += 1

            print(f"Subset: class distribution {list(cls_cnt.astype('int'))}")

            sub_cln_ids = np.array([i for i in selection_result if i not in poison_ids]).astype("int")
            accu_clean_scores = self.scores[sub_cln_ids]

            if accu_num_subpoison == 0:
                print(f'Subset: p_num, p_rate, n_num, cln_un, poi_un = {accu_num_subpoison}\t{accu_num_subpoison / len(selection_result):.4f}\t{accu_num_subnoise}\t{accu_clean_scores.mean():.4f}\tNone')
                print(f"Subset: min-max uncertainty (clean, poison): {accu_clean_scores.min():.4f}\t{accu_clean_scores.max():.4f}\tNone\tNone")
            else:
                sub_poi_ids = np.array([i for i in selection_result if i in poison_ids])
                accu_poison_scores = self.scores[sub_poi_ids]

                print(f'Subset: p_num, p_rate, n_num, cln_un, poi_un = {accu_num_subpoison}\t{accu_num_subpoison/len(selection_result):.4f}\t{accu_num_subnoise}\t{accu_clean_scores.mean():.4f}\t{accu_poison_scores.mean():.4f}')
                print(f"Subset: min-max uncertainty (clean, poison): {accu_clean_scores.min():.4f}\t{accu_clean_scores.max():.4f}\t{accu_poison_scores.min():.4f}\t{accu_poison_scores.max():.4f}")

        self.epoch += 1

    def before_run(self):
        pass

    def num_classes_mismatch(self):
        raise ValueError("num_classes of pretrain dataset does not match that of the training dataset.")

    def while_update(self, outputs, loss, targets, epoch, batch_idx, batch_size):
        if batch_idx % self.args.print_freq == 0:
            print('| Epoch [%3d/%3d] Iter[%3d/%3d]\t\tLoss: %.4f' % (
            epoch, self.epochs, batch_idx + 1, (self.n_pretrain_size // batch_size) + 1, loss.item()))

    def finish_run(self):
        if self.args.accumulative:
            print("\n=> Subset selection by accumulative uncertainty")
            scores = self.scores
            avg_score = np.sum(self.avg_scores) if self.args.correct_only else np.mean(scores)
        else:
            print("\n=> Subset selection by one-shot uncertainty")
            scores, correct_inds = self.rank_uncertainty(correct_out=True)
            avg_score = np.sum(scores * correct_inds)/np.sum(correct_inds) if self.args.correct_only else np.mean(scores)

        split_score = avg_score if self.args.adaptsize else np.percentile(scores, self.args.fraction*100)

        if self.cycle+1 == self.repeat:
            if self.args.accumulative:
                all_scores = self.scores
            else:
                all_scores = scores
            poison_ids = self.dst_train_unlabel.poison_ids
            clean_ids = self.dst_train_unlabel.clean_ids

            outdir = f"./dist_plots/{self.args.dataset}/{self.args.model}/Uncertainty/{self.selection_method}"
            if not os.path.exists(outdir):
                os.makedirs(outdir)

            f_outname = f"{outdir}/{self.args.trigger_type}_{self.args.inject_portion}_{self.args.selection_epochs}_{self.args.selection_optimizer}_LR{self.args.selection_lr}"
            if self.args.accumulative:
                f_outname = f"{f_outname}_Accum"
            if self.args.unlearn:
                f_outname = f"{f_outname}_UL{self.args.reg_gamma}"
            if self.args.warmup_epochs != 0:
                f_outname = f"{f_outname}_WarmUp{self.args.warmup_epochs}"
            if self.args.norm_score:
                f_outname = f"{f_outname}_SNorm"

            np.save(f"{f_outname}_cleanIDs.npy", clean_ids)
            np.save(f"{f_outname}_poisonIDs.npy", poison_ids)
            np.save(f"{f_outname}_scores.npy", all_scores)

            scores_poi = all_scores[poison_ids]
            scores_cln = all_scores[clean_ids]
            _, bins = np.histogram(all_scores, bins=100)
            fig = plt.figure(figsize=(9, 6))
            ax = fig.add_axes((0.13, 0.16, 0.84, 0.77))
            ax.hist(scores_cln, bins=bins, color="#006BA4", alpha=1.0, label='Benign samples')
            ax.hist(scores_poi, bins=bins, color="#ff800e", alpha=0.8, label='Poisonous samples')
            ylims = ax.get_ylim()
            ax.plot([split_score, split_score], [0, ylims[1]], color='black', linestyle='dashed', linewidth=2)
            # ax.set_ylim(0, 1500)
            ax.set_xlabel(self.selection_method)
            ax.set_ylabel("Number of Samples")

            pic_name = f"{f_outname}.png"
            plt.savefig(pic_name)
            plt.close(fig)

        if self.balance:
            selection_result = np.array([], dtype=np.int64)
            unlabel_result = np.array([], dtype=np.int64)
            scores_bal = []
            for c in range(self.args.num_classes):
                class_index = np.arange(self.n_train)[self.dst_train_unlabel.targets == c]
                scores_bal.append(scores[class_index])
                selection_result = np.append(selection_result, class_index[
                    np.argsort(scores_bal[-1])[::-1][:round(len(class_index) * self.fraction)]])

                unlabel_result = np.append(selection_result, class_index[np.argsort(scores_bal[-1])[::-1][round(
                    len(class_index) * self.fraction):2 * round(len(class_index) * self.fraction)]])
        else:
            if self.args.adaptsize:
                if self.args.prefix_coresize:
                    adapt_size = self.coreset_size
                else:
                    adapt_size = np.sum(scores > avg_score) if self.args.correct_only else np.sum(scores > np.mean(scores))
                print(f"Subset size of '> Avg. Uncertainty' = {adapt_size}")
                selection_result = np.argsort(scores)[::-1][:adapt_size]
                unlabel_result = np.argsort(scores)[::-1][adapt_size:]
            else:
                selection_result = np.argsort(scores)[::-1][:self.coreset_size]
                unlabel_result = np.argsort(scores)[::-1][self.coreset_size:2*self.coreset_size]

        return {"indices": selection_result, "u_indices": unlabel_result, "scores": scores}

    def rank_uncertainty(self, index=None, correct_out=False):
        self.model.eval()
        with torch.no_grad():
            train_loader = torch.utils.data.DataLoader(
                self.dst_train_unlabel if index is None else torch.utils.data.Subset(self.dst_train_unlabel, index),
                batch_size=self.args.selection_batch,
                num_workers=self.args.workers)

            scores = np.array([])
            batch_num = len(train_loader)

            correct_inds = np.array([])
            for i, (input, target) in enumerate(train_loader):
                if i % self.args.print_freq == 0:
                    print("| Selecting for batch [%3d/%3d]" % (i + 1, batch_num))

                preds_logits = self.model(input.to(self.args.device))
                pred_labels = np.argmax(preds_logits.detach().cpu().numpy(), axis=1)
                corrects = np.equal(target.data.view(-1).numpy(), pred_labels)
                # _, pred_labels = preds_logits.topk(1, 1, True, True)
                # pred_labels = pred_labels.t().detach().cpu()
                # corrects = pred_labels.eq(target.view(1, -1).expand_as(pred_labels))
                correct_inds = np.concatenate((correct_inds, corrects), axis=0)

                if self.selection_method == "LeastConfidence":
                    lc_scores = torch.nn.functional.softmax(preds_logits, dim=1).min(axis=1).values.cpu().numpy()
                    scores = np.append(scores, lc_scores)
                elif self.selection_method == "Entropy":
                    preds = torch.nn.functional.softmax(preds_logits/self.args.uncertainty_temperature, dim=1).cpu().numpy()
                    scores = np.append(scores, (-np.log(preds + 1e-6) * preds).sum(axis=1))
                elif self.selection_method == 'Margin':
                    # preds_logits = self.model(input.to(self.args.device))
                    preds = torch.nn.functional.softmax(preds_logits, dim=1)
                    preds_argmax = torch.argmax(preds, dim=1)
                    max_preds = preds[torch.ones(preds.shape[0], dtype=bool), preds_argmax].clone()
                    preds[torch.ones(preds.shape[0], dtype=bool), preds_argmax] = -1.0
                    preds_sub_argmax = torch.argmax(preds, dim=1)

                    # top2-top1 < 0: better samples -> smaller margin -> larger top2-top1
                    scores = np.append(scores, (preds[torch.ones(preds.shape[0], dtype=bool), preds_sub_argmax] - max_preds).cpu().numpy())
                elif self.selection_method == "EL2N":
                    probs = torch.nn.functional.softmax(preds_logits, dim=1)

                    # One-hot encode labels
                    labels_one_hot = torch.zeros_like(probs)
                    labels_one_hot[torch.arange(len(target)), target] = 1

                    # Compute L2 norm
                    l2_norms = torch.norm(probs - labels_one_hot, dim=1).cpu().numpy()

                    scores = np.append(scores, l2_norms)

            if self.args.norm_score:
                # scores = (scores - scores.mean())/scores.std()
                scores = (scores - scores.min())/(scores.max() - scores.min())

        if correct_out:
            return scores, np.array(correct_inds)
        else:
            return scores

    def select(self, **kwargs):
        selection_result = None
        for cycle in range(self.repeat):
            print(f"\n----------------------------------- Round {cycle+1} -----------------------------------\n")
            self.cycle = cycle

            if self.scores is None:  # Scores is empty in the first cycle
                list_of_train_idxes = np.arange(self.n_train)
            else:
                list_of_train_idxes = np.arange(self.n_train)

            selection_result = self.run(list_of_train_idxes=list_of_train_idxes, label_smooth=False)

            # Re-init all parameters
            self.random_seed += 10
            self.epoch = 0

        return selection_result
