# Copyright (c) Meta Platforms, Inc. and affiliates.
import copy

# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
import torch.nn as nn

from models.Attention import MyAttention as Attention

def global_pool(x, global_pool='avg', num_prefix_tokens=0):
    x = x[:, num_prefix_tokens:].mean(dim=1) if global_pool == 'avg' else x[:, 0]
    return x
class MoCo(nn.Module):
    """
    Build a MoCo model with: a query encoder, a key encoder, and a queue
    https://arxiv.org/abs/1911.05722
    """

    def __init__(self, base_encoder, args, dim=128, K=65536, m=0.999, T=0.07, mlp=False, num_heads=8):
        """
        dim: feature dimension (default: 128)
        K: queue size; number of negative keys (default: 65536)
        m: moco momentum of updating key encoder (default: 0.999)
        T: softmax temperature (default: 0.07)
        """
        super(MoCo, self).__init__()

        self.K = K
        self.m = m
        self.T = T

        # create the encoders

        self.encoder_q = copy.deepcopy(base_encoder)
        self.encoder_k = copy.deepcopy(base_encoder)
        if args.use_attention == 1:
            self.attention = Attention(args.feature_dim, num_heads=num_heads, qkv_bias=True)
        else:
            assert args.prompt_num_tokens_g == 0 and args.prompt_num_tokens_p == 0
            self.attention = nn.Identity()

        self.global_prompt = None


        # put fc layers in moco
        if args.proj:
            if mlp:  # hack: brute-force replacement
                dim_mlp = args.feature_dim
                self.fc_q = nn.Sequential(
                    nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), nn.Linear(dim_mlp, dim)
                )
                self.fc_k = nn.Sequential(
                    nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), nn.Linear(dim_mlp, dim)
                )
            else:
                self.fc_q = nn.Linear(args.feature_dim, dim)
                self.fc_k = nn.Linear(args.feature_dim, dim)
        else:
            self.fc_q = nn.Identity()
            self.fc_k = nn.Identity()

        # remove the fc layers in backbone. Backbone only extracts features
        self.encoder_q.fc = nn.Identity()
        self.encoder_k.fc = nn.Identity()

        for param_q, param_k in zip(
            self.encoder_q.parameters(), self.encoder_k.parameters()
        ):
            param_k.data.copy_(param_q.data)  # initialize
            param_k.requires_grad = False  # not update by gradient

        for param_q, param_k in zip(
            self.fc_q.parameters(), self.fc_k.parameters()
        ):
            param_k.data.copy_(param_q.data)  # initialize
            param_k.requires_grad = False  # not update by gradient

        # create the queue
        self.register_buffer("queue", torch.randn(dim, K))
        self.queue = nn.functional.normalize(self.queue, dim=0)

        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))

    @torch.no_grad()
    def _momentum_update_key_encoder(self):
        """
        Momentum update of the key encoder
        """
        for param_q, param_k in zip(
            self.encoder_q.parameters(), self.encoder_k.parameters()
        ):
            param_k.data = param_k.data * self.m + param_q.data * (1.0 - self.m)

        # update fc layers
        for param_q, param_k in zip(
            self.fc_q.parameters(), self.fc_k.parameters()
        ):
            param_k.data = param_k.data * self.m + param_q.data * (1.0 - self.m)

    def reset_momentum_key_encoder(self):

        for param_q, param_k in zip(
            self.encoder_q.parameters(), self.encoder_k.parameters()
        ):
            param_k.data.copy_(param_q.data)  # initialize
            param_k.requires_grad = False  # not update by gradient

        for param_q, param_k in zip(
            self.fc_q.parameters(), self.fc_k.parameters()
        ):
            param_k.data.copy_(param_q.data)  # initialize
            param_k.requires_grad = False  # not update by gradient

    @torch.no_grad()
    def _dequeue_and_enqueue(self, keys):
        # gather keys before updating queue
        keys = concat_all_gather(keys)

        batch_size = keys.shape[0]

        ptr = int(self.queue_ptr)
        assert self.K % batch_size == 0  # for simplicity

        # replace the keys at ptr (dequeue and enqueue)
        self.queue[:, ptr : ptr + batch_size] = keys.T
        ptr = (ptr + batch_size) % self.K  # move pointer

        self.queue_ptr[0] = ptr

    @torch.no_grad()
    def _batch_shuffle_ddp(self, x):
        """
        Batch shuffle, for making use of BatchNorm.
        *** Only support DistributedDataParallel (DDP) model. ***
        """
        # gather from all gpus
        batch_size_this = x.shape[0]
        x_gather = concat_all_gather(x)
        batch_size_all = x_gather.shape[0]

        num_gpus = batch_size_all // batch_size_this

        # random shuffle index
        idx_shuffle = torch.randperm(batch_size_all).cuda()

        # broadcast to all gpus
        torch.distributed.broadcast(idx_shuffle, src=0)

        # index for restoring
        idx_unshuffle = torch.argsort(idx_shuffle)

        # shuffled index for this gpu
        gpu_idx = torch.distributed.get_rank()
        idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx]

        return x_gather[idx_this], idx_unshuffle

    @torch.no_grad()
    def _batch_unshuffle_ddp(self, x, idx_unshuffle):
        """
        Undo batch shuffle.
        *** Only support DistributedDataParallel (DDP) model. ***
        """
        # gather from all gpus
        batch_size_this = x.shape[0]
        x_gather = concat_all_gather(x)
        batch_size_all = x_gather.shape[0]

        num_gpus = batch_size_all // batch_size_this

        # restored index for this gpu
        gpu_idx = torch.distributed.get_rank()
        idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx]

        return x_gather[idx_this]

    def incorporate_prompt(self, x, prompt):
        B = x.shape[0]
        if len(x.shape) == 2: x = x.reshape(B, 1, -1)
        if prompt is not None:
            x = torch.cat((
                x[:, :, :],
                prompt.expand(B, -1, -1),
            ), dim=1)
        return x

    def get_feature(self, img_q, personalized_prompt=None, if_update_encoder=True):
        """
        Input:
            img_q: a batch of query images
        Output:
            features to classify
        """
        # compute query features
        q = self.encoder_q(img_q)
        if if_update_encoder == False:
            q = q.detach()
        q = self.incorporate_prompt(q, personalized_prompt)
        q = global_pool(self.attention(q), global_pool='token')
        B = q.shape[0]
        q = q.reshape(B, -1)
        return q

    def forward(self, im_q, im_k, if_update_encoder=True):
        """
        Input:
            im_q: a batch of query images
            im_k: a batch of key images
        Output:
            logits, targets
        """

        # compute query features
        q = self.encoder_q(im_q)  # queries: NxC
        encoder_feature = q
        if if_update_encoder == False:
            q = q.detach()
        q = self.incorporate_prompt(q, self.global_prompt)
        q = global_pool(self.attention(q), global_pool='token')
        B = q.shape[0]
        q = q.reshape(B, -1)
        q = self.fc_q(q)
        q = nn.functional.normalize(q, dim=1)
        prompt_feature = q

        # compute key features
        with torch.no_grad():  # no gradient to keys
            self._momentum_update_key_encoder()  # update the key encoder

            k = self.encoder_k(im_k)  # keys: NxC
            k = self.incorporate_prompt(k, self.global_prompt)
            k = global_pool(self.attention(k), global_pool='token')
            B = k.shape[0]
            k = k.reshape(B, -1)
            k = self.fc_k(k)
            k = nn.functional.normalize(k, dim=1)


        # compute logits
        # Einstein sum is more intuitive
        # positive logits: Nx1
        l_pos = torch.einsum("nc,nc->n", [q, k]).unsqueeze(-1)
        # negative logits: NxK
        l_neg = torch.einsum("nc,ck->nk", [q, self.queue.clone().detach()])

        # logits: Nx(1+K)
        logits = torch.cat([l_pos, l_neg], dim=1)

        # apply temperature
        logits /= self.T

        # labels: positive key indicators
        labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()

        # dequeue and enqueue
        self._dequeue_and_enqueue(k)

        return logits, labels, encoder_feature, prompt_feature

@torch.no_grad()
def concat_all_gather(tensor):
    """
    Performs all_gather operation on the provided tensors.
    *** Warning ***: torch.distributed.all_gather has no gradient.
    """
    return tensor
# utils
# @torch.no_grad()
# def concat_all_gather(tensor):
#     """
#     Performs all_gather operation on the provided tensors.
#     *** Warning ***: torch.distributed.all_gather has no gradient.
#     """
#     tensors_gather = [
#         torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
#     ]
#     torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
#
#     output = torch.cat(tensors_gather, dim=0)
#     return output
