import torch
import torch.nn as nn


class fmutual(nn.Module):
    """
    Build a MoCo model with: a query encoder, a key encoder, and a queue
    https://arxiv.org/abs/1911.05722
    """
    def __init__(self, base_encoder, dim=128, K=65536, m=0.999, T=0.07, mlp=False, alpha=40, tsallis=3, estimator='Gaussian-KL'):
        """
        dim: feature dimension (default: 128)
        K: queue size; number of negative keys (default: 65536)
        m: moco momentum of updating key encoder (default: 0.999)
        T: softmax temperature (default: 0.07)
        """
        super(fmutual, self).__init__()
        
        self.alpha=alpha
        self.tsallis = tsallis
        self.estimator = estimator
        self.K = K
        self.m = m
        self.T = T

        # create the encoders
        # num_classes is the output fc dimension
        self.encoder_q = base_encoder(num_classes=dim)

        if mlp:  # hack: brute-force replacement
            dim_mlp = self.encoder_q.fc.weight.shape[1]
            self.encoder_q.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_q.fc)



    @torch.no_grad()
    def _batch_shuffle_ddp(self, x):
        """
        Batch shuffle, for making use of BatchNorm.
        *** Only support DistributedDataParallel (DDP) model. ***
        """
        # gather from all gpus
        batch_size_this = x.shape[0]
        x_gather = concat_all_gather(x)
        batch_size_all = x_gather.shape[0]

        num_gpus = batch_size_all // batch_size_this

        # random shuffle index
        idx_shuffle = torch.randperm(batch_size_all).cuda()

        # broadcast to all gpus
        torch.distributed.broadcast(idx_shuffle, src=0)

        # index for restoring
        idx_unshuffle = torch.argsort(idx_shuffle)

        # shuffled index for this gpu
        gpu_idx = torch.distributed.get_rank()
        idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx]

        return x_gather[idx_this], idx_unshuffle

    @torch.no_grad()
    def _batch_unshuffle_ddp(self, x, idx_unshuffle):
        """
        Undo batch shuffle.
        *** Only support DistributedDataParallel (DDP) model. ***
        """
        # gather from all gpus
        batch_size_this = x.shape[0]
        x_gather = concat_all_gather(x)
        batch_size_all = x_gather.shape[0]

        num_gpus = batch_size_all // batch_size_this

        # restored index for this gpu
        gpu_idx = torch.distributed.get_rank()
        idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx]

        return x_gather[idx_this]

    def forward(self, im_q, im_k):
        """
        Input:
            im_q: a batch of query images
            im_k: a batch of key images
        Output:
            logits, targets
        """

        # compute query features
        q = self.encoder_q(im_q)  # queries: NxC
        q = nn.functional.normalize(q, dim=1)

        k = self.encoder_q(im_k)  # keys: NxC
        k = nn.functional.normalize(k, dim=1)  # already normalized


        if self.estimator == 'Gaussian-KL':
            l_pos = (q - k).norm(p=2, dim=1).pow(2).mean() 
            l_neg = torch.pdist(q, p=2).pow(2).mul(-2).exp().mean()+torch.pdist(k, p=2).pow(2).mul(-2).exp().mean()
            l_neg /= 2.0
            loss2 = l_pos + self.alpha * l_neg

        if self.estimator == 'Gaussian-Pearson':
            l_pos = -(q - k).norm(p=2, dim=1).pow(2).mul(-1).exp().mean()
            l_neg = torch.pdist(q, p=2).pow(2).mul(-2).exp().mean()+torch.pdist(k, p=2).pow(2).mul(-2).exp().mean()
            l_neg /= 2.0
            loss2 = l_pos + l_neg

        if self.estimator == 'Gaussian-JS':
            l_pos = (q - k).norm(p=2, dim=1).pow(2).mul(2).exp().add(1).log().mean()   
            l_neg = torch.pdist(q, p=2).pow(2).mul(-2).exp().add(1).log().mean() + torch.pdist(k, p=2).pow(2).mul(-2).exp().add(1).log().mean()
            l_neg /= 2.0
            loss2 = 0.5 * l_pos + self.alpha * l_neg
            
        if self.estimator == 'Gaussian-SH':
            l_pos = (q - k).norm(p=2, dim=1).pow(2).mul(1).exp().mean()  
            l_neg = torch.pdist(q, p=2).pow(2).mul(-1).exp().mean() + torch.pdist(k, p=2).pow(2).mul(-1).exp().mean()
            l_neg /= 2.0
            loss2 = 0.5*l_pos + self.alpha * l_neg
            
        if self.estimator == 'Gaussian-Tsallis':
            l_pos = -(q - k).norm(p=2, dim=1).pow(2).mul(-(self.tsallis-1)).exp().mean()  
            l_neg = torch.pdist(q, p=2).pow(2).mul(-self.tsallis).exp().mean() + torch.pdist(k, p=2).pow(2).mul(-self.tsallis).exp().mean()
            l_neg /= 2.0
            loss2 = 2.0*l_pos + self.alpha*l_neg
            
        if self.estimator == 'Gaussian-VLC':
            l_pos = (q - k).norm(p=2, dim=1).pow(2).mul(-2).exp().add(1).pow(-2).mean() 
            l_neg = -(torch.pdist(q, p=2).pow(2).mul(-2).exp().add(1).pow(-1).mean()+torch.pdist(k, p=2).pow(2).mul(-2).exp().add(1).pow(-1).mean())
            l_neg /= 2.0
            loss2 = l_pos + self.alpha * l_neg

        return loss2
    

# utils
@torch.no_grad()
def concat_all_gather(tensor):
    """
    Performs all_gather operation on the provided tensors.
    *** Warning ***: torch.distributed.all_gather has no gradient.
    """
    tensors_gather = [torch.ones_like(tensor)
        for _ in range(torch.distributed.get_world_size())]
    torch.distributed.all_gather(tensors_gather, tensor, async_op=False)

    output = torch.cat(tensors_gather, dim=0)
    return output
