import numpy as np
import torch, torch.nn as nn, torch.nn.functional as F
import batchminer


"""================================================================================================="""
ALLOWED_MINING_OPS  = None
REQUIRES_BATCHMINER = False
REQUIRES_OPTIM      = True


class Criterion(torch.nn.Module):
    def __init__(self, opt):
        """
        Args:
            opt: Namespace containing all relevant parameters.
        """
        super(Criterion, self).__init__()

        ####
        self.num_proxies        = opt.n_classes
        self.embed_dim          = opt.embed_dim

        self.proxies            = torch.nn.Parameter(torch.randn(self.num_proxies, self.embed_dim)/8)
        self.class_idxs         = torch.arange(self.num_proxies)

        self.name           = 'multiproxynca'

        self.optim_dict_list = [{'params':self.proxies, 'lr':opt.lr * opt.loss_proxynca_lrmulti}]

        if opt.loss_proxynca_eps:
            size = opt.n_classes if opt.loss_proxynca_eps_per else 1
            if opt.loss_proxynca_eps_constant:
                self.eps = opt.loss_proxynca_eps
            else:
                self.eps = torch.nn.Parameter(torch.ones(size) * opt.loss_proxynca_eps)
                self.optim_dict_list += [{'params':self.eps, 'lr':opt.lr * opt.loss_proxynca_eps_lrmulti}]
        else:
            self.eps = 0    
        
        ####
        self.ALLOWED_MINING_OPS  = ALLOWED_MINING_OPS
        self.REQUIRES_BATCHMINER = REQUIRES_BATCHMINER
        self.REQUIRES_OPTIM      = REQUIRES_OPTIM

        
    def forward(self, batch, labels, **kwargs):
        #Empirically, multiplying the embeddings during the computation of the loss seem to allow for more stable training;
        #Acts as a temperature in the NCA objective.
        batch   = 3*torch.nn.functional.normalize(batch, dim=1)
        proxies = 3*torch.nn.functional.normalize(self.proxies, dim=1)
        #Group required proxies
        pos_proxies_list = [proxies[one_hot.bool(),:] for one_hot in labels]
        neg_proxies_list = [proxies[one_hot.bool().logical_not(),:] for one_hot in labels]
        #Compute Proxy-distances
        dist_to_neg_proxies_list = [torch.sum((x - neg_proxies).pow(2), dim=-1) for x, neg_proxies in zip(batch, neg_proxies_list)]
        dist_to_pos_proxies_list = [torch.sum((x - pos_proxies).pow(2), dim=-1) for x, pos_proxies in zip(batch, pos_proxies_list)]
        
        #Group required epsilons
        if isinstance(self.eps, torch.nn.Parameter) and self.eps.size()[0] > 1:
            pos_eps_list = [self.eps[one_hot.bool(),:] for one_hot in labels]
            neg_eps_list = [self.eps[one_hot.bool().logical_not(),:] for one_hot in labels]
            pos_sum = torch.stack([torch.logsumexp(-torch.clamp(dist - eps, min=0), dim=-1) for dist, eps in zip(dist_to_pos_proxies_list, pos_eps_list)])
            neg_sum = torch.stack([torch.logsumexp(-torch.clamp(dist - eps, min=0), dim=-1) for dist, eps in zip(dist_to_neg_proxies_list, neg_eps_list)])
        else:
            pos_sum = torch.stack([torch.logsumexp(-torch.clamp(dist - self.eps, min=0), dim=-1) for dist in dist_to_pos_proxies_list])            
            neg_sum = torch.stack([torch.logsumexp(-torch.clamp(dist - self.eps, min=0), dim=-1) for dist in dist_to_neg_proxies_list])

        #Compute final proxy-based NCA loss
        loss = torch.mean(neg_sum - pos_sum, dim = -1)
        return loss
