import torch
import torch.nn.parallel.scatter_gather as torch_
from src.sparse_torch import CSRMatrix3d, CSCMatrix3d, concatenate


def scatter(inputs, target_gpus, dim=0):
    """
    Slices tensors into approximately equal chunks and
    distributes them across given GPUs. Duplicates
    references to objects that are not tensors.
    """
    def scatter_map(obj):
        if isinstance(obj, torch.Tensor):
            return torch_.Scatter.apply(target_gpus, None, dim, obj)
        if isinstance(obj, tuple) and len(obj) > 0:
            return list(zip(*map(scatter_map, obj)))
        if isinstance(obj, list) and len(obj) > 0:
            return list(map(list, zip(*map(scatter_map, obj))))
        if isinstance(obj, dict) and len(obj) > 0:
            return list(map(type(obj), zip(*map(scatter_map, obj.items()))))

        # modified here
        if isinstance(obj, CSRMatrix3d) or isinstance(obj, CSCMatrix3d):
            return scatter_sparse_matrix(target_gpus, obj)

        return [obj for targets in target_gpus]

    # After scatter_map is called, a scatter_map cell will exist. This cell
    # has a reference to the actual function scatter_map, which has references
    # to a closure that has a reference to the scatter_map cell (because the
    # fn is recursive). To avoid this reference cycle, we set the function to
    # None, clearing the cell
    try:
        return scatter_map(inputs)
    finally:
        scatter_map = None


def scatter_kwargs(inputs, kwargs, target_gpus, dim=0):
    """Scatter with support for kwargs dictionary"""
    inputs = scatter(inputs, target_gpus, dim) if inputs else []
    kwargs = scatter(kwargs, target_gpus, dim) if kwargs else []
    if len(inputs) < len(kwargs):
        inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
    elif len(kwargs) < len(inputs):
        kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
    inputs = tuple(inputs)
    kwargs = tuple(kwargs)
    return inputs, kwargs


def scatter_sparse_matrix(target_gpus, obj):
    """Scatter for customized sparse matrix"""
    def get_device(i):
        return torch.device('cuda:{}'.format(i)) if i != -1 else torch.device('cpu')
    step = len(obj) // len(target_gpus)
    return tuple([obj[i:i+step].to(get_device(i // step)) for i in range(0, len(obj), step)])


def gather(outputs, target_device, dim=0):
    """
    Gathers tensors from different GPUs on a specified device (-1 means the CPU).
    """
    def gather_map(outputs):
        out = outputs[0]
        if isinstance(out, torch.Tensor):
            return torch_.Gather.apply(target_device, dim, *outputs)

        # modified here
        if isinstance(out, CSRMatrix3d) or isinstance(out, CSCMatrix3d):
            return concatenate(*outputs, device=target_device)

        if out is None:
            return None
        if isinstance(out, dict):
            if not all((len(out) == len(d) for d in outputs)):
                raise ValueError('All dicts must have the same number of keys')
            return type(out)(((k, gather_map([d[k] for d in outputs]))
                              for k in out))
        if isinstance(out, int):
            assert all([out == _ for _ in outputs])
            return out

        return type(out)(map(gather_map, zip(*outputs)))

    # Recursive function calls like this create reference cycles.
    # Setting the function to None clears the refcycle.
    try:
        return gather_map(outputs)
    finally:
        gather_map = None
