# Code adapted from SimCSE (https://github.com/princeton-nlp/SimCSE) governed by MIT license.

# Copyright (c) 2023, Salesforce, Inc.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause

import torch
import torch.distributed as dist

class GatherLayer(torch.autograd.Function):
    """
    Gather tensors from all process, supporting backward propagation.
    https://github.com/Spijkervet/SimCLR/blob/master/simclr/modules/gather.py
    """
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        output = [torch.zeros_like(input) for _ in range(dist.get_world_size())]
        dist.all_gather(output, input)
        return tuple(output)

    @staticmethod
    def backward(ctx, *grads):
        (input,) = ctx.saved_tensors
        grad_out = torch.zeros_like(input)
        grad_out[:] = grads[dist.get_rank()]
        return grad_out


def dist_gather(x: torch.tensor):
    if not dist.is_initialized():  return x
    if len(x.shape) == 0:
        x = x.reshape(1)
    x_gather = GatherLayer.apply(x)
    x_gather = torch.cat(x_gather, dim=0)
    return x_gather


@torch.no_grad()
def dist_gather_nograd(x: torch.tensor):
    if not dist.is_initialized():  return x
    x_gather = [torch.ones_like(x) for _ in range(get_world_size())]
    dist.all_gather(x_gather, x, async_op=False)
    x_gather = torch.cat(x_gather, dim=0)
    return x_gather


def get_rank():
    if not dist.is_available():
        return 0
    if not dist.is_initialized():
        return 0
    return dist.get_rank()


def is_main():
    return get_rank() == 0


def get_world_size():
    if not dist.is_initialized():
        return 1
    else:
        return dist.get_world_size()

def barrier():
    if dist.is_initialized():
        dist.barrier()


@torch.no_grad()
def varsize_gather_nograd(x: torch.Tensor):
    """gather tensors of different sizes along the first dimension"""
    if not dist.is_initialized():
        return x

    # determine max size
    size = torch.tensor([x.shape[0]], device=x.device, dtype=torch.int)
    allsizes = [torch.zeros_like(size) for _ in range(dist.get_world_size())]
    dist.all_gather(allsizes, size)
    max_size = max([size.cpu().max() for size in allsizes])

    padded = torch.empty(max_size, *x.shape[1:], dtype=x.dtype, device=x.device)
    padded[: x.shape[0]] = x
    output = [torch.zeros_like(padded) for _ in range(dist.get_world_size())]
    dist.all_gather(output, padded)

    output = [tensor[: allsizes[k]] for k, tensor in enumerate(output)]
    output = torch.cat(output, dim=0)

    return output