# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Wraps the MoCo model to provide a consistent interface for loading and using the model.

Args:
    model_path (str): Path to the pre-trained MoCo model checkpoint.
    device (str): The device to load the model on ('cpu' or 'cuda').

Returns:
    nn.Module: The MoCo model instance.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F


class MoCoWrapper(nn.Module):
    """
    Build a MoCo (Momentum Contrast) model with a query encoder, a key encoder, and a queue.
    
    The MoCo model is a contrastive learning framework that uses a momentum-updated key encoder to build a large and consistent dictionary of negative samples. This allows the model to learn visual representations from unlabeled data.
    
    The MoCoWrapper class encapsulates the MoCo model, providing methods for computing the contrastive loss, updating the key encoder, and managing the queue of negative samples.
    
    Args:
        encoder_q (nn.Module): The query encoder module.
        encoder_k (nn.Module): The key encoder module.
        in_dim (int): The input feature dimension.
        fea_dim (int, optional): The output feature dimension. Defaults to 128.
        K (int, optional): The size of the queue of negative samples. Defaults to 65536.
        m (float, optional): The momentum coefficient for updating the key encoder. Defaults to 0.999.
        T (float, optional): The softmax temperature for the contrastive loss. Defaults to 0.07.
        mlp (bool, optional): Whether to use an MLP head for the encoders. Defaults to False.
        symmetric (bool, optional): Whether to use a symmetric contrastive loss. Defaults to True.
    
    Returns:
        torch.Tensor: The contrastive loss.
    """
    

    def __init__(self, encoder_q, encoder_k, in_dim, fea_dim=128, K=65536, m=0.999, T=0.07, mlp=False, symmetric=True):
        """
        Initializes a MoCoWrapper object, which is a PyTorch module that implements the Momentum Contrast (MoCo) algorithm for unsupervised representation learning.
        
        The MoCoWrapper takes in two encoder networks, `encoder_q` and `encoder_k`, which are used to encode the query and key features, respectively. It also takes in the input dimension `in_dim`, the feature dimension `fea_dim`, the size of the queue `K`, the momentum update rate `m`, the temperature parameter `T`, a boolean flag `mlp` to determine whether to use an MLP head, and a boolean flag `symmetric` to determine whether to use a symmetric loss.
        
        The `__init__` method sets up the encoder networks, the queue, and the queue pointer. It also initializes the weights of the key encoder network to be the same as the query encoder network, and sets the key encoder weights to be frozen (i.e., not updated by gradient).
        """
        
        super(MoCoWrapper, self).__init__()

        self.symmetric = symmetric
        self.K = K
        self.m = m
        self.T = T

        create_head = lambda mlp: nn.Linear(in_dim, fea_dim) if not mlp else nn.Sequential(nn.Linear(in_dim, in_dim),
                                                                                           nn.ReLU(),
                                                                                           nn.Linear(in_dim, fea_dim))
        # create the encoders
        self.encoder_q = nn.Sequential(
            encoder_q,
            create_head(mlp)
        )
        self.encoder_k = nn.Sequential(
            encoder_k,
            create_head(mlp)
        )

        self.q_params = list(self.encoder_q.parameters())
        self.k_params = list(self.encoder_k.parameters())

        for param_q, param_k in zip(self.q_params, self.k_params):
            param_k.data.copy_(param_q.data)  # initialize
            param_k.requires_grad = False  # not update by gradient

        # create the queue
        self.register_buffer("queue", torch.randn(fea_dim, K))
        self.queue = nn.functional.normalize(self.queue, dim=0)

        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))

    @torch.no_grad()
    def _momentum_update_key_encoder(self):
        """
        Performs a momentum update of the key encoder parameters. The key encoder parameters are updated as a moving average of the query encoder parameters, with a momentum factor `self.m`.
        """
        for param_q, param_k in zip(self.q_params, self.k_params):
            param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)

    @torch.no_grad()
    def _dequeue_and_enqueue(self, keys):
        """
        Dequeues and enqueues keys in the queue.
        
        This method is responsible for updating the queue of keys used for the momentum contrast (MoCo) algorithm. It dequeues the oldest keys from the queue and enqueues the new keys provided as input. The queue is maintained as a circular buffer, with the queue pointer `queue_ptr` keeping track of the current position in the queue.
        
        Args:
            keys (torch.Tensor): A batch of keys to be enqueued in the queue.
        """

        keys = self.concat_all_gather(keys)

        batch_size = keys.shape[0]

        ptr = int(self.queue_ptr)
        assert self.K % batch_size == 0  # for simplicity

        # replace the keys at ptr (dequeue and enqueue)
        self.queue[:, ptr:ptr + batch_size] = keys.T
        ptr = (ptr + batch_size) % self.K  # move pointer

        self.queue_ptr[0] = ptr

    @torch.no_grad()
    def _batch_shuffle_ddp(self, x):
        """
        Batch shuffles the input tensor `x` for making use of BatchNorm in a DistributedDataParallel (DDP) model.
        
        This function gathers the input tensor from all GPUs, shuffles the batch, and returns the shuffled tensor along with the indices needed to restore the original order.
        
        Args:
            x (torch.Tensor): The input tensor to be shuffled.
        
        Returns:
            torch.Tensor: The shuffled input tensor.
            torch.Tensor: The indices needed to restore the original order.
        """

        batch_size_this = x.shape[0]
        x_gather = self.concat_all_gather(x)
        batch_size_all = x_gather.shape[0]

        num_gpus = batch_size_all // batch_size_this

        # random shuffle index
        idx_shuffle = torch.randperm(batch_size_all).cuda()

        # broadcast to all gpus
        torch.distributed.broadcast(idx_shuffle, src=0)

        # index for restoring
        idx_unshuffle = torch.argsort(idx_shuffle)

        # shuffled index for this gpu
        gpu_idx = torch.distributed.get_rank()
        idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx]

        return x_gather[idx_this], idx_unshuffle

    @torch.no_grad()
    def _batch_unshuffle_ddp(self, x, idx_unshuffle):
        """
        Undo the batch shuffle for a DistributedDataParallel (DDP) model.
        
        This function takes a batch of data `x` and an index `idx_unshuffle` that was used to shuffle the batch, and returns the original unshuffle batch.
        
        This function is only supported for DDP models, where the batch is split across multiple GPUs. It gathers the data from all GPUs, restores the original order, and returns the unshuffle batch for the current GPU.
        
        Args:
            x (torch.Tensor): The batch of data to be unshuffle.
            idx_unshuffle (torch.Tensor): The index used to shuffle the batch.
        
        Returns:
            torch.Tensor: The original unshuffle batch of data.
        """

        
        # gather from all gpus
        batch_size_this = x.shape[0]
        x_gather = self.concat_all_gather(x)
        batch_size_all = x_gather.shape[0]

        num_gpus = batch_size_all // batch_size_this

        # restored index for this gpu
        gpu_idx = torch.distributed.get_rank()
        idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx]

        return x_gather[idx_this]

    def contrastive_loss(self, im_q, im_k):
        """
        Computes the contrastive loss for the MoCo (Momentum Contrast) model.
        
        Args:
            im_q (torch.Tensor): The query image tensor.
            im_k (torch.Tensor): The key image tensor.
        
        Returns:
            tuple: A tuple containing the contrastive loss, the query features, and the key features.
        """

        # compute query features
        q = self.encoder_q(im_q)  # queries: NxC
        q = nn.functional.normalize(q, dim=1)  # already normalized

        # compute key features
        with torch.no_grad():  # no gradient to keys
            # shuffle for making use of BN
            im_k_, idx_unshuffle = self._batch_shuffle_ddp(im_k)

            k = self.encoder_k(im_k_)  # keys: NxC
            k = nn.functional.normalize(k, dim=1)  # already normalized

            # undo shuffle
            k = self._batch_unshuffle_ddp(k, idx_unshuffle)

        # compute logits
        # Einstein sum is more intuitive
        # positive logits: Nx1
        l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
        # negative logits: NxK
        l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])

        # logits: Nx(1+K)
        logits = torch.cat([l_pos, l_neg], dim=1)

        # apply temperature
        logits /= self.T

        # labels: positive key indicators
        labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()

        loss = F.cross_entropy(logits, labels)

        return loss, q, k

    def forward(self, im_q, im_k, psedo_labels=None):
        """
        Forward pass of the MoCo (Momentum Contrast) model.
        
        Args:
            im_q (torch.Tensor): A batch of query images.
            im_k (torch.Tensor): A batch of key images.
            psedo_labels (Optional[torch.Tensor]): Pseudo-labels for the images (not used).
        
        Returns:
            torch.Tensor: The computed loss.
        """

        

        # update the key encoder
        with torch.no_grad():  # no gradient to keys
            self._momentum_update_key_encoder()

        # compute loss
        if self.symmetric:  # asymmetric loss
            loss_12, q1, k2 = self.contrastive_loss(im_q, im_k)
            loss_21, q2, k1 = self.contrastive_loss(im_k, im_q)
            loss = loss_12 + loss_21
            k = torch.cat([k1, k2], dim=0)
        else:  # asymmetric loss
            loss, q, k = self.contrastive_loss(im_q, im_k)

        self._dequeue_and_enqueue(k)

        return loss

    # utils
    # @torch.no_grad()
    def concat_all_gather(self, tensor):
        """
        Performs an all-gather operation on the provided tensor, concatenating the results from all processes.
        
        This function is a wrapper around torch.distributed.all_gather, which has no gradient. It creates a list of tensors the same size as the input tensor, gathers the input tensor from all processes into the list, and then concatenates the list into a single output tensor.
        
        Args:
            tensor (torch.Tensor): The input tensor to be gathered.
        
        Returns:
            torch.Tensor: The concatenated tensor containing the gathered input tensors from all processes.
        """

        
        tensors_gather = [torch.ones_like(tensor)
                          for _ in range(torch.distributed.get_world_size())]
        torch.distributed.all_gather(tensors_gather, tensor, async_op=False)

        output = torch.cat(tensors_gather, dim=0)
        return output
