import torch
import time
import torch.distributed as dist

class DataSelectionStrategy(object):
    """
    Implementation of Data Selection Strategy class which serves as base class for other
    dataselectionstrategies for general learning frameworks.
    Parameters
    ----------
        trainloader: class
            Loading the training data using pytorch dataloader
        valloader: class
            Loading the validation data using pytorch dataloader
        model: class
            Model architecture used for training
        num_classes: int
            Number of target classes in the dataset
        linear_layer: bool
            If True, we use the last fc layer weights and biases gradients
            If False, we use the last fc layer biases gradients
        loss: class
            PyTorch Loss function
        device: str
            The device being utilized - cpu | cuda
        logger: class
            logger object for logging the information
    """

    def __init__(self, trainloader, valloader, model, num_classes, linear_layer, loss, device, logger):
        """
        Constructor method
        """
        self.is_distributed = dist.is_initialized()
        self.rank = dist.get_rank() if self.is_distributed else 0
        self.trainloader = trainloader
        self.valloader = valloader
        self.model = model
        self.N_trn = len(trainloader.sampler)
        self.N_val = len(valloader.sampler)
        self.grads_per_elem = None
        self.val_grads_per_elem = None
        self.numSelected = 0
        self.linear_layer = linear_layer
        self.num_classes = num_classes
        self.trn_lbls = None
        self.val_lbls = None
        self.loss = loss
        self.device = device
        self.logger = logger

    def select(self, budget, model_params):
        pass

    def get_labels(self, valid=False):
        if isinstance(self.trainloader.dataset[0], dict):
            for batch_idx, batch in enumerate(self.trainloader):
                if batch_idx == 0:
                    self.trn_lbls = batch['labels'].view(-1, 1)
                else:
                    self.trn_lbls = torch.cat((self.trn_lbls, batch['labels'].view(-1, 1)), dim=0)
        else:
            for batch_idx, (_, targets) in enumerate(self.trainloader):
                if batch_idx == 0:
                    self.trn_lbls = targets.view(-1, 1)
                else:
                    self.trn_lbls = torch.cat((self.trn_lbls, targets.view(-1, 1)), dim=0)
        self.trn_lbls = self.trn_lbls.view(-1)

        if valid:
            if isinstance(self.valloader.dataset[0], dict):
                for batch_idx, batch in enumerate(self.valloader):
                    if batch_idx == 0:
                        self.val_lbls = batch['labels'].view(-1, 1)
                    else:
                        self.val_lbls = torch.cat((self.val_lbls, batch['labels'].view(-1, 1)), dim=0)
            else:
                for batch_idx, (_, targets) in enumerate(self.valloader):
                    if batch_idx == 0:
                        self.val_lbls = targets.view(-1, 1)
                    else:
                        self.val_lbls = torch.cat((self.val_lbls, targets.view(-1, 1)), dim=0)
            self.val_lbls = self.val_lbls.view(-1)

    def compute_gradients(self, valid=False, perBatch=False, perClass=False):
        if self.is_distributed:
            for param in self.model.parameters():
                dist.broadcast(param.data, src=0)


        start_time = time.time()
        if hasattr(self.model, 'module'):
            embDim = self.model.module.get_embedding_dim()
        else:
            embDim = self.model.get_embedding_dim()        
        trainloader = self.trainloader
        train_end = time.time()
        for batch_idx, (inputs, targets) in enumerate(trainloader): 
            inputs, targets = inputs.to(self.device), targets.to(self.device, non_blocking=True)

            if batch_idx == 0:
                out, l1 = self.model(inputs, last=True, freeze=True)
                loss = self.loss(out, targets).sum()
                l0_grads = torch.autograd.grad(loss, out)[0]
                
                if self.linear_layer:
                    l0_expand = torch.repeat_interleave(l0_grads, embDim, dim=1)
                    l1_grads = l0_expand * l1.repeat(1, self.num_classes)
                if perBatch:
                    l0_grads = l0_grads.mean(dim=0).view(1, -1)
                    if self.linear_layer:
                        l1_grads = l1_grads.mean(dim=0).view(1, -1)
            else:
                out, l1 = self.model(inputs, last=True, freeze=True)
                loss = self.loss(out, targets).sum()
                batch_l0_grads = torch.autograd.grad(loss, out)[0]
                if self.linear_layer:
                    batch_l0_expand = torch.repeat_interleave(batch_l0_grads, embDim, dim=1)
                    batch_l1_grads = batch_l0_expand * l1.repeat(1, self.num_classes)

                if perBatch:
                    batch_l0_grads = batch_l0_grads.mean(dim=0).view(1, -1)
                    if self.linear_layer:
                        batch_l1_grads = batch_l1_grads.mean(dim=0).view(1, -1)
                l0_grads = torch.cat((l0_grads, batch_l0_grads), dim=0)
                if self.linear_layer:
                    l1_grads = torch.cat((l1_grads, batch_l1_grads), dim=0)
        train_end = time.time()

        torch.cuda.empty_cache()

        if self.linear_layer:
            self.grads_per_elem = torch.cat((l0_grads, l1_grads), dim=1)
        else:
            self.grads_per_elem = l0_grads

        train_end = time.time()
    
    def compute_gradients_small(self, valid=False, perBatch=False, perClass=False):
        if self.is_distributed:
            for param in self.model.parameters():
                dist.broadcast(param.data, src=0)
        start_time = time.time()
        if hasattr(self.model, 'module'):
            embDim = self.model.module.get_embedding_dim()
        else:
            embDim = self.model.get_embedding_dim()
        trainloader = self.trainloader
        train_end = time.time()
        for batch_idx, (inputs, targets) in enumerate(trainloader):
            if batch_idx >= 36:
                break  
            inputs, targets = inputs.to(self.device), targets.to(self.device, non_blocking=True)
            if batch_idx == 0:
                out, l1 = self.model(inputs, last=True, freeze=True)
                loss = self.loss(out, targets).sum()
                l0_grads = torch.autograd.grad(loss, out)[0]
                if self.linear_layer:
                    l0_expand = torch.repeat_interleave(l0_grads, embDim, dim=1)
                    l1_grads = l0_expand * l1.repeat(1, self.num_classes)
                # if perBatch:
                #     l0_grads = l0_grads.mean(dim=0).view(1, -1)
                #     if self.linear_layer:
                #         l1_grads = l1_grads.mean(dim=0).view(1, -1)
            else:
                out, l1 = self.model(inputs, last=True, freeze=True)
                loss = self.loss(out, targets).sum()
                batch_l0_grads = torch.autograd.grad(loss, out)[0]
                if self.linear_layer:
                    batch_l0_expand = torch.repeat_interleave(batch_l0_grads, embDim, dim=1)
                    batch_l1_grads = batch_l0_expand * l1.repeat(1, self.num_classes)

                # if perBatch:
                #     batch_l0_grads = batch_l0_grads.mean(dim=0).view(1, -1)
                #     if self.linear_layer:
                #         batch_l1_grads = batch_l1_grads.mean(dim=0).view(1, -1)
                l0_grads = torch.cat((l0_grads, batch_l0_grads), dim=0)
                if self.linear_layer:
                    l1_grads = torch.cat((l1_grads, batch_l1_grads), dim=0)
        train_end = time.time()

        torch.cuda.empty_cache()

        if self.linear_layer:
            self.grads_per_elem = torch.cat((l0_grads, l1_grads), dim=1)
        else:
            self.grads_per_elem = l0_grads
        train_end = time.time()

    def update_model(self, model_params):
        """
        Update the models parameters

        Parameters
        ----------
        model_params: OrderedDict
            Python dictionary object containing models parameters
        """
        self.model.load_state_dict(model_params)