import torch
import torch.distributed as dist
from typing import Any, List, Optional

def ddp_gather_to_rank0(tensor: torch.Tensor):
    if not dist.is_available() or not dist.is_initialized():
        return tensor

    world_size = dist.get_world_size()
    rank = dist.get_rank()

    # Gathered tensors will be stored here on rank 0
    gather_list = [torch.zeros_like(tensor) for _ in range(world_size)] if rank == 0 else None

    dist.gather(tensor, gather_list=gather_list, dst=0)

    if rank == 0:
        return gather_list
    else:
        return None
    
def ddp_all_gather_to_rank0(obj: Any) -> Optional[Any]:
    """
    Gather *obj* from every rank to rank-0.

    Parameters
    ----------
    obj
        Either a torch.Tensor **with the same shape on every rank**
        or an arbitrary picklable Python object (e.g. list of strings).

    Returns
    -------
    • On rank-0  → the list of tensor *or* flattened list.  
    • On other ranks → None.
    """
    if not dist.is_available() or not dist.is_initialized():
        return obj  # single-GPU fallback

    world_size = dist.get_world_size()
    rank = dist.get_rank()

    if torch.is_tensor(obj):
        if obj.device.type == 'cpu' and dist.get_backend() == 'nccl':
            obj = obj.cuda()  # Ensure tensor is on GPU for NCCL backend
        # --------------- tensor path ----------------
        gather_list: List[torch.Tensor] = [
            torch.empty_like(obj) for _ in range(world_size)
        ]
        dist.all_gather(gather_list, obj)
        if rank == 0:
            return gather_list
        return None
    else:
        # --------------- generic Python object path ----------------
        gather_list: List[Any] = [None for _ in range(world_size)]
        dist.all_gather_object(gather_list, obj)  # no type restrictions
        if rank == 0:
            return gather_list
        return None
    

def ddp_all_gather_variable_tensor_to_rank0(tensor: torch.Tensor) -> Optional[torch.Tensor]:
    """
    Gather tensors of potentially different first-dim sizes across all ranks to rank-0.
    """
    if not dist.is_available() or not dist.is_initialized():
        return tensor

    world_size = dist.get_world_size()
    rank = dist.get_rank()
    device = tensor.device

    # Step 1: Gather sizes
    local_size = torch.tensor([tensor.size(0)], device=device)
    all_sizes = [torch.zeros_like(local_size) for _ in range(world_size)]
    dist.all_gather(all_sizes, local_size)
    sizes = [int(sz.item()) for sz in all_sizes]
    max_size = max(sizes)

    # Step 2: Pad local tensor
    pad_len = max_size - tensor.size(0)
    if pad_len > 0:
        padding = torch.zeros((pad_len, *tensor.shape[1:]), dtype=tensor.dtype, device=device)
        tensor = torch.cat([tensor, padding], dim=0)

    # Step 3: Gather padded tensors
    gather_list = [torch.zeros_like(tensor) for _ in range(world_size)]
    dist.all_gather(gather_list, tensor)

    # Step 4: Truncate to actual sizes on rank 0
    if rank == 0:
        out = [g[:s] for g, s in zip(gather_list, sizes)]
        return out
    else:
        return None