import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import math
import torch.distributed as dist
from util import Pack
from functools import partial
import torchvision.models as torchvision_models
from util import Pack

class GatherLayer(torch.autograd.Function):
    """Gather tensors from all process, supporting backward propagation."""

    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        output = [torch.zeros_like(input) for _ in range(dist.get_world_size())]
        dist.all_gather(output, input)
        return tuple(output)

    @staticmethod
    def backward(ctx, *grads):
        (input,) = ctx.saved_tensors
        grad_out = torch.zeros_like(input)
        grad_out[:] = grads[dist.get_rank()]
        return grad_out

@torch.no_grad()
def concat_all_gather(tensor):
    """
    Performs all_gather operation on the provided tensors.
    *** Warning ***: torch.distributed.all_gather has no gradient.
    """
    tensors_gather = [torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())]
    torch.distributed.all_gather(tensors_gather, tensor, async_op=False)

    output = torch.cat(tensors_gather, dim=0)
    return output

class SparseCL(nn.Module):
    def __init__(self, args):
        super(SparseCL, self).__init__()
        self.args = args
        self.criterion = nn.CosineSimilarity(dim=1, eps=1e-4).cuda(args.local_rank)
        self.online_encoder = partial(torchvision_models.__dict__["resnet50"], zero_init_residual=True)(num_classes=args.mlp_dim)
        self.target_encoder = partial(torchvision_models.__dict__["resnet50"], zero_init_residual=True)(num_classes=args.mlp_dim)

        hidden_dim = self.online_encoder.fc.weight.shape[1]
        del self.online_encoder.fc, self.target_encoder.fc # remove original fc layer

        #projectors mlp_dim =4096, feature_dim = 256
        self.online_encoder.fc = self._build_projector(args, hidden_dim)
        self.target_encoder.fc = self._build_projector(args, hidden_dim)

        # predictor
        self.predictor = self._build_predictor(args)

        for param_o, param_t in zip(self.online_encoder.parameters(), self.target_encoder.parameters()):
            param_t.data.copy_(param_o.data)  # initialize
            param_t.requires_grad = False  # not update by gradient

    def _build_projector(self, args, hidden_dim):
        return nn.Sequential(
            nn.Linear(hidden_dim, args.mlp_dim, bias=False),
            nn.BatchNorm1d(args.mlp_dim),
            nn.ReLU(inplace=True),
            nn.Linear(args.mlp_dim, args.mlp_dim, bias=False),
            nn.BatchNorm1d(args.mlp_dim),
            nn.ReLU(inplace=True),
            nn.Linear(args.mlp_dim, args.feature_dim, bias=True)
        )

    def _build_predictor(self, args):
        return nn.Sequential(
            nn.Linear(args.feature_dim, args.mlp_dim, bias=False),
            nn.BatchNorm1d(args.mlp_dim),
            nn.ReLU(inplace=True),
            nn.Linear(args.mlp_dim, args.feature_dim, bias=True)
        )

    @torch.no_grad()
    def _update_momentum_encoder(self, m):
        """Momentum update of the momentum encoder"""
        for param_o, param_t in zip(self.online_encoder.parameters(), self.target_encoder.parameters()):
            param_t.data = param_t.data * m + param_o.data * (1. - m)

    def alignment_loss(self, q, k):
        alignment_loss = 2.0*(1.0 - self.criterion(q, k)).sum()
        return alignment_loss

    def sparsity_loss(self, q, k):
        # normalize
        q = nn.functional.normalize(q, dim=1)
        k = nn.functional.normalize(k, dim=1)

        # gather all targets
        k = concat_all_gather(k)
        ##[per_gpu_batch_size, batch_size]
        cosine_distance = torch.mm(q, k.t().contiguous())
        N = cosine_distance.size(0)
        batch_size = cosine_distance.size(1)
        matrix = torch.sigmoid((cosine_distance - self.args.threshold)/self.args.temperature)
        labels = F.one_hot((torch.arange(N, dtype=torch.long) + N * torch.distributed.get_rank()).cuda(), num_classes=batch_size)
        sparsity_loss = ((1.0-labels)*matrix).abs().sum()
        return sparsity_loss

    def forward(self, x1, x2, m):
        """
        Input:
            x1: first views of images
            x2: second views of images
            m: moco momentum
        Output:
            loss
        """
        # compute features
        q1 = self.predictor(self.online_encoder(x1))
        q2 = self.predictor(self.online_encoder(x2))

        with torch.no_grad():  # no gradient
            self._update_momentum_encoder(m)  # update the momentum encoder

            # compute momentum features as targets
            k1 = self.target_encoder(x1)
            k2 = self.target_encoder(x2)

        
        sparsity_loss = 0.5 * (self.sparsity_loss(q1, k2) + self.sparsity_loss(q2, k1))
        alignment_loss = 0.5 * (self.alignment_loss(q1, k2) + self.alignment_loss(q2, k1))

        loss = alignment_loss + self.args.alpha*sparsity_loss
        loss_pack = Pack(alignment_loss=alignment_loss, sparsity_loss=sparsity_loss)
        return loss, loss_pack
