import numpy as np
import torch, torch.nn as nn, torch.nn.functional as F


"""================================================================================================="""
ALLOWED_MINING_OPS  = None
REQUIRES_BATCHMINER = False
REQUIRES_OPTIM      = True


class Criterion(torch.nn.Module):
    def __init__(self, opt):
        """
        Args:
            opt: Namespace containing all relevant parameters.
        """
        super(Criterion, self).__init__()

        ####
        self.num_proxies        = opt.n_classes * 2
        self.embed_dim          = opt.embed_dim

        self.proxies            = torch.nn.Parameter(torch.randn(self.num_proxies, self.embed_dim)/8)
        self.class_idxs         = torch.stack((torch.arange(5), torch.arange(5))).T.reshape(1, -1)

        self.pos_proxy_inds = torch.arange(0, len(self.proxies), 2)
        self.neg_proxy_inds = torch.arange(1, len(self.proxies), 2)

        self.name           = 'multiproxyncanegative'
        self.label_arange = torch.arange(opt.n_classes)

        self.optim_dict_list = [{'params':self.proxies, 'lr':opt.lr * opt.loss_proxynca_lrmulti}]
        self.eps = 0
        
        ####
        self.ALLOWED_MINING_OPS  = ALLOWED_MINING_OPS
        self.REQUIRES_BATCHMINER = REQUIRES_BATCHMINER
        self.REQUIRES_OPTIM      = REQUIRES_OPTIM

        
    def forward(self, batch, labels, **kwargs):
        #Empirically, multiplying the embeddings during the computation of the loss seem to allow for more stable training;
        #Acts as a temperature in the NCA objective.
        batch   = 3*torch.nn.functional.normalize(batch, dim=1)
        proxies = 3*torch.nn.functional.normalize(self.proxies, dim=1)
        
        stacked_proxies = torch.stack((proxies[self.neg_proxy_inds], proxies[self.pos_proxy_inds])) # 2 x n_labels x emb_dim
        # proxy_list = [torch.stack([stacked_proxies[i, c, :] for c, i in enumerate(one_hot) ]) for one_hot in labels] # can be optimized
        proxy_list = [stacked_proxies[one_hot, self.label_arange, :] for one_hot in labels]

        #Compute distances between samples and proxies
        dist_to_proxies_list = [torch.sum((x - proxies).pow(2), dim=-1) for x, proxies in zip(batch, proxy_list)]
        all_sum = torch.stack([torch.logsumexp(-torch.clamp(dist - self.eps, min=0), dim=-1) for dist in dist_to_proxies_list]) 
        loss1 = -torch.mean(all_sum, dim = -1) # minimize distance from samples to proxies
        
        # compute distance between samples and other proxy
        other_proxy_list = [stacked_proxies[(~one_hot.bool()).long(), self.label_arange, :] for one_hot in labels]
        dist_to_other_proxies_list = [torch.sum((x - proxies).pow(2), dim=-1) for x, proxies in zip(batch, other_proxy_list)]
        all_sum_other = torch.stack([torch.logsumexp(-torch.clamp(dist - self.eps, min=0), dim=-1) for dist in dist_to_other_proxies_list]) 
        loss2 = torch.mean(all_sum_other, dim = -1) # maximize distance to other proxy
        
        # compute distances between proxies
        proxy_dists = (stacked_proxies[0, :, :] - stacked_proxies[1, :, :]).pow(2).sum(dim = -1)
        loss3 = torch.logsumexp(-proxy_dists, dim = 0) # maximize distance between proxies
        
       #  print(loss1, loss2, loss3)
        return loss1 + loss2 + loss3
