"""
vllm doesn't support ddp by default

There are some parameter validations w.r.t ParallelConfig, for tensor parallel and pipeline parallel, in vllm. However, these validations don't take data parallel into consideration.
When using accelerate to implement ddp, such validations will prevent the code execution.
So to enable vllm+ddp using accelerate, we need to disable these parameter validations.
"""

import torch
from typing import Optional, List, Dict, Set, Any, Union
import vllm.worker.worker
import vllm.model_executor.parallel_utils.parallel_state


# disable ``assert world_size == tensor_model_parallel_size * pipeline_model_parallel_size``
def initialize_model_parallel(
    tensor_model_parallel_size: int = 1,
    pipeline_model_parallel_size: int = 1,
) -> None:
    """
    Initialize model parallel groups.

    Arguments:
        tensor_model_parallel_size: number of GPUs used for tensor model
            parallelism.
        pipeline_model_parallel_size: number of GPUs used for pipeline model
            parallelism.

    Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
    use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
    the model pipeline. The present function will
    create 4 tensor model-parallel groups and 2 pipeline model-parallel groups:
        4 tensor model-parallel groups:
            [g0, g1], [g2, g3], [g4, g5], [g6, g7]
        2 pipeline model-parallel groups:
            [g0, g2, g4, g6], [g1, g3, g5, g7]
    Note that for efficiency, the caller should make sure adjacent ranks
    are on the same DGX box. For example if we are using 2 DGX-1 boxes
    with a total of 16 GPUs, rank 0 to 7 belong to the first box and
    ranks 8 to 15 belong to the second box.
    """
    # Get world size and rank. Ensure some consistencies.
    assert torch.distributed.is_initialized()
    world_size: int = torch.distributed.get_world_size()

    # if (world_size !=
    #         tensor_model_parallel_size * pipeline_model_parallel_size):
    #     raise RuntimeError(
    #         f"world_size ({world_size}) is not equal to "
    #         f"tensor_model_parallel_size ({tensor_model_parallel_size}) x "
    #         f"pipeline_model_parallel_size ({pipeline_model_parallel_size})")

    num_tensor_model_parallel_groups: int = (world_size //
                                             tensor_model_parallel_size)
    num_pipeline_model_parallel_groups: int = (world_size //
                                               pipeline_model_parallel_size)
    rank = torch.distributed.get_rank()

    # Build the tensor model-parallel groups.
    # assert vllm.model_executor.parallel_utils.parallel_state._TENSOR_MODEL_PARALLEL_GROUP is None, (
        # "tensor model parallel group is already initialized")
    for i in range(num_tensor_model_parallel_groups):
        ranks = range(i * tensor_model_parallel_size,
                      (i + 1) * tensor_model_parallel_size)
        group = torch.distributed.new_group(ranks)
        if rank in ranks:
            vllm.model_executor.parallel_utils.parallel_state._TENSOR_MODEL_PARALLEL_GROUP = group

    # Build the pipeline model-parallel groups.
    # assert vllm.model_executor.parallel_utils.parallel_state._PIPELINE_MODEL_PARALLEL_GROUP is None, (
    #     "pipeline model parallel group is already initialized")
    for i in range(num_pipeline_model_parallel_groups):
        ranks = range(i, world_size, num_pipeline_model_parallel_groups)
        group = torch.distributed.new_group(ranks)
        if rank in ranks:
            vllm.model_executor.parallel_utils.parallel_state._PIPELINE_MODEL_PARALLEL_GROUP = group
            vllm.model_executor.parallel_utils.parallel_state._PIPELINE_GLOBAL_RANKS = ranks
            

# disable ``assert torch_world_size == parallel_config.world_size``
def _init_distributed_environment(
    parallel_config: vllm.config.ParallelConfig,
    rank: int,
    distributed_init_method: Optional[str] = None,
) -> None:
    """Initialize the distributed environment."""
    if torch.distributed.is_initialized():
        torch_world_size = torch.distributed.get_world_size()
        # if torch_world_size != parallel_config.world_size:
        #     raise RuntimeError(
        #         "torch.distributed is already initialized but the torch world "
        #         "size does not match parallel_config.world_size "
        #         f"({torch_world_size} vs. {parallel_config.world_size}).")
    elif not distributed_init_method:
        raise ValueError(
            "distributed_init_method must be set if torch.distributed "
            "is not already initialized")
    else:
        torch.distributed.init_process_group(
            backend="nccl",
            world_size=parallel_config.world_size,
            rank=rank,
            init_method=distributed_init_method,
        )

    # A small all_reduce for warmup.
    torch.distributed.all_reduce(torch.zeros(1).cuda())
    vllm.model_executor.parallel_utils.parallel_state.initialize_model_parallel(parallel_config.tensor_parallel_size,
                              parallel_config.pipeline_parallel_size)


def disable_parallel_config_checking():
    vllm.worker.worker._init_distributed_environment = _init_distributed_environment
    vllm.model_executor.parallel_utils.parallel_state.initialize_model_parallel = initialize_model_parallel


