from typing import Callable

import torch
from torch import Tensor, distributed, normal, zeros_like
from torch.nn.functional import linear, normalize, one_hot
from torch.nn.parameter import Parameter


class MultiFilterArcLoss(torch.nn.Module):
    def __init__(self,
                 scale: float,
                 margin: float,
                 neg_filter: float = 0):
        super().__init__()
        self.scale = scale
        self.margin = margin
        self.neg_filter = neg_filter

    def forward(self,
                logit: Tensor,
                label4pfc_o: Tensor):
        """
        Parameters
        ----------
        logit: Tensor
            the output of the last layer of the model, shape: (batch_size, num_class)
        label4pfc_o: Tensor
            the label of the sample, shape: (batch_size, num_class)
        Returns
        -------
        logit: Tensor
            the output of the last layer of the model, shape: (batch_size, num_class)
        """

        if self.scale == 1:
            return logit

        if self.margin > 0:
            with torch.no_grad():
                logit.arccos_()
                arc_margin = label4pfc_o * self.margin
                logit += arc_margin
                logit.cos_()
            logit *= self.scale
        return logit

 
class PartialFC_V4(torch.nn.Module):
    """
    PartialFC_v4 is a module for face recognition. It is proposed in
    `"Partial FC: Training 10 Million Identities on a Single Machine" <https://arxiv.org/abs/2010.05222>`_.
    PartialFC_v4 is an improved version of PartialFC_V2, which can reduce the memory consumption of the embedding layer.
    """
    _version = 4

    def __init__(
        self,
        margin_loss: Callable,
        embedding_size: int,
        num_class: int,
        neg_class_sample_rate: float = 1.0,
        pos_class_num: int = 1,
        is_normlize: int = 1,
    ):
        super(PartialFC_V4, self).__init__()
        assert (
            distributed.is_initialized()
        ), "must initialize distributed before create this"
        self.rank = distributed.get_rank()
        self.world_size = distributed.get_world_size()
        self.embedding_size = embedding_size
        self.neg_class_sample_rate = neg_class_sample_rate
        self.pos_class_num = pos_class_num
        self.is_normlize = is_normlize
        self.num_local = num_class // self.world_size + int(
            self.rank < num_class % self.world_size)
        self.class_start = num_class // self.world_size * self.rank + min(
            self.rank, num_class % self.world_size)
        self.num_sample = int(self.neg_class_sample_rate * self.num_local)
        self.weight = Parameter(
            normal(0, 0.01, (self.num_local, embedding_size)))
        self.dist_multi_softmax = DistMultiSoftmax()
        self.margin_loss = margin_loss

    def forward(
        self,
        local_embedding: torch.Tensor,
        local_label: torch.Tensor,
    ):
        local_label.squeeze_()
        local_label: torch.Tensor

        if len(local_label.size()) != 2:
            local_label = torch.reshape(local_label, (-1, 1))

        local_label = local_label.long()

        # Gather all embedding and labels from different gpus
        cache_embedding = [
            zeros_like(local_embedding) for _ in range(self.world_size)
        ]
        embedding = torch.cat(AllGather(local_embedding, *cache_embedding))

        cache_label = [
            zeros_like(local_label) for _ in range(self.world_size)
        ]
        distributed.all_gather(cache_label, local_label)
        label = torch.cat(cache_label)

        # label in this rank means if the label is in this rank or not
        label_in_this_rank = (self.class_start <= label) & (
            label < self.class_start + self.num_local
        )  # label in [self.class_start, self.class_start + self.num_local)

        label4pfc = label.clone()
        # label -1 means this label is not in this rank
        label4pfc[~label_in_this_rank] = -1
        # label in this rank is in [0, self.num_local) now
        label4pfc[label_in_this_rank] -= self.class_start

        label4pfc, weight_index = self.sample_class_center(
            label4pfc, label_in_this_rank)

        if self.neg_class_sample_rate < 1:
            weight = self.weight[weight_index]
        else:
            weight = self.weight

        if self.is_normlize:
            norm_embeddings = normalize(embedding)
            norm_weight_activated = normalize(weight)
            logit = linear(norm_embeddings, norm_weight_activated)
        else:
            logit = linear(embedding, weight)

        # import logging
        # logging.info(label4pfc)
        label4pfc_o = self.onehot4pfc(label4pfc)
        if self.is_normlize:
            logit = logit.clamp(-1, 1)
        else:
            logit = torch.clip(logit, -64, 64)

        if self.margin_loss is not None:
            logit = self.margin_loss(logit, label4pfc_o)

        loss = self.dist_multi_softmax(logit, label4pfc_o)
        return loss

    @torch.no_grad()
    def onehot4pfc(self, label: torch.Tensor):
        # because of the -1 label, we need to add 1 to all label
        # to make sure the label is in [0, num_sample + 1)
        # and then we can use one_hot to get the one hot label
        # finally, we need to remove the first one hot label
        # because the first one hot label is for the -1 label
        # and we don't need its
        label += 1
        label = one_hot(label, self.num_sample + 1)
        label = label[:, :, 1:]
        label = torch.sum(label, dim=1)
        label = label.clamp(0, 1)
        return label

    @torch.no_grad()
    def sample_class_center(self, label4pfc, label_in_this_rank):
        shape_label4pfc = label4pfc.size()
        label4pfc = label4pfc.flatten()
        label_in_this_rank = label_in_this_rank.flatten()

        unique = torch.unique(
            label4pfc[label_in_this_rank], sorted=True)

        # this means the number of negative class is less than the numbe

        if self.num_sample - unique.size(0) >= 0:
            perm = torch.rand(size=(self.num_local, )).cuda()
            perm[unique] = 2.0
            index = torch.topk(perm, k=self.num_sample)[1]
            index = index.sort()[0]
        else:
            index = unique
        current = label4pfc.clone()

        current[label_in_this_rank] = torch.searchsorted(
            index, label4pfc[label_in_this_rank])
        current = current.reshape(shape_label4pfc)
        return current, index


class DistMultiSoftmaxFunc(torch.autograd.Function):
    @staticmethod
    def forward(ctx, logit: torch.Tensor, label: torch.Tensor):
        with torch.no_grad():
            label = label.float()
            mask = torch.ones_like(label)
            mask -= label
            mask -= label
            mask_neg = torch.ones_like(label) - label
            mask_pos = label.clone()
        logit = logit * mask
        logit = torch.exp(logit)
        logit_pos = logit * mask_pos
        logit_neg = logit * mask_neg
        sum_pos = logit_pos.sum(1)
        sum_neg = logit_neg.sum(1)
        distributed.all_reduce(sum_pos)
        distributed.all_reduce(sum_neg)
        loss = torch.log(1 + sum_pos) + torch.log(1 + sum_neg)
        loss = torch.mean(loss)
        ctx.save_for_backward(logit_pos, logit_neg, sum_pos, sum_neg)
        return loss

    @staticmethod
    def backward(ctx, loss_gradient):
        (
            logit_pos,
            logit_neg,
            sum_pos,
            sum_neg) = ctx.saved_tensors
        batch_size = logit_pos.size(0)
        logit_pos = -1 * logit_pos / (1 + sum_pos).unsqueeze(1)
        logit_neg = logit_neg / (1 + sum_neg).unsqueeze(1)
        logit_gradient = logit_pos + logit_neg
        logit_gradient.div_(batch_size)
        return logit_gradient * loss_gradient.item(), None


class DistMultiSoftmax(torch.nn.Module):
    def __init__(self):
        super(DistMultiSoftmax, self).__init__()

    def forward(self, logit, label):
        return DistMultiSoftmaxFunc.apply(logit, label)


class AllGatherFunc(torch.autograd.Function):
    """AllGather op with gradient backward"""

    @staticmethod
    def forward(ctx, tensor, *gather_list):
        gather_list = list(gather_list)
        distributed.all_gather(gather_list, tensor)
        return tuple(gather_list)

    @staticmethod
    def backward(ctx, *grads):
        grad_list = list(grads)
        rank = distributed.get_rank()
        grad_out = grad_list[rank]

        dist_ops = [
            distributed.reduce(grad_out, rank, distributed.ReduceOp.SUM, async_op=True)
            if i == rank
            else distributed.reduce(
                grad_list[i], i, distributed.ReduceOp.SUM, async_op=True
            )
            for i in range(distributed.get_world_size())
        ]
        for _op in dist_ops:
            _op.wait()

        grad_out *= len(grad_list)  # cooperate with distributed loss function
        return (grad_out, *[None for _ in range(len(grad_list))])


AllGather = AllGatherFunc.apply
