# modified from https://github.com/hpcaitech/ColossalAI-Examples/tree/main/language/bert/sequene_parallel


import torch
from torch.distributed.distributed_c10d import _get_global_rank

from MQSP_evaluation.distributed_utils import DistGroups

# global val
sp_method = "single"


def set_sp_method(m):
    global sp_method
    sp_method = m


def get_sp_method():
    global sp_method
    return sp_method


_MAX_DATA_DIM = 5

SHAPE_CACHE = {"key_size": {}, "key_numel": {}, "total_numel": 0}  # data always same


def _build_key_size_numel_dictionaries(keys, data):
    """Build the size on rank 0 and broadcast."""
    max_dim = _MAX_DATA_DIM
    sizes = [0 for _ in range(max_dim) for _ in keys]
    # if not SHAPE_CACHE["key_size"]:

    # Pack the sizes on rank zero.
    if DistGroups["sp"].rank() == 0:
        offset = 0
        for key in keys:
            assert data[key].dim() < max_dim, "you should increase MAX_DATA_DIM:%s dim=%s" % (key, data[key].dim())
            size = data[key].size()
            for i, s in enumerate(size):
                sizes[i + offset] = s
            offset += max_dim
    # Move to GPU and broadcast.
    sizes_cuda = torch.cuda.LongTensor(sizes)
    torch.distributed.broadcast(sizes_cuda, _get_global_rank(DistGroups["sp"], 0), group=DistGroups["sp"])

    # Move back to cpu and unpack.
    sizes_cpu = sizes_cuda.cpu()
    key_size = SHAPE_CACHE["key_size"]
    key_numel = SHAPE_CACHE["key_numel"]
    total_numel = 0
    offset = 0
    for key in keys:
        i = 0
        size = []
        numel = 1
        while sizes_cpu[offset + i] > 0:
            this_size = sizes_cpu[offset + i]
            size.append(this_size)
            numel *= this_size
            i += 1
        key_size[key] = size
        key_numel[key] = numel
        total_numel += numel
        offset += max_dim
    SHAPE_CACHE["total_numel"] = total_numel
    return SHAPE_CACHE["key_size"], SHAPE_CACHE["key_numel"], SHAPE_CACHE["total_numel"]


def broadcast_data(keys, data, datatype):
    """Broadcast data from rank zero of each model parallel group to the
    members of the same model parallel group.

    Arguments:
        keys: list of keys in the data disctionary to be broadcasted
        data: data dictionary of string keys and cpu tensor values.
        datatype: torch data type of all tensors in data associated
                  with keys.
    """
    # Build (key, size) and (key, number of elements) dictionaries along
    # with the total number of elements on all ranks.
    key_size, key_numel, total_numel = _build_key_size_numel_dictionaries(keys, data)

    # Pack on rank zero.
    if DistGroups["sp"].rank() == 0:
        # Check that all keys have the same data type.
        # Flatten the data associated with the keys
        flatten_data = torch.cat([data[key].contiguous().view(-1) for key in keys], dim=0).cuda()
    else:
        flatten_data = torch.empty(total_numel, device=torch.cuda.current_device(), dtype=datatype)

    # Broadcast
    torch.distributed.broadcast(flatten_data, _get_global_rank(DistGroups["sp"], 0), group=DistGroups["sp"])

    # Unpack
    output = {}
    offset = 0
    for key in keys:
        size = key_size[key]
        numel = key_numel[key]
        output[key] = flatten_data.narrow(0, offset, numel).view(size)
        offset += numel

    return output
