# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import torch
import torch.nn as nn

import graph.loss as graph_loss


class MoCo(nn.Module):
    """
    Build a MoCo model with a base encoder, a momentum encoder, and two MLPs
    https://arxiv.org/abs/1911.05722
    """
    def __init__(self, base_encoder, dim=256, mlp_dim=4096, T=1.0):
        """
        dim: feature dimension (default: 256)
        mlp_dim: hidden dimension in MLPs (default: 4096)
        T: softmax temperature (default: 1.0)
        """
        super(MoCo, self).__init__()

        self.T = T

        # build encoders
        self.base_encoder = base_encoder(num_classes=mlp_dim)
        self.momentum_encoder = base_encoder(num_classes=mlp_dim)

        self._build_projector_and_predictor_mlps(dim, mlp_dim)

        for param_b, param_m in zip(self.base_encoder.parameters(), self.momentum_encoder.parameters()):
            param_m.data.copy_(param_b.data)  # initialize
            param_m.requires_grad = False  # not update by gradient

    def _build_mlp(self, num_layers, input_dim, mlp_dim, output_dim, last_bn=True):
        mlp = []
        for l in range(num_layers):
            dim1 = input_dim if l == 0 else mlp_dim
            dim2 = output_dim if l == num_layers - 1 else mlp_dim

            mlp.append(nn.Linear(dim1, dim2, bias=False))

            if l < num_layers - 1:
                mlp.append(nn.BatchNorm1d(dim2))
                mlp.append(nn.ReLU(inplace=True))
            elif last_bn:
                # follow SimCLR's design: https://github.com/google-research/simclr/blob/master/model_util.py#L157
                # for simplicity, we further removed gamma in BN
                mlp.append(nn.BatchNorm1d(dim2, affine=False))

        return nn.Sequential(*mlp)

    def _build_projector_and_predictor_mlps(self, dim, mlp_dim):
        pass

    @torch.no_grad()
    def _update_momentum_encoder(self, m):
        """Momentum update of the momentum encoder"""
        for param_b, param_m in zip(self.base_encoder.parameters(), self.momentum_encoder.parameters()):
            param_m.data = param_m.data * m + param_b.data * (1. - m)

    def contrastive_loss(self, q, k):
        # normalize
        q = nn.functional.normalize(q, dim=1)
        k = nn.functional.normalize(k, dim=1)
        # gather all targets
        k = concat_all_gather(k)
        # Einstein sum is more intuitive
        logits = torch.einsum('nc,mc->nm', [q, k]) / self.T
        N = logits.shape[0]  # batch size per GPU
        labels = (torch.arange(N, dtype=torch.long) + N * torch.distributed.get_rank()).cuda()
        return nn.CrossEntropyLoss()(logits, labels) * (2 * self.T)

    def forward(self, x1, x2, y1, y2, m):
        """
        Input:
            x1: first views of images
            x2: second views of images
            m: moco momentum
        Output:
            loss
        """

        # compute features
        q1 = self.predictor(self.base_encoder(x1, y1))
        q2 = self.predictor(self.base_encoder(x2, y2))

        with torch.no_grad():  # no gradient
            self._update_momentum_encoder(m)  # update the momentum encoder

            # compute momentum features as targets
            k1 = self.momentum_encoder(x1, y1)
            k2 = self.momentum_encoder(x2, y2)

        return self.contrastive_loss(q1, k2) + self.contrastive_loss(q2, k1)


class MoCo_ViT(MoCo):
    def _build_projector_and_predictor_mlps(self, dim, mlp_dim):
        hidden_dim = self.base_encoder.head.weight.shape[1]
        del self.base_encoder.head, self.momentum_encoder.head # remove original fc layer

        # projectors
        self.base_encoder.head = self._build_mlp(3, hidden_dim, mlp_dim, dim)
        self.momentum_encoder.head = self._build_mlp(3, hidden_dim, mlp_dim, dim)

        # predictor
        self.predictor = self._build_mlp(2, dim, mlp_dim, dim)


class MoCo_HrchViT(MoCo_ViT):
    """Define losses for Hierarchical ViT.
    """

    def __init__(self, base_encoder, dim=256, mlp_dim=4096, T=1.0):
        """
        dim: feature dimension (default: 256)
        mlp_dim: hidden dimension in MLPs (default: 4096)
        T: softmax temperature (default: 1.0)
        """
        super(MoCo_HrchViT, self).__init__(base_encoder, dim, mlp_dim, T)
        self.dmon_loss1 = graph_loss.DMonLoss(adj_knn=2, remove_self_loop=True, binarize=True)
        self.dmon_loss2 = graph_loss.DMonLoss(adj_knn=2, remove_self_loop=True, binarize=True)
        self.dmon_loss3 = graph_loss.DMonLoss(adj_knn=2, remove_self_loop=True, binarize=True)
        self.dmon_loss4 = graph_loss.DMonLoss(adj_knn=2, remove_self_loop=True, binarize=True)

    def _build_projector_and_predictor_mlps(self, dim, mlp_dim):
        hidden_dim1 = self.base_encoder.head1.weight.shape[1]
        hidden_dim2 = self.base_encoder.head2.weight.shape[1]
        hidden_dim3 = self.base_encoder.head3.weight.shape[1]
        hidden_dim4 = self.base_encoder.head4.weight.shape[1]

        # remove original fc layer
        del self.base_encoder.head1, self.momentum_encoder.head1
        del self.base_encoder.head2, self.momentum_encoder.head2
        del self.base_encoder.head3, self.momentum_encoder.head3
        del self.base_encoder.head4, self.momentum_encoder.head4

        # projectors
        self.base_encoder.head1 = self._build_mlp(2, hidden_dim1, mlp_dim, dim)
        self.momentum_encoder.head1 = self._build_mlp(2, hidden_dim1, mlp_dim, dim)
        self.base_encoder.head2 = self._build_mlp(2, hidden_dim2, mlp_dim, dim)
        self.momentum_encoder.head2 = self._build_mlp(2, hidden_dim2, mlp_dim, dim)
        self.base_encoder.head3 = self._build_mlp(2, hidden_dim3, mlp_dim, dim)
        self.momentum_encoder.head3 = self._build_mlp(2, hidden_dim3, mlp_dim, dim)
        self.base_encoder.head4 = self._build_mlp(2, hidden_dim4, mlp_dim, dim)
        self.momentum_encoder.head4 = self._build_mlp(2, hidden_dim4, mlp_dim, dim)

        # predictor
        self.predictor = self._build_mlp(2, dim, mlp_dim, dim)

    def forward(self, x1, x2, y1, y2, m):
        """
        Input:
            x1: first views of images
            x2: second views of images
            m: moco momentum
        Output:
            loss = (contrastive_losses + clustering_losses
                      + centroid_contrastive_losses)
        """

        # compute features
        _, q_med1 = self.base_encoder(x1, y1, return_intermediate=True)
        q_med1['out4'] = self.predictor(q_med1['out4'])
        _, q_med2 = self.base_encoder(x2, y2, return_intermediate=True)
        q_med2['out4'] = self.predictor(q_med2['out4'])

        with torch.no_grad():  # no gradient
            self._update_momentum_encoder(m)  # update the momentum encoder

            # compute momentum features as targets
            _, k_med1 = self.momentum_encoder(x1, y1, return_intermediate=True)
            _, k_med2 = self.momentum_encoder(x2, y2, return_intermediate=True)

        contrastive_weights = [0.0, 0.0, 0.0, 1.0]
        clustering_weights = [0.4, 0.4, 0.4, 0.4]
        levels = [1, 2, 3, 4]
        dmon_loss_fns = [self.dmon_loss1, self.dmon_loss2, self.dmon_loss3, self.dmon_loss4]


        # Compute contrastive losses.
        #contrastive_loss = self.contrastive_loss(q1, k2) + self.contrastive_loss(q2, k1)
        contrastive_losses = []
        for level, weight in zip(levels, contrastive_weights):
            q1 = q_med1['out{:d}'.format(level)]
            k2 = k_med2['out{:d}'.format(level)]

            q2 = q_med2['out{:d}'.format(level)]
            k1 = k_med1['out{:d}'.format(level)]
            loss = (self.contrastive_loss(q1, k2) * weight
                    + self.contrastive_loss(q2, k1) * weight)
            contrastive_losses.append(loss)

        contrastive_loss = sum(contrastive_losses)

        # Compute clustering losses.
        clustering_losses = []
        for level, dmon_loss_fn, weight in zip(levels, dmon_loss_fns, clustering_weights):
            feat1 = q_med1['block{:d}'.format(level)]
            logit1 = q_med1['logit{:d}'.format(level)]
            padding_mask = q_med1['padding_mask{:d}'.format(level)]
            dmon_loss1, reg_loss1 = dmon_loss_fn(
                logit1, feat1, padding_mask)

            feat2 = q_med2['block{:d}'.format(level)]
            logit2 = q_med2['logit{:d}'.format(level)]
            padding_mask = q_med2['padding_mask{:d}'.format(level)]
            dmon_loss2, reg_loss2 = dmon_loss_fn(
                logit2, feat2, padding_mask)

            losses = [dmon_loss1 * weight, reg_loss1 * weight,
                      dmon_loss2 * weight, reg_loss2 * weight]
            clustering_losses.extend(losses)

        clustering_loss = sum(clustering_losses)

        return contrastive_loss + clustering_loss

# utils
@torch.no_grad()
def concat_all_gather(tensor):
    """
    Performs all_gather operation on the provided tensors.
    *** Warning ***: torch.distributed.all_gather has no gradient.
    """
    tensors_gather = [torch.ones_like(tensor)
        for _ in range(torch.distributed.get_world_size())]
    torch.distributed.all_gather(tensors_gather, tensor, async_op=False)

    output = torch.cat(tensors_gather, dim=0)
    return output

