from .earlytrain import EarlyTrain
import torch
import numpy as np


class DeepFool(EarlyTrain):
    def __init__(self, dst_train, args, mean, std, fraction=0.5, random_seed=None, epochs=200,
                 specific_model=None, balance: bool = False, max_iter: int = 50, **kwargs):
        super().__init__(dst_train, args, mean, std, fraction, random_seed, epochs, specific_model, **kwargs)

        self.balance = balance
        self.max_iter = max_iter

        self.mean = mean
        self.std = std

    def num_classes_mismatch(self):
        raise ValueError("num_classes of pretrain dataset does not match that of the training dataset.")

    def while_update(self, outputs, loss, targets, epoch, batch_idx, batch_size):
        if batch_idx % self.args.print_freq == 0:
            print('| Epoch [%3d/%3d] Iter[%3d/%3d]\t\tLoss: %.4f' % (
                epoch, self.epochs, batch_idx + 1, (self.n_pretrain_size // batch_size) + 1, loss.item()))

    def finish_run(self):
        self.model.no_grad = False

        # Create a data loader for self.dst_train with batch size self.args.selection_batch
        batch_loader = torch.utils.data.DataLoader(self.dst_train, batch_size=self.args.selection_batch
                                                   , num_workers=self.args.workers)

        r = np.zeros(self.n_train, dtype=np.float32)
        batch_num = len(batch_loader)
        for i, (inputs, targets) in enumerate(batch_loader):
            if i % self.args.print_freq == 0:
                print('| Selecting Batch [%3d/%3d]' % (i + 1, batch_num))
            r[(i * self.args.selection_batch):(i * self.args.selection_batch + targets.shape[0])] = self.deep_fool(
                inputs)

        if self.balance:
            selection_result = np.array([], dtype=np.int64)
            for c in range(self.args.num_classes):
                class_index = np.arange(self.n_train)[self.dst_train.targets == c]
                selection_result = np.append(selection_result, class_index[
                    r[class_index].argsort()[:round(len(class_index) * self.fraction)]])
        else:
            selection_result = r.argsort()[:self.coreset_size]
        return {"indices": selection_result, "scores": r}

    def deep_fool(self, inputs):
        # Here, start running DeepFool algorithm.
        self.model.eval()

        # Initialize a boolean mask indicating if selection has been stopped at corresponding positions.
        sample_size = inputs.shape[0]
        boolean_mask = np.ones(sample_size, dtype=bool)
        all_idx = np.arange(sample_size)

        # A matrix to store total pertubations.
        r_tot = np.zeros([sample_size, inputs.shape[1] * inputs.shape[2] * inputs.shape[3]])

        # Set requires_grad for inputs.
        cur_inputs = inputs.requires_grad_(True).to(self.args.device)

        original_shape = inputs.shape[1:]

        # set requires_grad for all parametres in network as False to accelerate autograd
        for p in self.model.parameters():
            p.requires_grad_(False)

        self.model.no_grad = True
        first_preds = self.model(cur_inputs).argmax(dim=1)
        self.model.no_grad = False

        for i in range(self.max_iter):
            f_all = self.model(cur_inputs)

            w_k = []
            for c in range(self.args.num_classes):
                w_k.append(torch.autograd.grad(f_all[:, c].sum(), cur_inputs,
                                               retain_graph=False if c + 1 == self.args.num_classes else True)[
                               0].flatten(1))
            w_k = torch.stack(w_k, dim=0)
            w_k = w_k - w_k[first_preds, boolean_mask[boolean_mask]].unsqueeze(0)
            w_k_norm = w_k.norm(dim=2)

            w_k_norm[first_preds, boolean_mask[
                boolean_mask]] = 1.  # Set w_k_norm for preds positions to 1. to avoid division by zero.

            l_all = (f_all - f_all[boolean_mask[boolean_mask], first_preds].unsqueeze(1)).detach().abs() / w_k_norm.T
            l_all[boolean_mask[
                      boolean_mask], first_preds] = np.inf  # Set l_k for preds positions to inf, as the argmin for each
                                                            # row will be calculated soon.

            l_hat = l_all.argmin(dim=1)
            r_i = l_all[boolean_mask[boolean_mask], l_hat].unsqueeze(1) / w_k_norm[
                l_hat, boolean_mask[boolean_mask]].T.unsqueeze(1) * w_k[l_hat, boolean_mask[boolean_mask]]

            # Update r_tot values.
            r_tot[boolean_mask] += r_i.cpu().numpy()

            cur_inputs += r_i.reshape([r_i.shape[0]] + list(original_shape))

            # Re-input the updated sample into the network and get new predictions.
            self.model.no_grad = True
            preds = self.model(cur_inputs).argmax(dim=1)
            self.model.no_grad = False

            # In DeepFool algorithm, the iteration stops when the updated sample produces a different prediction
            # in the model.
            index_unfinished = (preds == first_preds)
            if torch.all(~index_unfinished):
                break

            cur_inputs = cur_inputs[index_unfinished]
            first_preds = first_preds[index_unfinished]
            boolean_mask[all_idx[boolean_mask][~index_unfinished.cpu().numpy()]] = False

        return (r_tot * r_tot).sum(axis=1)

    def select(self, **kwargs):
        selection_result = self.run()
        return selection_result
