import torch
import torch.nn as nn
from torch.nn import functional as F

import torch.distributed.nn
from torch import distributed as dist

## from CLIP
def cross_entropy(preds, targets, reduction='none'):
    log_softmax = nn.LogSoftmax(dim=-1)
    loss = (-targets * log_softmax(preds)).sum(1)
    if reduction == "none":
        return loss
    elif reduction == "mean":
        return loss.mean()

class ContrasLoss(nn.Module):
    def __init__(self,):
        super().__init__()

    def forward(self, c_features, h_features, logit_scale):
        device = c_features.device
        # logits_per_c, logits_per_h = self.get_logits(c_features, h_features, logit_scale)
        logits_per_c = logit_scale * c_features @ h_features.T
        logits_per_h = logit_scale * h_features @ c_features.T
        # labels = self.get_ground_truth(device, logits_per_c.shape[0])
        labels = torch.arange(logits_per_c.shape[0], device=device, dtype=torch.long)
        total_loss = (
                     F.cross_entropy(logits_per_c, labels) +  # default reduction = 'mean'
                     F.cross_entropy(logits_per_h, labels)
                     ) / 2
        return total_loss

class ClipLoss(nn.Module): # change from open-clip
    def __init__(
            self,
            rank=0,
            world_size=1,
    ):
        super().__init__()
        self.rank = rank
        self.world_size = world_size

    def get_ground_truth(self, device, num_logits) -> torch.Tensor:
        # calculated ground-truth and cache if enabled
        # print(f'rank used for labels is here: {self.rank}')
        # print(f'num_logits is: {num_logits}')

        labels = torch.arange(num_logits, device=device, dtype=torch.long)
        # print(f'labels before add rank: {labels}')

        labels = labels + num_logits * self.rank
        # print(f'labels after add numn_logits {num_logits} * rank{self.rank}: {labels}')

        return labels

    def get_logits(self, c_features, h_features, logit_scale):
        all_c_features, all_h_features = gather_features(c_features, h_features)

        logits_per_c = logit_scale * c_features @ all_h_features.T # local_loss
        logits_per_h = logit_scale * h_features @ all_c_features.T

        return logits_per_c, logits_per_h

    def forward(self, c_features, h_features, logit_scale, output_dict=False):
        assert self.world_size > 1, 'The world_size is less than or equal to 1.'
        device = c_features.device
        # print(f'Local_rank is here: {self.rank}')
        logits_per_c, logits_per_h = self.get_logits(c_features, h_features, logit_scale)

        # print(f'logits_per_c_shape: {logits_per_c.shape}')
        # print(f'logits_per_h_shape: {logits_per_h.shape}')

        labels = self.get_ground_truth(device, logits_per_c.shape[0]) # is same as the batchsize

        total_loss = (
                     F.cross_entropy(logits_per_c, labels) +  #default reduction = 'mean'
                     F.cross_entropy(logits_per_h, labels)
                     ) / 2  # in each local_rank

        # return {"contrastive_loss": total_loss} if output_dict else total_loss
        # print(f'Loss: {total_loss}')
        return total_loss


def gather_features(
        c_features,
        h_features,
):
    # We gather tensors from all gpus
    all_c_features = torch.cat(torch.distributed.nn.all_gather(c_features), dim=0)
    all_h_features = torch.cat(torch.distributed.nn.all_gather(h_features), dim=0)

    return all_c_features, all_h_features


class SigLipLoss(nn.Module):
    """ Sigmoid Loss for Language Image Pre-Training (SigLIP) - https://arxiv.org/abs/2303.15343

    @article{zhai2023sigmoid,
      title={Sigmoid loss for language image pre-training},
      author={Zhai, Xiaohua and Mustafa, Basil and Kolesnikov, Alexander and Beyer, Lucas},
      journal={arXiv preprint arXiv:2303.15343},
      year={2023}
    }
    """
    def __init__(
            self,
            cache_labels=False,
            rank=0,
            world_size=1,
            bidir=True,
            weight = 1,
            # use_horovod=False,
    ):
        super().__init__()
        self.cache_labels = cache_labels
        self.rank = rank
        self.world_size = world_size
        # assert not use_horovod  # FIXME need to look at hvd ops for ring transfers
        # self.use_horovod = use_horovod
        self.bidir = bidir
        self.weight = weight

        # cache state FIXME cache not currently used, worthwhile?
        self.prev_num_logits = 0
        self.labels = {}

    def get_ground_truth(self, device, dtype, num_img, num_text, negative_only=False) -> torch.Tensor:
        labels = -torch.ones((num_img, num_text), device=device, dtype=dtype)

        if not negative_only:
            base_block = torch.ones(8, 4) * 2
            assert num_img % 8 == 0 and num_text % 4 == 0 and num_img/8 == num_text/4, 'The view is not enought!'
            bs = int(num_img / 8)
            base_list = [base_block] * bs  # List containing 4 copies of the base tensor
            gt_block = torch.block_diag(*base_list).to(device)
            labels = gt_block + labels
        return labels

    def get_logits(self, image_features, text_features, logit_scale, logit_bias=None):
        logits = logit_scale * image_features @ text_features.T
        # print(f'logits_scale:{logit_scale}, logit:{image_features @ text_features.T}')

        # logits = image_features @ text_features.T
        # print(f' logit:{image_features @ text_features.T}')

        if logit_bias is not None:
            logits += logit_bias
        return logits

    def _loss(self, image_features, text_features, logit_scale, logit_bias=None, negative_only=False):
        logits = self.get_logits(image_features, text_features, logit_scale, logit_bias)
        labels = self.get_ground_truth(
            image_features.device,
            image_features.dtype,
            image_features.shape[0],
            text_features.shape[0],
            negative_only=negative_only,
        )
        # loss = -F.logsigmoid(labels * logits).sum() / image_features.shape[0]
        loss = -F.logsigmoid(labels * logits)
        # print(f'weight: {self.weight}')
        if self.weight != 1:
            loss[labels==1] = loss[labels==1] * self.weight
        return loss.mean()

    def forward(self, image_features, text_features, logit_scale, logit_bias=None, output_dict=False):
        loss = self._loss(image_features, text_features, logit_scale, logit_bias)
        # print(f'loss: {loss}')
        # print(f' logit:{(image_features @ text_features.T)[:8, :4,]}')

        # assert self.world_size > 1, 'The world_size should be more than 1.'
        # # exchange text features w/ neighbour world_size - 1 times
        # right_rank = (self.rank + 1) % self.world_size
        # left_rank = (self.rank - 1 + self.world_size) % self.world_size
        # if self.bidir:
        #     text_features_to_right = text_features_to_left = text_features
        #     num_bidir, remainder = divmod(self.world_size - 1, 2)
        #     for i in range(num_bidir):
        #         text_features_recv = neighbour_exchange_bidir_with_grad(
        #             left_rank,
        #             right_rank,
        #             text_features_to_left,
        #             text_features_to_right,
        #         )
        #
        #         for f in text_features_recv:
        #             loss += self._loss(
        #                 image_features,
        #                 f,
        #                 logit_scale,
        #                 logit_bias,
        #                 negative_only=True,
        #             )
        #         text_features_to_left, text_features_to_right = text_features_recv
        #
        #     if remainder:
        #         text_features_recv = neighbour_exchange_with_grad(
        #             left_rank, right_rank, text_features_to_right)
        #
        #         loss += self._loss(
        #             image_features,
        #             text_features_recv,
        #             logit_scale,
        #             logit_bias,
        #             negative_only=True,
        #         )
        # else:
        #     text_features_to_right = text_features
        #     for i in range(self.world_size - 1):
        #         text_features_from_left = neighbour_exchange_with_grad(
        #             left_rank, right_rank, text_features_to_right)
        #
        #         loss += self._loss(
        #             image_features,
        #             text_features_from_left,
        #             logit_scale,
        #             logit_bias,
        #             negative_only=True,
        #         )
        #         text_features_to_right = text_features_from_left

        return {"contrastive_loss": loss} if output_dict else loss

##==============other class and function
def neighbour_exchange(from_rank, to_rank, tensor, group=None):
    tensor_recv = torch.zeros_like(tensor)
    send_op = torch.distributed.P2POp(
        torch.distributed.isend,
        tensor,
        to_rank,
        group=group,
    )
    recv_op = torch.distributed.P2POp(
        torch.distributed.irecv,
        tensor_recv,
        from_rank,
        group=group,
    )
    reqs = torch.distributed.batch_isend_irecv([send_op, recv_op])
    for req in reqs:
        req.wait()
    return tensor_recv


def neighbour_exchange_bidir(left_rank, right_rank, tensor_to_left, tensor_to_right, group=None):
    tensor_from_left = torch.zeros_like(tensor_to_right)
    tensor_from_right = torch.zeros_like(tensor_to_left)
    send_op_left = torch.distributed.P2POp(
        torch.distributed.isend,
        tensor_to_left,
        left_rank,
        group=group,
    )
    send_op_right = torch.distributed.P2POp(
        torch.distributed.isend,
        tensor_to_right,
        right_rank,
        group=group,
    )
    recv_op_left = torch.distributed.P2POp(
        torch.distributed.irecv,
        tensor_from_left,
        left_rank,
        group=group,
    )
    recv_op_right = torch.distributed.P2POp(
        torch.distributed.irecv,
        tensor_from_right,
        right_rank,
        group=group,
    )
    reqs = torch.distributed.batch_isend_irecv([send_op_right, send_op_left, recv_op_right, recv_op_left])
    for req in reqs:
        req.wait()
    return tensor_from_right, tensor_from_left


class NeighbourExchange(torch.autograd.Function):
    @staticmethod
    def forward(ctx, from_rank, to_rank, group, tensor):
        ctx.group = group
        ctx.from_rank = from_rank
        ctx.to_rank = to_rank
        return neighbour_exchange(from_rank, to_rank, tensor, group=group)

    @staticmethod
    def backward(ctx, grad_output):
        return (None, None, None) + (NeighbourExchange.apply(ctx.to_rank, ctx.from_rank, ctx.group, grad_output),)


def neighbour_exchange_with_grad(from_rank, to_rank, tensor, group=None):
    return NeighbourExchange.apply(from_rank, to_rank, group, tensor)


class NeighbourExchangeBidir(torch.autograd.Function):
    @staticmethod
    def forward(ctx, left_rank, right_rank, group, tensor_to_left, tensor_to_right):
        ctx.group = group
        ctx.left_rank = left_rank
        ctx.right_rank = right_rank
        return neighbour_exchange_bidir(left_rank, right_rank, tensor_to_left, tensor_to_right, group=group)

    @staticmethod
    def backward(ctx, *grad_outputs):
        return (None, None, None) + \
            NeighbourExchangeBidir.apply(ctx.right_rank, ctx.left_rank, ctx.group, *grad_outputs)


def neighbour_exchange_bidir_with_grad(left_rank, right_rank, tensor_to_left, tensor_to_right, group=None):
    return NeighbourExchangeBidir.apply(left_rank, right_rank, group, tensor_to_left, tensor_to_right)



class SigLipLossTest(nn.Module):
    """ Sigmoid Loss for Language Image Pre-Training (SigLIP) - https://arxiv.org/abs/2303.15343

    @article{zhai2023sigmoid,
      title={Sigmoid loss for language image pre-training},
      author={Zhai, Xiaohua and Mustafa, Basil and Kolesnikov, Alexander and Beyer, Lucas},
      journal={arXiv preprint arXiv:2303.15343},
      year={2023}
    }
    """
    def __init__(
            self,
            cache_labels=False,
            rank=0,
            world_size=1,
            bidir=True,
            # use_horovod=False,
    ):
        super().__init__()
        self.cache_labels = cache_labels
        self.rank = rank
        self.world_size = world_size
        # assert not use_horovod  # FIXME need to look at hvd ops for ring transfers
        # self.use_horovod = use_horovod
        self.bidir = bidir

        # cache state FIXME cache not currently used, worthwhile?
        self.prev_num_logits = 0
        self.labels = {}

    def get_ground_truth(self, device, dtype, num_img, num_text, negative_only=False) -> torch.Tensor:
        labels = -torch.ones((num_img, num_text), device=device, dtype=dtype)

        if not negative_only:
            base_block = torch.ones(8, 4) * 2
            assert num_img % 8 == 0 and num_text % 4 == 0 and num_img/8 == num_text/4, 'The view is not enought!'
            bs = int(num_img / 8)
            base_list = [base_block] * bs  # List containing 4 copies of the base tensor
            gt_block = torch.block_diag(*base_list).to(device)
            labels = gt_block + labels
        return labels

    def get_logits(self, image_features, text_features, logit_scale, logit_bias=None):
        logits = logit_scale * image_features @ text_features.T
        if logit_bias is not None:
            logits += logit_bias
        return logits

    def _loss(self, image_features, text_features, logit_scale, logit_bias=None, negative_only=False):
        logits = self.get_logits(image_features, text_features, logit_scale, logit_bias)
        labels = self.get_ground_truth(
            image_features.device,
            image_features.dtype,
            image_features.shape[0],
            text_features.shape[0],
            negative_only=negative_only,
        )
        # loss = -F.logsigmoid(labels * logits).sum() / image_features.shape[0]
        loss = -F.logsigmoid(labels * logits).mean()

        return loss

    def forward(self, image_features, text_features, logit_scale, logit_bias=None, output_dict=False):
        loss = self._loss(image_features, text_features, logit_scale, logit_bias)
                        #image_features [bg * 8, c]
                        #text_features [bg * 4, c]
                        #logit_scale []

        # assert self.world_size > 1, 'The world_size should be more than 1.'
        # # exchange text features w/ neighbour world_size - 1 times
        # right_rank = (self.rank + 1) % self.world_size
        # left_rank = (self.rank - 1 + self.world_size) % self.world_size
        # if self.bidir:
        #     text_features_to_right = text_features_to_left = text_features
        #     num_bidir, remainder = divmod(self.world_size - 1, 2)
        #     for i in range(num_bidir):
        #         text_features_recv = neighbour_exchange_bidir_with_grad(
        #             left_rank,
        #             right_rank,
        #             text_features_to_left,
        #             text_features_to_right,
        #         )
        #
        #         for f in text_features_recv:
        #             loss += self._loss(
        #                 image_features,
        #                 f,
        #                 logit_scale,
        #                 logit_bias,
        #                 negative_only=True,
        #             )
        #         text_features_to_left, text_features_to_right = text_features_recv
        #
        #     if remainder:
        #         text_features_recv = neighbour_exchange_with_grad(
        #             left_rank, right_rank, text_features_to_right)
        #
        #         loss += self._loss(
        #             image_features,
        #             text_features_recv,
        #             logit_scale,
        #             logit_bias,
        #             negative_only=True,
        #         )
        # else:
        #     text_features_to_right = text_features
        #     for i in range(self.world_size - 1):
        #         text_features_from_left = neighbour_exchange_with_grad(
        #             left_rank, right_rank, text_features_to_right)
        #
        #         loss += self._loss(
        #             image_features,
        #             text_features_from_left,
        #             logit_scale,
        #             logit_bias,
        #             negative_only=True,
        #         )
        #         text_features_to_right = text_features_from_left

        return {"contrastive_loss": loss} if output_dict else loss