              
                                                      
                                                                 

                 

from datetime import datetime
import os

import torch
from torch.distributed import ReduceOp

try:
    from vllm import SamplingParams
    from vllm.engine.arg_utils import AsyncEngineArgs
    try:
        from vllm.model_executor.model_loader.loader import ShardedStateLoader
    except ImportError:
        from vllm.model_executor.model_loader.sharded_state_loader import ShardedStateLoader
    from vllm.usage.usage_lib import UsageContext
    from vllm.v1.engine.async_llm import AsyncLLM as _AsyncLLM
    from vllm.distributed.parallel_state import get_tp_group, get_pp_group
    from vllm.platforms import current_platform
except ImportError:
    pass


                                                                                                                                            
def stateless_init_process_group(master_address, master_port, rank, world_size, device):
    """
    vLLM provides `StatelessProcessGroup` to create a process group
    without considering the global process group in torch.distributed.
    It is recommended to create `StatelessProcessGroup`, and then initialize
    the data-plane communication (NCCL) between external (train processes) 
    and vLLM workers.
    """
    from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
    from vllm.distributed.utils import StatelessProcessGroup

    pg = StatelessProcessGroup.create(host=master_address,
                                      port=master_port,
                                      rank=rank,
                                      world_size=world_size)
    pynccl = PyNcclCommunicator(pg, device=device)
    return pynccl


def init_model_update_group(worker_wrap, master_address, master_port, dp_rank, world_size):
    ret = True
    debug_dir = "./debug-tmp"
    os.makedirs(debug_dir, exist_ok=True)
    with open(f"{debug_dir}/init_group_dp{dp_rank}.txt", 'a') as outf:
        outf.write(f"init_model_update_group {datetime.now()}\n")
        try:
            tp_group = get_tp_group()
            pp_group = get_pp_group()
            tp_size = tp_group.world_size
            pp_size = pp_group.world_size
            tp_rank = tp_group.rank_in_group
            pp_rank = pp_group.rank_in_group
            mp_size = tp_size * pp_size
            device = torch.cuda.current_device()

            rank = dp_rank * mp_size + pp_rank * tp_size + tp_rank
            model_update_group = stateless_init_process_group(
                master_address,
                master_port,
                rank,
                world_size,
                device=device,
            )
                                       
            worker_wrap._gcore_rank = rank
            worker_wrap._gcore_model_update_group = model_update_group

            lst = [tp_size, pp_size, mp_size, tp_rank, pp_rank, dp_rank, rank]
            outf.write(f"enter {rank=} {dp_rank=} {device=} {worker_wrap.device=} "
                       f"{worker_wrap._gcore_rank=} \n"
                       f"meta_info {lst} {world_size=} {master_address=} {master_port=}\n")
            outf.flush()
            ret = True
        except Exception as e:
            outf.write(f"cat exception {e}\n")
            outf.flush()
            ret = False
    return ret


def setup_head_and_src_rank(worker_wrap, actor_pp_rank, actor_ep_rank, head_device_uuid,
                            sampler_dp_rank):
    ret = True
    debug_dir = "./debug-tmp"
    os.makedirs(debug_dir, exist_ok=True)
                            
    with open(
            f"{debug_dir}/define_sdp_{sampler_dp_rank}_app_{actor_pp_rank}_aep_{actor_ep_rank}.txt",
            'a') as outf:
        outf.write(f"setup_head_and_src_rank {datetime.now()}\n")
        try:
            rank = worker_wrap._gcore_rank
            device = torch.cuda.current_device()
            curr_device_uuid = current_platform.get_device_uuid(worker_wrap.device.index)

            if curr_device_uuid == head_device_uuid:
                                                                      
                setattr(worker_wrap, f'_gcore_mp_head_{actor_pp_rank}_{actor_ep_rank}', True)
                setattr(worker_wrap, f'_gcore_mp_head_device_uuid_{actor_pp_rank}_{actor_ep_rank}',
                        head_device_uuid)
                src_rank_tensor = torch.tensor(rank, dtype=torch.int64, device=device)
            else:
                setattr(worker_wrap, f'_gcore_mp_head_{actor_pp_rank}_{actor_ep_rank}', False)
                setattr(worker_wrap, f'_gcore_mp_head_device_uuid_{actor_pp_rank}_{actor_ep_rank}',
                        None)
                src_rank_tensor = torch.tensor(0, dtype=torch.int64, device=device)

            src_rank_tensor = worker_wrap._gcore_model_update_group.all_reduce(src_rank_tensor,
                                                                               op=ReduceOp.SUM)
            src_rank = src_rank_tensor.item()
            setattr(worker_wrap, f'_gcore_src_rank_for_{actor_pp_rank}_{actor_ep_rank}', src_rank)

            outf.write(
                f"setup_head_and_src_rank {actor_pp_rank=} {actor_ep_rank=} {device=} "
                f"{rank=} {src_rank=} head {curr_device_uuid == head_device_uuid} {curr_device_uuid=} "
                f"{head_device_uuid=}\n")
            ret = True
        except Exception as e:
            outf.write(f"cat exception {e}\n")
            outf.flush()
            ret = False
    return ret, src_rank


def gcore_save_vllm_checkpoint(worker_wrap, checkpoint_dir):
    debug_dir = "./debug-tmp"
    os.makedirs(debug_dir, exist_ok=True)
    with open(f'{debug_dir}/save_checkpoint.txt', 'a') as outf:
        outf.write(f"update_engine_weight {datetime.now()}\n")
        try:
            model = worker_wrap.worker.model_runner.model
                                                                     
            ShardedStateLoader.save_model(
                model,
                checkpoint_dir,
                pattern=None,
                max_size=None,
            )
            outf.write(f"save ckpt at {checkpoint_dir}")
            outf.flush()
        except Exception as e:
            outf.write(f"exception {e}")
            outf.flush()
    return True
