from .earlytrain import EarlyTrain
import torch, time
import numpy as np
from ..nets.nets_utils import MyDataParallel
from scipy.special import softmax


class GraNdSampling(EarlyTrain):
    def __init__(self, dst_train, args, fraction=0.5, random_seed=None, epochs=200, repeat=10,
                 specific_model=None, balance=False, **kwargs):
        super().__init__(dst_train, args, fraction, random_seed, epochs, specific_model, **kwargs)
        self.epochs = epochs
        self.n_train = len(dst_train)
        self.coreset_size = round(self.n_train * fraction)
        self.specific_model = specific_model
        self.repeat = repeat

        self.balance = balance
        self.probabilities = None

    def while_update(self, outputs, loss, targets, epoch, batch_idx, batch_size):
        if batch_idx % self.args.print_freq == 0:
            print('| Epoch [%3d/%3d] Iter[%3d/%3d]\t\tLoss: %.4f' % (
                epoch, self.epochs, batch_idx + 1, (self.n_train // batch_size) + 1, loss.item()))

    def before_run(self):
        if isinstance(self.model, MyDataParallel):
            self.model = self.model.module

    def finish_run(self):
        self.model.embedding_recorder.record_embedding = True  # recording embedding vector

        self.model.eval()

        embedding_dim = self.model.get_last_layer().in_features
        batch_loader = torch.utils.data.DataLoader(
            self.dst_train, batch_size=self.args.selection_batch, num_workers=self.args.workers)
        sample_num = self.n_train

        for i, data in enumerate(batch_loader):
            input, targets = data[0], data[1]

            self.model_optimizer.zero_grad()
            outputs = self.model(input.to(self.args.device))
            loss = self.criterion(torch.nn.functional.softmax(outputs.requires_grad_(True), dim=1),
                                  targets.to(self.args.device)).sum()
            batch_num = targets.shape[0]
            with torch.no_grad():
                bias_parameters_grads = torch.autograd.grad(loss, outputs)[0]
                self.norm_matrix[i * self.args.selection_batch:min((i + 1) * self.args.selection_batch, sample_num),
                self.cur_repeat] = torch.norm(torch.cat([bias_parameters_grads, (
                        self.model.embedding_recorder.embedding.view(batch_num, 1, embedding_dim).repeat(1,
                                             self.args.num_classes, 1) * bias_parameters_grads.view(
                                             batch_num, self.args.num_classes, 1).repeat(1, 1, embedding_dim)).
                                             view(batch_num, -1)], dim=1), dim=1, p=2)

        self.model.train()
        self.model.embedding_recorder.record_embedding = False

    def sample(self):
        return {'indices': np.random.choice(range(self.n_train), round(self.fraction * self.n_train),
                                            p=self.probabilities, replace=False)}

    def select(self, **kwargs):
        if self.probabilities is None:
            # Initialize a matrix to save norms of each sample on idependent runs
            self.norm_matrix = torch.zeros([self.n_train, self.repeat], requires_grad=False).to(self.args.device)

            warmup_test_acc = 0.0
            for self.cur_repeat in range(self.repeat):
                selection_result, warmup_test_acc = self.run()
                self.random_seed = int(time.time() * 1000) % 100000

            self.norm_mean = torch.mean(self.norm_matrix, dim=1).cpu().detach().numpy()
            if not self.balance:
                top_examples = self.train_indx[np.argsort(self.norm_mean)][::-1][:self.coreset_size]
                self.probabilities = softmax(self.norm_mean)
            else:
                top_examples = np.array([], dtype=np.int64)
                for c in range(self.num_classes):
                    c_indx = self.train_indx[self.dst_train.targets == c]
                    budget = round(self.fraction * len(c_indx))
                    top_examples = np.append(top_examples, c_indx[np.argsort(self.norm_mean[c_indx])[::-1][:budget]])

        return self.sample(), 0.0
