# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional

import torch
import torch.distributed as dist
from torch.distributed._tensor import DeviceMesh, distribute_tensor, Shard
from torch.distributed.tensor import DTensor


def get_ranks_in_group(process_group):
    """
    Get all ranks that are part of the given process group.

    Args:
        process_group: The process group to get ranks for

    Returns:
        List of global ranks in the process group
    """
    world_size = dist.get_world_size(process_group)
    ranks = []
    for local_rank in range(world_size):
        global_rank = dist.distributed_c10d.get_global_rank(process_group, local_rank)
        ranks.append(global_rank)
    return ranks


def transfer_layer_across_pp(
    tensor: torch.Tensor,
    src_pp_rank: int,
    pp_group: dist.ProcessGroup,
    fsdp_mesh: Optional[DeviceMesh] = None,
) -> Optional[torch.Tensor]:
    """
    Transfer a distributed tensor from one PP rank to all other PP ranks in a hybrid parallelism setup.

    Args:
        tensor: The tensor to transfer (can be a regular Tensor or DTensor)
        src_pp_rank: Source PP rank that currently holds the tensor
        pp_group: Process group for pipeline parallel communication
        fsdp_mesh: Device mesh for the FSDP sharding
        compress: Whether to apply compression
        compression_rate: The compression rate to use (higher means more compression)

    Returns:
        Distributed tensor on all ranks, resharded according to FSDP mesh if provided
    """
    # Import logger
    from torchtitan.tools.logging import logger

    # Get current PP rank and global rank
    curr_rank = dist.get_rank(pp_group)
    global_rank = dist.get_rank()

    logger.debug(
        f"PP Rank {curr_rank}: Starting transfer_layer_across_pp, tensor is None: {tensor is None}"
    )

    # Get ranks in the PP group
    ranks = get_ranks_in_group(pp_group)
    global_src_rank_pp = ranks[src_pp_rank]

    logger.debug(f"PP Rank {curr_rank}: Global source rank: {global_src_rank_pp}")

    # Get device from tensor or current device if tensor is None
    device = (
        tensor.device
        if tensor is not None
        else torch.device(f"cuda:{torch.cuda.current_device()}")
    )
    logger.debug(f"PP Rank {curr_rank}: Using device: {device}")

    # First, gather FSDP shards on the source rank if tensor is distributed
    if curr_rank == src_pp_rank:
        logger.debug(f"PP Rank {curr_rank}: I am the source rank")
        # Gather distributed tensor shards
        if isinstance(tensor, DTensor):
            logger.debug(f"PP Rank {curr_rank}: Converting DTensor to full tensor")
            full_tensor = tensor.full_tensor()
        else:
            logger.debug(f"PP Rank {curr_rank}: Using tensor directly")
            full_tensor = tensor

        logger.debug(
            f"PP Rank {curr_rank}: Full tensor shape: {full_tensor.shape}, type: {type(full_tensor)}"
        )

        # Broadcast tensor size to other ranks
        tensor_size = torch.tensor(full_tensor.size(), device=device)
        logger.debug(f"PP Rank {curr_rank}: Broadcasting tensor size: {tensor_size}")
        dist.broadcast(tensor_size, src=global_src_rank_pp, group=pp_group)

        # Broadcast the full tensor
        logger.debug(f"PP Rank {curr_rank}: Broadcasting full tensor")
        dist.broadcast(full_tensor, src=global_src_rank_pp, group=pp_group)
        logger.debug(f"PP Rank {curr_rank}: Broadcast complete")
    else:
        logger.debug(f"PP Rank {curr_rank}: I am a receiver rank")
        # Receive tensor size
        tensor_size = torch.empty(2, dtype=torch.long, device=device)
        logger.debug(f"PP Rank {curr_rank}: Receiving tensor size")
        dist.broadcast(tensor_size, src=global_src_rank_pp, group=pp_group)
        logger.debug(f"PP Rank {curr_rank}: Received tensor size: {tensor_size}")

        # Receive the full tensor
        logger.debug(
            f"PP Rank {curr_rank}: Creating empty tensor with size {tensor_size.tolist()}"
        )
        full_tensor = torch.empty(*tensor_size.tolist(), device=device)
        logger.debug(f"PP Rank {curr_rank}: Receiving full tensor")
        dist.broadcast(full_tensor, src=global_src_rank_pp, group=pp_group)
        logger.debug(
            f"PP Rank {curr_rank}: Received full tensor with shape: {full_tensor.shape}"
        )

    # Re-shard the tensor according to FSDP mesh
    if fsdp_mesh is not None:
        logger.debug(f"PP Rank {curr_rank}: Resharding tensor with FSDP mesh")
        dist_tensor = distribute_tensor(full_tensor, fsdp_mesh, [Shard(0)])
        logger.debug(f"PP Rank {curr_rank}: Resharded tensor type: {type(dist_tensor)}")
    else:
        logger.debug(f"PP Rank {curr_rank}: No resharding needed")
        dist_tensor = full_tensor

    logger.debug(f"PP Rank {curr_rank}: Waiting at barrier")
    # Use barrier without timeout since it's not supported in this version
    try:
        dist.barrier(group=pp_group)
        logger.debug(f"PP Rank {curr_rank}: Passed barrier, returning tensor")
    except Exception as e:
        logger.warning(
            f"PP Rank {curr_rank}: Barrier failed: {str(e)}, continuing anyway"
        )

    return dist_tensor


def sync_tensor_across_group(
    tensor: Optional[torch.Tensor],
    src_rank: int,
    process_group: Optional[dist.ProcessGroup] = None,
    device: Optional[torch.device] = None,
    shard_mesh: Optional[DeviceMesh] = None,
) -> Optional[torch.Tensor]:
    """
    Synchronize a tensor from one rank to all other ranks in a process group.

    Args:
        tensor: The tensor to synchronize (None on non-source ranks)
        src_rank: Global rank that holds the source tensor
        process_group: Process group for communication (None for world group)
        device: Device to place tensors on (auto-detected if None)
        shard_mesh: Device mesh for resharding (optional)

    Returns:
        The synchronized tensor on all ranks
    """
    from torchtitan.tools.logging import logger

    # Get current rank
    global_rank = dist.get_rank()

    # Determine device
    if device is None:
        if tensor is not None:
            device = tensor.device
        else:
            device = torch.device(f"cuda:{torch.cuda.current_device()}")

    logger.debug(f"Rank {global_rank}: Starting sync_tensor_across_group")

    # Source rank broadcasts the tensor
    if global_rank == src_rank:
        if tensor is None:
            raise ValueError(f"Source rank {src_rank} must provide a non-None tensor")

        # Handle DTensor by converting to full tensor
        if isinstance(tensor, DTensor):
            logger.debug(f"Rank {global_rank}: Converting DTensor to full tensor")
            full_tensor = tensor.full_tensor()
        else:
            logger.debug(f"Rank {global_rank}: Using tensor directly")
            full_tensor = tensor.contiguous()

        # Broadcast tensor size first
        tensor_size = torch.tensor(full_tensor.size(), device=device)
        logger.debug(f"Rank {global_rank}: Broadcasting tensor size: {tensor_size}")
        dist.broadcast(tensor_size, src=src_rank, group=process_group)

        # Broadcast the tensor
        logger.debug(f"Rank {global_rank}: Broadcasting tensor")
        dist.broadcast(full_tensor, src=src_rank, group=process_group)
        logger.debug(f"Rank {global_rank}: Broadcast complete")

    else:
        # Non-source ranks receive the tensor
        logger.debug(f"Rank {global_rank}: Receiving from rank {src_rank}")

        # Receive tensor size
        tensor_size = torch.empty(2, dtype=torch.long, device=device)
        logger.debug(f"Rank {global_rank}: Receiving tensor size")
        dist.broadcast(tensor_size, src=src_rank, group=process_group)
        logger.debug(f"Rank {global_rank}: Received tensor size: {tensor_size}")

        # Create empty tensor and receive data
        full_tensor = torch.empty(*tensor_size.tolist(), device=device)
        logger.debug(f"Rank {global_rank}: Receiving tensor")
        dist.broadcast(full_tensor, src=src_rank, group=process_group)
        logger.debug(
            f"Rank {global_rank}: Received tensor with shape: {full_tensor.shape}"
        )

    # Re-shard if requested
    if shard_mesh is not None:
        logger.debug(f"Rank {global_rank}: Resharding tensor")
        result_tensor = distribute_tensor(full_tensor, shard_mesh, [Shard(0)])
        logger.debug(
            f"Rank {global_rank}: Resharded tensor type: {type(result_tensor)}"
        )
    else:
        logger.debug(f"Rank {global_rank}: No resharding needed")
        result_tensor = full_tensor

    # Synchronize all ranks
    logger.debug(f"Rank {global_rank}: Waiting at barrier")
    try:
        dist.barrier(group=process_group)
        logger.debug(f"Rank {global_rank}: Passed barrier, returning tensor")
    except Exception as e:
        logger.warning(
            f"Rank {global_rank}: Barrier failed: {str(e)}, continuing anyway"
        )

    return result_tensor
