              
                                                      
                                                                 

from datetime import datetime
from typing import Callable
from dataclasses import dataclass
import asyncio
import io
import math
import time
import json
from typing import Dict, Any, List

from torch.multiprocessing.reductions import reduce_tensor
import torch

try:
    from vllm.platforms import current_platform
except ImportError:
    pass

try:
    from gpatch.core.sampler_v3.sglang import get_device_uuid
    from sglang.srt.utils import MultiprocessingSerializer
except ImportError:
    pass

from megatron.core import mpu
from megatron.core.utils import get_model_config

from gpatch.core.aligner_helper import clear_memory
from gpatch.core.utils import to_cuda_if_not_none, to_cpu_if_not_none
from gpatch.core.utils import print_with_rank_and_datetime
from gpatch.core.parallel_state import (
    is_tp_dp_and_cp_head,
    is_update_weight_head,
    is_mp_and_cp_head,
    cpu_barrier,
)
from gpatch.core.aligner_helper import flatten_weights
from gpatch.core.aligner_helper import broadcast_object_within_mp_and_cp
from gpatch.core.models.gpt.weight_conversion import mcore_to_hf_weights
from gpatch.rpc import call_once_rpc


@dataclass
class GptPpoSamplerClientV3:
    ep_ips: List[str]
    ep_ports: List[int]
    timeout: int
    update_timeout: int
    rpc_max_retries: int
    num_rm: int = 1
    unwrap_model_func: Callable = None
    infer_engine_impl: str = None
    update_weight_max_size_mb: int = 64

    def __post_init__(self):
        assert len(self.ep_ips) == len(self.ep_ports)
        dp_size = mpu.get_data_parallel_world_size()
        dp_rank = mpu.get_data_parallel_rank()
        nep = len(self.ep_ips)

        self.max_size_bytes_to_send = self.update_weight_max_size_mb * 1024 * 1024

        if torch.distributed.get_rank() == 0:
            print(
                f"GptPpoSamplerClientV3 update_weights_max_size_in_mb {self.max_size_bytes_to_send}"
            )

    def wait_until_sampler_server_is_ready(self):
        for ep_ip, ep_port in zip(self.ep_ips, self.ep_ports, strict=True):
            url = f'http://{ep_ip}:{ep_port}/heartbeat'
            co = call_once_rpc(url, {})
            asyncio.run(co)

    def setup(self):

        async def rpc_co(ep_ip, ep_port):
            url = f'http://{ep_ip}:{ep_port}/setup'
            return await call_once_rpc(url, {})

        async def co():
            rpc_cos = []
            for ep_ip, ep_port in zip(self.ep_ips, self.ep_ports, strict=True):
                rpc_cos.append(rpc_co(ep_ip, ep_port))
            return await asyncio.gather(*rpc_cos)

        if torch.distributed.get_rank() == 0:
            asyncio.run(co())

    def sleep(self):

        async def rpc_co(ep_ip, ep_port):
            url = f'http://{ep_ip}:{ep_port}/sleep'
            return await call_once_rpc(url, {})

        async def co():
            rpc_cos = []
            for ep_ip, ep_port in zip(self.ep_ips, self.ep_ports, strict=True):
                rpc_cos.append(rpc_co(ep_ip, ep_port))
            return await asyncio.gather(*rpc_cos)

        if torch.distributed.get_rank() == 0:
            asyncio.run(co())

    def wake_up(self):
                        
        async def wake_rpc(ep_ip, ep_port, req_dict):
            url = f'http://{ep_ip}:{ep_port}/wake_up'
            return await call_once_rpc(url, req_dict)

        async def co():
            rpc_cos = []
            req_dict = {"tag_names": ["weights", "kv_cache"]}
            for ip, port in zip(self.ep_ips, self.ep_ports):
                rpc_cos.append(wake_rpc(ip, port, req_dict))
            return await asyncio.gather(*rpc_cos)

        if torch.distributed.get_rank() == 0:
            asyncio.run(co())

    def mark_sampler_ppo_step_begin(self, ppo_step):
        dp_rank = mpu.get_data_parallel_rank()
        tag_names = ['weights', 'kv_cache']
        req_dict = {
            'ppo_step': ppo_step,
            'actor_dp_rank': dp_rank,
            'tag_names': tag_names,
        }

        async def rpc_co(ep_ip, ep_port):
            url = f'http://{ep_ip}:{ep_port}/mark_sampler_ppo_step_begin'
            return await call_once_rpc(url, req_dict)

        async def co():
            rpc_cos = []
            for ep_ip, ep_port in zip(self.ep_ips, self.ep_ports, strict=True):
                rpc_cos.append(rpc_co(ep_ip, ep_port))
            return await asyncio.gather(*rpc_cos)

        if torch.distributed.get_rank() == 0:
            asyncio.run(co())

    def mark_sampler_ppo_step_end(self, ppo_step):
        dp_rank = mpu.get_data_parallel_rank()
        req_dict = {
            'ppo_step': ppo_step,
            'actor_dp_rank': dp_rank,
        }

        async def rpc_co(ep_ip, ep_port):
            url = f'http://{ep_ip}:{ep_port}/mark_sampler_ppo_step_end'
            return await call_once_rpc(url, req_dict)

        async def co():
            rpc_cos = []
            for ep_ip, ep_port in zip(self.ep_ips, self.ep_ports, strict=True):
                rpc_cos.append(rpc_co(ep_ip, ep_port))
            return await asyncio.gather(*rpc_cos)

        if torch.distributed.get_rank() == 0:
            asyncio.run(co())

    def pick_ep_idx(self, sample_idx):
        dp_size = mpu.get_data_parallel_world_size()
        dp_rank = mpu.get_data_parallel_rank()
        rdp_size = len(self.ep_ips)
        if dp_size == rdp_size:
            ep_idx = dp_rank
        elif dp_size < rdp_size:
            ep_idx = dp_rank * (rdp_size // dp_size) + sample_idx
        else:
            ep_idx = dp_rank // (dp_size // rdp_size) + sample_idx
        ep_idx %= rdp_size
        return ep_idx

    async def generate(self, ppo_step, sample_idx, batch, sampling_repeat)-> Dict[str, List[Any]]:
        dp_size = mpu.get_data_parallel_world_size()
        dp_rank = mpu.get_data_parallel_rank()
        ep_idx = self.pick_ep_idx(sample_idx)
        sampler_dp_rank = ep_idx

        ep_ip = self.ep_ips[ep_idx]
        ep_port = self.ep_ports[ep_idx]
        req_dict = {
            'actor_dp_rank': dp_rank,
            'actor_dp_size': dp_size,
            'sampler_dp_rank': sampler_dp_rank,
            'ppo_step': ppo_step,
            'sample_idx': sample_idx,
            'batch': batch,
            'sampling_repeat': sampling_repeat,
        }

        url = f'http://{ep_ip}:{ep_port}/generate'
        return await call_once_rpc(url, req_dict, timeout=10 * 60)

    def broadcast_rollout_batch(self, rbs):
        if is_mp_and_cp_head():
            rbs = self.remove_before_broadcast_rollout_batch(rbs)
        output = broadcast_object_within_mp_and_cp(rbs)
        if is_mp_and_cp_head():
            output = self.add_back_after_broadcast_rollout_batch(output)
        return output

                                     
    def remove_before_broadcast_rollout_batch(self, rbs):
        """
        remove unserializable objects before broadcast
        """
        return rbs

                                      
    def add_back_after_broadcast_rollout_batch(self, rbs):
        """
        add back unserializable objects after broadcast
        """
        return rbs

    def init_weight_update_group(self, setup_head=True):

        async def rpc_init(ep_ip, ep_port):
            url = f'http://{ep_ip}:{ep_port}/init_weight_update_group'
            return (await call_once_rpc(url, {}))['init_group_ret']

        async def co():
            rpc_cos = []
            for ep_ip, ep_port in zip(self.ep_ips, self.ep_ports, strict=True):
                rpc_cos.append(rpc_init(ep_ip, ep_port))
            return await asyncio.gather(*rpc_cos)

                                      
        if torch.distributed.get_rank() == 0:
            resps = asyncio.run(co())
            ret = all(resps)
            assert ret, f"init_weight_update_group failed {ret=}"

        cpu_barrier()
        if not setup_head:
            return

        async def rpc_setup(ep_ip, ep_port, req_dict):
            url = f'http://{ep_ip}:{ep_port}/setup_head_and_src_rank'
            return (await call_once_rpc(url, req_dict))['setup_ret']

        async def co():
            device_uuid = current_platform.get_device_uuid(torch.cuda.current_device())
            rpc_cos = []
            req_dict = {
                "head_device_uuid": device_uuid,
                "actor_pp_rank": mpu.get_pipeline_model_parallel_rank(),
                "actor_ep_rank": mpu.get_expert_model_parallel_rank(),
            }
            for ep_ip, ep_port in zip(self.ep_ips, self.ep_ports, strict=True):
                rpc_cos.append(rpc_setup(ep_ip, ep_port, req_dict))
            return await asyncio.gather(*rpc_cos)

        if is_update_weight_head(
                enable_expert_parallel=mpu.get_expert_model_parallel_world_size() > 1):
            resp = asyncio.run(co())
            ret = all(resp)
            assert ret, f"setup_head_and_src_rank failed {ret=}"
        cpu_barrier()

    def update_weights(self, model, replace_zeros=False, early_swap_model=False, cpu_memory_model=None):
        if early_swap_model:
            assert cpu_memory_model is not None

        torch.cuda.synchronize()
        time_t1 = time.time()

                        
        async def wake_rpc(ep_ip, ep_port, req_dict):
            url = f'http://{ep_ip}:{ep_port}/wake_up'
            return await call_once_rpc(url, req_dict)

        async def co():
            rpc_cos = []
            req_dict = {"tag_names": ["weights"]}
            for ip, port in zip(self.ep_ips, self.ep_ports):
                rpc_cos.append(wake_rpc(ip, port, req_dict))
            return await asyncio.gather(*rpc_cos)

        if torch.distributed.get_rank() == 0:
            asyncio.run(co())
        cpu_barrier()

        torch.cuda.synchronize()
        time_t2 = time.time()

        async def rpc_fn(ip, port, req_dict, rpc_func_str, resp_str):
            url = f'http://{ip}:{port}/{rpc_func_str}'
            return (await call_once_rpc(url, req_dict, timeout=self.update_timeout))[resp_str]

        async def rpc_co(req_dict, rpc_func_str, resp_str):
            rpc_cos = []
            for ip, port in zip(self.ep_ips, self.ep_ports):
                rpc_cos.append(rpc_fn(ip, port, req_dict, rpc_func_str, resp_str))
            return await asyncio.gather(*rpc_cos)

        def serialize_weight(src_head_device_uuid, flatten_weight):
            if self.infer_engine_impl == 'vllm':
                ipc_handle = {src_head_device_uuid: reduce_tensor(flatten_weight)}
            else:
                ipc_handle = MultiprocessingSerializer.serialize(flatten_weight)
            return ipc_handle

        def update_func(
            name_lst,
            key_size,
            key_numel,
            ipc_handle,
            flatten_shape,
            flatten_dtype,
            src_head_device_uuid,
            rpc_co_func,
        ):
            resp_ret = True
            rpc_func_str = "update_weights"
            resp_str = "update_success"
            req_dict = dict(
                name=name_lst,
                key_size=key_size,
                key_numel=key_numel,
                actor_pp_rank=mpu.get_pipeline_model_parallel_rank(),
                actor_ep_rank=mpu.get_expert_model_parallel_rank(),
                ipc_handle=ipc_handle,
                flatten_shape=flatten_shape,
                flatten_dtype=flatten_dtype,
                src_head_device_uuid=src_head_device_uuid,
            )

            resps = asyncio.run(rpc_co_func(req_dict, rpc_func_str, resp_str))
            resp_ret = all(resps)
            assert resp_ret, f"update_weight {name_lst=} failed {resp_ret=}"
            return resp_ret

        def update_model_func(model_param_type, is_moe_model=False):
            current_size_bytes = 0
            name_lst = []
            weight_lst = []
            weight_dtype = None
            if self.infer_engine_impl == 'vllm':
                src_head_device_uuid = current_platform.get_device_uuid(torch.cuda.current_device())
            else:
                src_head_device_uuid = get_device_uuid(torch.cuda.current_device())

            assert model_param_type in ["moe", "dense"]
            if not is_moe_model:
                assert model_param_type in ["dense"]
            for name, weight in mcore_to_hf_weights(model,
                                                    self.unwrap_model_func,
                                                    model_param_type,
                                                    early_swap_model=early_swap_model,
                                                    cpu_memory_model=cpu_memory_model):
                if is_update_weight_head(
                        enable_expert_parallel=mpu.get_expert_model_parallel_world_size() > 1):
                    if replace_zeros:
                        weight = torch.zeros_like(weight)
                    else:
                        weight = weight.contiguous()
                    if weight_dtype is None:
                        weight_dtype = weight.dtype
                    else:
                        assert weight_dtype == weight.dtype, \
                            f"weight dtype {weight_dtype} must be same as {weight_dtype=}"
                    tensor_size = weight.element_size() * weight.numel()
                    name_lst.append(name)
                    weight_lst.append(weight)
                    current_size_bytes += tensor_size

                    if current_size_bytes >= self.max_size_bytes_to_send:
                        torch.cuda.synchronize()
                        assert len(name_lst) == len(weight_lst)
                        flatten_weight, key_size, key_numel, total_numel = flatten_weights(
                            name_lst, weight_lst)
                        flatten_shape = flatten_weight.shape
                        flatten_dtype = weight_dtype
                        ipc_handle = serialize_weight(src_head_device_uuid, flatten_weight)

                        resp_ret = update_func(name_lst, key_size, key_numel, ipc_handle, flatten_shape,
                                               flatten_dtype, src_head_device_uuid, rpc_co)
                        name_lst = []
                        current_size_bytes = 0
                                     
                        del ipc_handle
                        weight_lst.clear()
                        weight_lst = []
                        torch.cuda.empty_cache()

            if is_update_weight_head(
                    enable_expert_parallel=mpu.get_expert_model_parallel_world_size() > 1):
                assert len(name_lst) == len(weight_lst)
                if len(name_lst) > 0:
                    torch.cuda.synchronize()
                    flatten_weight, key_size, key_numel, total_numel = flatten_weights(
                        name_lst, weight_lst)
                    flatten_shape = flatten_weight.shape
                    flatten_dtype = weight_dtype

                    ipc_handle = serialize_weight(src_head_device_uuid, flatten_weight)

                    resp_ret = update_func(name_lst, key_size, key_numel, ipc_handle, flatten_shape,
                                           flatten_dtype, src_head_device_uuid, rpc_co)
                    name_lst = []
                    current_size_bytes = 0
                    del ipc_handle
                    weight_lst.clear()
                    weight_lst = []
                    torch.cuda.empty_cache()

            if not is_moe_model:
                finished_name = ["update_finish"]
            else:
                if model_param_type == "moe":
                    if is_tp_dp_and_cp_head():
                                                                                        
                        finished_name = ["moe_finish"]
                    else:
                        finished_name = ["update_finish"]
                else:
                    if not is_tp_dp_and_cp_head():
                        return
                    else:
                        finished_name = ["update_finish"]

            if is_update_weight_head(
                    enable_expert_parallel=mpu.get_expert_model_parallel_world_size() > 1):
                resp_ret = update_func(name_lst=finished_name,
                                       key_size=None,
                                       key_numel=None,
                                       ipc_handle=None,
                                       flatten_shape=None,
                                       flatten_dtype=None,
                                       src_head_device_uuid=None,
                                       rpc_co_func=rpc_co)
                print_with_rank_and_datetime(f"update_finished {finished_name} {resp_ret=}")

        unwrapped_model = self.unwrap_model_func(model)[0]
        model_config = get_model_config(unwrapped_model)
        is_moe_model = False
        if 'moe' in model_config.model_arch:
            is_moe_model = True
            update_model_func(model_param_type="moe", is_moe_model=is_moe_model)

        update_model_func(model_param_type="dense", is_moe_model=is_moe_model)

        clear_memory()
        cpu_barrier()
        torch.cuda.synchronize()
        time_t3 = time.time()

        print_with_rank_and_datetime(f"update using time {time_t3 - time_t1} detail"
                                     f" {time_t3 - time_t2} {time_t2 - time_t1}")

    def infer_engine_flush_cache(self):
        async def flush_cache_rpc(ep_ip, ep_port):
            url = f'http://{ep_ip}:{ep_port}/flush_cache'
            return await call_once_rpc(url, {})

        async def co():
            rpc_cos = []
            for ip, port in zip(self.ep_ips, self.ep_ports):
                rpc_cos.append(flush_cache_rpc(ip, port))
            return await asyncio.gather(*rpc_cos)

        if torch.distributed.get_rank() == 0:
            asyncio.run(co())
        cpu_barrier()

                                    
    def check_sampler_zeros_weight(self):
                        
        async def zeros_rpc(ep_ip, ep_port):
            url = f'http://{ep_ip}:{ep_port}/check_zeros_weight'
            return await call_once_rpc(url, {})

        async def co():
            rpc_cos = []
            for ip, port in zip(self.ep_ips, self.ep_ports):
                rpc_cos.append(zeros_rpc(ip, port))
            return await asyncio.gather(*rpc_cos)

        if torch.distributed.get_rank() == 0:
            ret = asyncio.run(co())
            print_with_rank_and_datetime(f"check all zeros {ret=}")
        torch.distributed.barrier()

    def save_infer_engine_ckpt(self, save_path):
        async def save_rpc(ep_ip, ep_port, req_dict):
            url = f'http://{ep_ip}:{ep_port}/save_infer_engine_ckpt'
            return await call_once_rpc(url, req_dict)

        async def co():
            rpc_cos = []
            req_dict = {
                "checkpoint_dir": save_path,
            }
            for ip, port in zip(self.ep_ips, self.ep_ports):
                rpc_cos.append(save_rpc(ip, port, req_dict))
            return await asyncio.gather(*rpc_cos)

        if torch.distributed.get_rank() == 0:
            ret = asyncio.run(co())
            print_with_rank_and_datetime(f"save infer_engine ckpt {ret=}")

        torch.distributed.barrier()

                              
    def test_generate(self, ):

        async def co():
            for ep_ip, ep_port in zip(self.ep_ips, self.ep_ports):
                url = f'http://{ep_ip}:{ep_port}/test_generate'
                prompts = [
                    "I am good at", "The capital of France is", "Is the statement '3 - 4 > 0' true?"
                ]
                req_dict = {
                    'prompt': prompts,
                }
                resp = await call_once_rpc(url, req_dict)
                print(f"test_generate: {ep_ip=}, {ep_port=}, {resp=}")

        if torch.distributed.get_rank() == 0:
            asyncio.run(co())
        torch.distributed.barrier()

                                                  
    def test_sglang_sleep_wakeup_generate(self, ):
        async def test_sglang_rpc(ep_ip, ep_port, req_dict):
            url = f'http://{ep_ip}:{ep_port}/test_sglang_sleep_wakeup_generate'
            return await call_once_rpc(url, req_dict)

        async def co(trial):
            rpc_cos = []
            some_sentences = 'InstructRetro (Wang et al., 2023b) further scales up the size of Retro to 48B, featuring the largest LLM pretrained with retrieval (as of December 2023). The obtained foundation model, Retro 48B, largely outperforms the GPT counterpart in terms of perplexity. With instruction tuning on Retro, InstructRetro demonstrates significant improvement over the instruction tuned GPT on downstream tasks in the zero-shot setting. Specifically, the average improvement of InstructRetro is 7% over its GPT counterpart across 8 short-form QA tasks, and 10% over GPT across 4 challenging long-form QA tasks. We also find that one can ablate the encoder from InstructRetro architecture and directly use the InstructRetro decoder backbone as GPT, while achieving comparable results. Megatron-Core offers core building blocks such as attention mechanisms, transformer blocks and layers, normalization layers, and embedding techniques. Additional functionality like activation recomputation, distributed checkpointing is also natively built-in to the library. The building blocks and functionality are all GPU optimized, and can be built with advanced parallelization strategies for optimal training speed and stability on NVIDIA Accelerated Computing Infrastructure. Another key component of the Megatron-Core library includes advanced model parallelism techniques (tensor, sequence, pipeline, context, and MoE expert parallelism). Our codebase is capable of efficiently training large language models (i.e., models with hundreds of billions of parameters) with both model and data parallelism. To demonstrate how our software scales with multiple GPUs and model sizes, we consider GPT models ranging from 2 billion parameters to 462 billion parameters. All models use a vocabulary size of 131,072 and a sequence length of 4096. We vary hidden size, number of attention heads, and number of layers to arrive at a specific model size. As the model size increases, we also modestly increase batch size. Our experiments use up to 6144 xxx GPUs. We perform fine-grained overlapping of data-parallel (--overlap-grad-reduce --overlap-param-gather), tensor-parallel (--tp-comm-overlap) and pipeline-parallel communication (enabled by default) with computation to improve scalability. The reported throughputs are measured for end-to-end training and include all operations including data loading, optimizer steps, communication, and even logging. Note that we did not train these models to convergence.'
            vocab = list(set(some_sentences.split()))
            prompts = [' '.join([random.choice(vocab) for j in range(512)]) for i in range(16)]

            req_dict = {
                "trial": trial,
                "prompt": prompts,
            }
            for ep_ip, ep_port in zip(self.ep_ips, self.ep_ports):
                rpc_cos.append(test_sglang_rpc(ep_ip, ep_port, req_dict))
            return await asyncio.gather(*rpc_cos)

        if torch.distributed.get_rank() == 0:
            for i in range(10):
                asyncio.run(co(trial=f'trial {i}'))
                self.sleep()
                self.wake_up()

        cpu_barrier()
