                                                      
                                                                 

from contextlib import asynccontextmanager
from datetime import datetime
import asyncio
import io
import os
import time
import queue
import uuid
import inspect

import fastapi
import torch
import uvicorn

try:
    from sglang.srt.utils import MultiprocessingSerializer
    from sglang.srt.model_executor.model_runner import LocalSerializedTensor
    from gpatch.core.sampler_v3.sglang import (
        GcoreUpdateWeightsFromTensorReqInput,
        InitWeightsUpdateGroupReqInput,
    )
except ImportError:
    pass

from megatron.training.initialize import initialize_megatron
from megatron.core import parallel_state
from megatron.core import mpu
from megatron.core.utils import divide
from megatron.training.global_vars import (
    get_args,
    get_timers,
    get_tokenizer,
)
from megatron.training.utils import print_rank_0

from gpatch.rpc import once_rpc
from gpatch.training.arguments import validate_rl_args
from gpatch.training.global_vars import set_global_variables
from gpatch.training.utils import (
    print_with_rank_and_datetime,
    pack_name_lst_to_str,
    find_process_using_port,
)
from gpatch.core.utils import  print_memory_tracking
from gpatch.core.utils import get_nvml_memory_info
from gpatch.core.parallel_state import (
    init_pg,
    is_mp_and_cp_head,
    get_mp_and_cp_size,
    cpu_barrier,
    is_mp_head,
)
from gpatch.core.aligner_helper import (
    clear_memory,
)
from gpatch.core.sampler_v3.vllm import (
    init_model_update_group,
    setup_head_and_src_rank,
    gcore_save_vllm_checkpoint,
)
from gpatch.core.parallel_state import get_model_parallel_group_gloo, get_model_parallel_src_rank_gloo
from gpatch.core.wecube import report_ppo_metrics, init_wecube_reporter
from gpatch.rpc.monitor import start_monitor_client_in_background
from megatron_datasets.args import parse_dataset_config


def run_grpo_sampler_server(sampler, model_provider, gen_func):
    args = get_args()
    ep_ip = args.ppo_sampler_ips[mpu.get_data_parallel_rank()]
    ep_port = args.ppo_sampler_ports[mpu.get_data_parallel_rank()]

    monitor_kwargs = {
        "do_monitor": args.do_monitor,
        "monitor_server_ip": args.monitor_server_ip,
        "monitor_port": args.monitor_port,
    }

                                                                                                 
                                  
    lock = asyncio.Lock()
    actor_pp_size = args.ppo_actor_pipeline_model_parallel_size
    actor_ep_size = args.ppo_actor_expert_model_parallel_size
    weight_update_finished = [[False for _ in range(actor_ep_size)] for _ in range(actor_pp_size)]
    update_pp_idx = 0
    update_ep_idx = 0

    @asynccontextmanager
    async def lifespan(app: fastapi.FastAPI):
        yield

    app = fastapi.FastAPI(lifespan=lifespan)

    @app.post("/heartbeat")
    @once_rpc(**monitor_kwargs)
    async def heartbeat(req_dict):
        return {"ret": "ok"}

    @app.post("/exit")
    @once_rpc(**monitor_kwargs)
    async def exit(req_dict):
        return {'ret': 'ok'}

    @app.post("/setup")
    @once_rpc(**monitor_kwargs)
    async def setup(req_dict):
        assert is_mp_head()
        print(f'setup rank {torch.distributed.get_rank()}')

        async with lock:
            if sampler.engine is None:
                cmd_obj = [{'cmd': 'setup'}]
                torch.distributed.broadcast_object_list(cmd_obj,
                                                        src=get_model_parallel_src_rank_gloo(),
                                                        group=get_model_parallel_group_gloo())
                sampler.engine = model_provider()

        assert sampler.engine is not None
        print_with_rank_and_datetime(f"setup sampler engine")
        torch.cuda.synchronize()
        print_memory_tracking(f"Memory tracking: sampler after setup", verbose=True, rank=0)
        cmd_obj = [{'cmd': 'setup_log_memory'}]
        torch.distributed.broadcast_object_list(cmd_obj,
                                                src=get_model_parallel_src_rank_gloo(),
                                                group=get_model_parallel_group_gloo())
        return {"ret": 'ok'}

           
    @app.post("/test_generate")
    @once_rpc(**monitor_kwargs)
    async def test_generate(req_dict):
        prompts = req_dict.pop("prompt")
        if not isinstance(prompts, list):
            prompts = [prompts]
        res_gens = []
        for prompt in prompts:
            sampling_params = sampler.engine.get_sampling_params(temperature=0.,
                                                                 top_k=1,
                                                                 seed=123,
                                                                 n=1)

            _prompt = prompt
            if sampler.engine.infer_engine_impl == 'sglang':
                _prompt = {
                    'prompt_token_ids':
                    get_tokenizer()._tokenizer(prompt, add_special_tokens=False).input_ids,
                }
            res_gens.append(
                sampler.engine.async_generate(_prompt, sampling_params, str(uuid.uuid4().hex)))

        outputs = await sampler.engine.wait_and_get_async_generate_output(res_gens)
        if sampler.engine.infer_engine_impl == 'sglang':
            for output in outputs:
                output.outputs[0].text = get_tokenizer()._tokenizer.decode(
                    output.outputs[0].token_ids, skip_special_tokens=False)


        text_outputs = [prompt + output.outputs[0].text for prompt, output in zip(prompts, outputs)]
        output_token_ids = [output.outputs[0].token_ids for output in outputs]
        print_with_rank_and_datetime(f"test_generate {text_outputs=} {output_token_ids=}")
        return {"text": text_outputs, "token_ids": output_token_ids}

           
    @app.post("/test_sglang_sleep_wakeup_generate")
    @once_rpc(**monitor_kwargs)
    async def test_sglang_sleep_wakeup_generate(req_dict):
        trial = req_dict.pop("trial")
        prompts = req_dict.pop("prompt")
        if not isinstance(prompts, list):
            prompts = [prompts]

        meminfo = get_nvml_memory_info()
        print_with_rank_and_datetime(f'Before {trial}: {meminfo}')

        async def generate_and_print(trial):
            res_gens = []
            prompt_lens = []
            for prompt in prompts:
                sampling_params = sampler.engine.get_sampling_params(temperature=0.,
                                                                    top_k=1,
                                                                    max_tokens=512,
                                                                    seed=123,
                                                                    n=1)

                _prompt = prompt
                if sampler.engine.infer_engine_impl == 'sglang':
                    _prompt = {
                        'prompt_token_ids':
                        get_tokenizer()._tokenizer(prompt, add_special_tokens=False).input_ids,
                    }
                    prompt_lens.append(len(_prompt['prompt_token_ids']))
                res_gens.append(
                    sampler.engine.async_generate(_prompt, sampling_params, str(uuid.uuid4().hex)))

            outputs = await sampler.engine.wait_and_get_async_generate_output(res_gens)
            if sampler.engine.infer_engine_impl == 'sglang':
                for output in outputs:
                    output.outputs[0].text = get_tokenizer()._tokenizer.decode(
                        output.outputs[0].token_ids, skip_special_tokens=False)

                                                                                                          
            text_outputs = [output.outputs[0].text[:200] for output in outputs]
            mean_prompt_len = sum(prompt_lens) / len(prompt_lens)
            output_lens = [len(output.outputs[0].token_ids) for output in outputs]
            mean_output_len = sum(output_lens) / len(output_lens)
                                                                                         
            print_with_rank_and_datetime(
                message={
                    "trial": trial,
                    "mean_prompt_len": mean_prompt_len,
                    "mean_output_len": mean_output_len,
                    "prompt_lens": prompt_lens[:50],
                    "output_lens": output_lens[:50],
                    "prompts": [prompt[:200] for prompt in prompts[:50]],
                    "text_outputs": text_outputs[:50],
                },
                rank=0,
            )

        await generate_and_print(trial=trial)

        meminfo = get_nvml_memory_info()
        print_with_rank_and_datetime(f'After {trial}: {meminfo}')

        return {"ret": 'ok'}

    @app.post("/sleep")
    @once_rpc(**monitor_kwargs)
    async def sleep(req_dict):
                                                                 
        await sampler.engine.sleep()
        clear_memory()
        return {"ret": 'ok'}

    @app.post("/wake_up")
    @once_rpc(**monitor_kwargs)
    async def wake_up(req_dict):
        tag_names = req_dict["tag_names"]
        await sampler.engine.wake_up(tags=tag_names)
        clear_memory()
        return {"ret": 'ok'}

    @app.post("/flush_cache")
    @once_rpc(**monitor_kwargs)
    async def flush_cache(req_dict):
        if sampler.engine.infer_engine_impl == 'sglang':
            await sampler.engine.get_engine().tokenizer_manager.flush_cache()
        else:
                                              
            pass
        return {"ret": 'ok'}

    @app.post("/generate")
    @once_rpc(**monitor_kwargs)
    async def generate(req_dict):
        sampler_rank = torch.distributed.get_rank()
        actor_dp_rank = req_dict['actor_dp_rank']
        sampler_dp_rank = req_dict['sampler_dp_rank']
        ppo_step = req_dict['ppo_step']
        sample_idx = req_dict['sample_idx']
        sampling_repeat = req_dict['sampling_repeat']
        assert sampler_dp_rank == mpu.get_data_parallel_rank(
        ), f'{sampler_dp_rank=} {mpu.get_data_parallel_rank()=}'

        resp_dict = await sampler.generate_rollouts(gen_func, req_dict['batch'], sampling_repeat)
        assert 'ready' not in resp_dict, '`ready` is a reserved name'

        if args.ppo_wecube_report:
            report_data = {}
            if 'batch' in req_dict:
                if 'prompt_token_ids' in req_dict['batch']:
                    sampler_num_datas = len(req_dict['batch']['prompt_token_ids'])
                    report_data["sampler_num_datas"] = sampler_num_datas
                if 'tokens' in req_dict:
                    report_data["sampler_num_samples"]: len(req_dict['tokens'])
            report_ppo_metrics(report_data)

        print_memory_tracking(f"Memory tracking: sampler generate {torch.cuda.memory_allocated() / (1024**3)} GB")
        return resp_dict

    @app.post("/mark_sampler_ppo_step_begin")
    @once_rpc(**monitor_kwargs)
    async def mark_sampler_ppo_step_begin(req_dict):
        ppo_step = req_dict['ppo_step']
        actor_dp_rank = req_dict['actor_dp_rank']
        tag_names = req_dict['tag_names']

        cmd_obj = [{'cmd': 'mark_sampler_ppo_step_begin'}]
        torch.distributed.broadcast_object_list(cmd_obj,
                                                src=get_model_parallel_src_rank_gloo(),
                                                group=get_model_parallel_group_gloo())

        async with lock:
            await sampler.engine.wake_up(tags=tag_names)
        torch.cuda.synchronize()
        print_memory_tracking(f"Memory tracking: sampler after wake_up", verbose=True, rank=0)

        cmd_obj = [{'cmd': 'wakeup_log_memory'}]
        torch.distributed.broadcast_object_list(cmd_obj,
                                                src=get_model_parallel_src_rank_gloo(),
                                                group=get_model_parallel_group_gloo())
        return {"ret": 'ok'}

    @app.post("/mark_sampler_ppo_step_end")
    @once_rpc(**monitor_kwargs)
    async def mark_sampler_ppo_step_end(req_dict):
        ppo_step = req_dict['ppo_step']
        actor_dp_rank = req_dict['actor_dp_rank']

        cmd_obj = [{'cmd': 'mark_sampler_ppo_step_end'}]
        torch.distributed.broadcast_object_list(cmd_obj,
                                                src=get_model_parallel_src_rank_gloo(),
                                                group=get_model_parallel_group_gloo())

        async with lock:
            await sampler.engine.sleep()
        clear_memory()
        print_memory_tracking(f"Memory tracking: sampler after sleep", verbose=True, rank=0)
        cmd_obj = [{'cmd': 'sleep_log_memory'}]
        torch.distributed.broadcast_object_list(cmd_obj,
                                                src=get_model_parallel_src_rank_gloo(),
                                                group=get_model_parallel_group_gloo())
        return {"ret": 'ok'}

    @app.post("/update_weights",)
    @once_rpc(**monitor_kwargs)
    async def update_weights(req_dict):
        torch.cuda.synchronize()
        t1 = time.time()

        ret = True
        actor_pp_rank = req_dict["actor_pp_rank"]
        actor_ep_rank = req_dict["actor_ep_rank"]
        name_lst = req_dict["name"]
        nonlocal update_pp_idx, weight_update_finished, update_ep_idx, actor_pp_size, actor_ep_size

        sampler.weights_per_actor_pp_rank[actor_pp_rank][actor_ep_rank].put(req_dict)

        while actor_pp_rank != update_pp_idx or actor_ep_rank != update_ep_idx \
            or sampler.weights_per_actor_pp_rank[update_pp_idx][update_ep_idx].qsize() <= 0:
            await asyncio.sleep(0.01)

        torch.cuda.synchronize()
        t2 = time.time()
        if sampler.engine.infer_engine_impl == 'vllm':
            ret = await sampler.update_weights_vllm(update_pp_idx, update_ep_idx)
        else:
            ret = await sampler.update_weights_sglang(update_pp_idx, update_ep_idx)
        if len(name_lst) == 1 and name_lst[0] == "update_finish":
            async with lock:
                weight_update_finished[update_pp_idx][update_ep_idx] = True

        update_ep_idx += 1
        if update_ep_idx == actor_ep_size:
            update_ep_idx = update_ep_idx % actor_ep_size
            update_pp_idx = (update_pp_idx + 1) % actor_pp_size

                                      
        async with lock:
            if all(all(ep_updated_lst) for ep_updated_lst in weight_update_finished):
                for i in range(len(weight_update_finished)):
                    for j in range(len(weight_update_finished[0])):
                        weight_update_finished[i][j] = False
                update_pp_idx = 0
                update_ep_idx = 0

                                                                                 
        while weight_update_finished[update_pp_idx][update_ep_idx] is True:
            update_ep_idx += 1
            if update_ep_idx == actor_ep_size:
                update_ep_idx = update_ep_idx % actor_ep_size
                update_pp_idx = (update_pp_idx + 1) % actor_pp_size

        torch.cuda.synchronize()
        t3 = time.time()
        resp_dict = {"update_success": ret}
        print_with_rank_and_datetime(
            f"update app {actor_pp_rank} aep {actor_ep_rank} using time {t3 - t1}, detail {t3 - t2} {t2 - t1}"
        )
        return resp_dict

    @app.post("/check_zeros_weight")
    @once_rpc(**monitor_kwargs)
    async def check_zeros_weight(req_dict):
        weights_updated = await sampler.engine.get_engine().collective_rpc(check_weight_all_zeros,
                                                                           args=())
        resp_dict = {"all_zeros": weights_updated}
        print(f"{resp_dict=}")
        return resp_dict

    @app.post("/init_weight_update_group")
    @once_rpc(**monitor_kwargs)
    async def init_weight_update_group(req_dict):
                                                  
        print_with_rank_and_datetime(f"call init_weight_update_group")
        init_ret = await sampler.init_weight_update_group(req_dict)
        resp_dict = {"init_group_ret": init_ret}
        print_with_rank_and_datetime(f"{resp_dict=}")
        return resp_dict

    @app.post("/setup_head_and_src_rank")
    @once_rpc(**monitor_kwargs)
    async def setup_head_and_src_rank(req_dict):
        actor_pp_rank = req_dict["actor_pp_rank"]
        actor_ep_rank = req_dict["actor_ep_rank"]
        sampler.weights_setup_head_queue[actor_pp_rank][actor_ep_rank].put(req_dict)

        if sampler.engine.infer_engine_impl == 'vllm':
            nonlocal update_pp_idx, actor_pp_size, update_ep_idx, actor_ep_size
            while actor_pp_rank != update_pp_idx or actor_ep_rank != update_ep_idx \
                or sampler.weights_setup_head_queue[update_pp_idx][update_ep_idx].qsize() <= 0:
                await asyncio.sleep(0.05)

            ret, src_rank = await sampler.setup_head_and_src_rank(update_pp_idx, update_ep_idx)
            update_ep_idx += 1
            if update_ep_idx == actor_ep_size:
                update_ep_idx = update_ep_idx % actor_ep_size
                update_pp_idx = (update_pp_idx + 1) % actor_pp_size
        else:
            src_rank = 0
            ret = True

        resp_dict = {"setup_ret": ret}
        print_with_rank_and_datetime(f"setup_head_and_src_rank {resp_dict=} {src_rank=}")
        return resp_dict

              
    @app.post("/save_infer_engine_ckpt")
    @once_rpc(**monitor_kwargs)
    async def save_infer_engine_ckpt(req_dict):
        checkpoint_dir = req_dict["checkpoint_dir"]
        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir, exist_ok=True)
        print_with_rank_and_datetime(f"{checkpoint_dir=}")
        if sampler.engine.infer_engine_impl == 'vllm':
            ret = await sampler.engine.get_engine().collective_rpc(gcore_save_vllm_checkpoint,
                                                                args=(checkpoint_dir, ))
        else:
            ret = sampler.engine.get_engine().save_sharded_model(
                path=checkpoint_dir, pattern=None, max_size=(16 * 1024**3),
            )
        resp_dict = {"save_ckpt": ret}
        print_with_rank_and_datetime(f"save_infer_engine_ckpt {ret=}")
        return resp_dict


    def serve_forever_fn():
        uvicorn.run(app,
                    host=ep_ip,
                    port=ep_port,
                    log_level='error',
                    use_colors=False,
                    timeout_keep_alive=args.ppo_sampler_server_timeout_keep_alive,
                    ssl_keyfile=None,
                    ssl_certfile=None,
                    ssl_ca_certs=None,
                    ssl_cert_reqs=None)
    find_process_using_port(ep_ip, ep_port)
    print_with_rank_and_datetime(f'run_grpo_sampler_server http://{ep_ip}:{ep_port}')
                                                                                              
                                  
               
                                                        
                          
                    
         
                                                                                     
                                                                                       
                                                                                          
    serve_forever_fn()


def run_grpo_sampler_worker(sampler, model_provider, gen_func):
    args = get_args()

    while True:
        cmd_obj = [None]
        torch.distributed.broadcast_object_list(cmd_obj,
                                                src=get_model_parallel_src_rank_gloo(),
                                                group=get_model_parallel_group_gloo())
        cmd_obj = cmd_obj[0]

        if cmd_obj['cmd'] == 'setup':
            if args.infer_engine_impl == 'sglang' and mpu.get_tensor_model_parallel_rank(
            ) % args.num_gpus_per_node == 0:
                sampler.engine = model_provider()
        elif cmd_obj['cmd'] == 'wakeup_log_memory':
            print_memory_tracking(f"Memory tracking: sampler after wake_up")
        elif cmd_obj['cmd'] == 'sleep_log_memory':
            print_memory_tracking(f"Memory tracking: sampler after sleep")
        elif cmd_obj['cmd'] == 'setup_log_memory':
            print_memory_tracking(f"Memory tracking: sampler after setup")
        else:
            pass


def update_engine_weight(worker_wrap, name_lst, key_size, key_numel, actor_pp_rank, actor_ep_rank,
                         ipc_handle, flatten_shape, flatten_dtype, sampler_dp_rank):
    ret = False
    debug_dir = "./debug-tmp"
    os.makedirs(debug_dir, exist_ok=True)
    with open(
            f'{debug_dir}/update_sampler_dp{sampler_dp_rank}_actor_pp_{actor_pp_rank}_ep_{actor_ep_rank}.txt',
            'a') as outf:
        outf.write(f"update_engine_weight {datetime.now()}\n")
        outf.flush()
        try:
            if not hasattr(worker_wrap, f'_gcore_mp_head_{actor_pp_rank}_{actor_ep_rank}'):
                outf.write(
                    f"update_engine_weight: There is no mo_head. You must set head early. {datetime.now()=}\n"
                )
                outf.flush()
                raise RuntimeError(
                    f"update_engine_weight must set head early. {actor_pp_rank=} {actor_ep_rank=}")

                                                            
            is_head = getattr(worker_wrap, f'_gcore_mp_head_{actor_pp_rank}_{actor_ep_rank}')
            head_device_uuid = getattr(
                worker_wrap, f'_gcore_mp_head_device_uuid_{actor_pp_rank}_{actor_ep_rank}')
            src_rank = getattr(worker_wrap, f'_gcore_src_rank_for_{actor_pp_rank}_{actor_ep_rank}')
            current_rank = worker_wrap._gcore_rank

            torch.cuda.synchronize()
            t1 = time.time()

            if is_head is True:
                device_uuid = head_device_uuid
                assert device_uuid is not None
                ip_handle = ipc_handle[device_uuid]
                device_id = worker_wrap.device.index
                func, args = ip_handle
                list_args = list(args)
                                                                         
                                                                           
                list_args[6] = device_id
                weight = func(*list_args).to(device=worker_wrap.device)
            else:
                weight = torch.empty(flatten_shape, dtype=flatten_dtype, device=worker_wrap.device)

            worker_wrap._gcore_model_update_group.broadcast(weight,
                                                            src=src_rank,
                                                            stream=torch.cuda.current_stream())
            model = worker_wrap.worker.model_runner.model

                
            model_weights = []
            offset = 0
            for name in name_lst:
                size = key_size[name]
                numel = key_numel[name]
                _weight = weight.narrow(0, offset, numel).view(size)
                offset += numel
                model_weights.append((name, _weight))

            model.load_weights(weights=model_weights)

            torch.cuda.synchronize()
            t2 = time.time()
            outf.write(
                f"update weight time {t2 - t1} {src_rank=} {current_rank=} datetime {datetime.now()} num of weights {len(model_weights)}\n"
            )
            outf.flush()
            ret = True
        except Exception as e:
            outf.write(f"exception {e} {datetime.now()=}\n")
            outf.flush()
            ret = False
    return ret


def check_weight_all_zeros(worker_wrap):
    ret = False
    debug_dir = "./debug-tmp"
    os.makedirs(debug_dir, exist_ok=True)
    with open(f'{debug_dir}/zero_check.txt', 'a') as outf:
        outf.write(f"enter zeros check\n")
        try:
            model = worker_wrap.worker.model_runner.model
            weights_updated = True
            for name, p in model.named_parameters():
                weights_updated = weights_updated and torch.equal(p, torch.zeros_like(p))
                outf.write(f"weight {p.sum()} {p}\n")
            ret = weights_updated
            outf.write(f"check_zeros_weight {ret}\n")
            outf.flush()
        except Exception as e:
            outf.write(f"exception {e} {datetime.now()=}\n")
            outf.flush()
            ret = False
    return ret


class GrpoSamplerV3:

    def __init__(self):
        self.engine = None
        self.gen_func_params = None

    def post_init(self):
        args = get_args()
        self.sampler_master_address = args.ppo_sampler_ips[0]
        self.sampler_master_port = args.ppo_sampler_ports[0] + args.world_size
        self.sampler_world_size = args.world_size

        self.actor_pp_size = args.ppo_actor_pipeline_model_parallel_size
        self.actor_ep_size = args.ppo_actor_expert_model_parallel_size
        self.weights_per_actor_pp_rank = [[queue.SimpleQueue() for _ in range(self.actor_ep_size)]
                                          for _ in range(self.actor_pp_size)]
        self.weights_setup_head_queue = [[queue.SimpleQueue() for _ in range(self.actor_ep_size)]
                                          for _ in range(self.actor_pp_size)]

        print_with_rank_and_datetime(
            f"post_init {self.sampler_master_address=} {self.sampler_master_port=} "
            f"{self.sampler_world_size=} {self.actor_pp_size=}")
        pass

    @torch.no_grad()
    def generate_rollouts(self, gen_func, batch, sampling_repeat):
        args = get_args()
        timers = get_timers()
        assert is_mp_head()

                                           
                                                                    
        if self.gen_func_params is None:
            sig = inspect.signature(gen_func)
            self.gen_func_params = sig.parameters
        
        if 'sampling_repeat' in self.gen_func_params:
            rollout_batch = gen_func(self.engine, batch, sampling_repeat=sampling_repeat)
        else:
            rollout_batch = gen_func(self.engine, batch)
        return rollout_batch

    async def init_weight_update_group(self, req_dict):
        ret = False
                                       
        print_with_rank_and_datetime(
            f"call init weight update group dp_rank {mpu.get_data_parallel_rank()}")

        if self.engine.infer_engine_impl == "vllm":
            resp = await self.engine.get_engine().collective_rpc(init_model_update_group,
                                                                 args=(self.sampler_master_address,
                                                                       self.sampler_master_port,
                                                                       mpu.get_data_parallel_rank(),
                                                                       self.sampler_world_size))
            ret = all(resp)
        elif self.engine.infer_engine_impl == "sglang":
            mp_size = mpu.get_tensor_model_parallel_world_size() * mpu.get_pipeline_model_parallel_world_size()
            obj = InitWeightsUpdateGroupReqInput(
                master_address=self.sampler_master_address,
                master_port=self.sampler_master_port + self.sampler_world_size,
                rank_offset=(mpu.get_data_parallel_rank() * mp_size),
                world_size=self.sampler_world_size,
                group_name="sglang_weight_init_group",
                backend="nccl",
            )
            resp = await self.engine.get_engine().tokenizer_manager.init_weights_update_group(
                obj, None)
            print_with_rank_and_datetime(f"init weight update {resp=}")
            ret = resp[0]
        else:
            raise ValueError(f"value error: {self.engine.infer_engine_impl}")
        return ret

    async def setup_head_and_src_rank(self, actor_pp_rank, actor_ep_rank):
                                           
        print_with_rank_and_datetime(
            f"current {actor_pp_rank=} {actor_ep_rank=} "
            f"{self.weights_setup_head_queue[actor_pp_rank][actor_ep_rank].qsize()}"
        )
        data = self.weights_setup_head_queue[actor_pp_rank][actor_ep_rank].get()
        head_device_uuid = data["head_device_uuid"]
        assert actor_pp_rank == data[
            "actor_pp_rank"], f"{actor_pp_rank=} missmatch {data['actor_pp_rank']=}"
        assert actor_ep_rank == data[
            "actor_ep_rank"], f"{actor_ep_rank=} missmatch {data['actor_ep_rank']=}"

        resps = await self.engine.get_engine().collective_rpc(setup_head_and_src_rank,
                                                              args=(actor_pp_rank, actor_ep_rank,
                                                                    head_device_uuid,
                                                                    mpu.get_data_parallel_rank()))
        ret = True
        src_rank = None
        for resp in resps:
            if not resp[0]:
                ret = False
            src_rank = resp[1]
        return ret, src_rank

    async def update_weights_vllm(self, actor_pp_rank, actor_ep_rank):
        data = self.weights_per_actor_pp_rank[actor_pp_rank][actor_ep_rank].get()
        torch.cuda.synchronize()
        t0 = time.time()
        name_lst = data["name"]
        key_size = data["key_size"]
        key_numel = data["key_numel"]
        ipc_handle = data["ipc_handle"]
        flatten_shape = data["flatten_shape"]
        flatten_dtype = data["flatten_dtype"]

        assert actor_pp_rank == data[
            "actor_pp_rank"], f"{actor_pp_rank=} missmatch {data['actor_pp_rank']=}"
        assert actor_ep_rank == data[
            "actor_ep_rank"], f"{actor_ep_rank=} missmatch {data['actor_ep_rank']=}"
        resp = None
        sampler_dp_rank = mpu.get_data_parallel_rank()
        if name_lst[0] not in ["update_finish", "moe_finish"]:
            resp = await self.engine.get_engine().collective_rpc(
                update_engine_weight,
                args=(name_lst, key_size, key_numel, actor_pp_rank, actor_ep_rank, ipc_handle,
                      flatten_shape, flatten_dtype, sampler_dp_rank))

        ret = True
        if resp is not None:
            ret = all(resp)
        torch.cuda.synchronize()
        t3 = time.time()
        if torch.distributed.get_rank() == 0:
            print(f"update perf {resp=} {ret=} {name_lst[:2]=} {actor_pp_rank=} "
                  f"{actor_ep_rank=} using_time {t3-t0}")
        return ret

    async def update_weights_sglang(self, actor_pp_rank, actor_ep_rank):
        data = self.weights_per_actor_pp_rank[actor_pp_rank][actor_ep_rank].get()
        torch.cuda.synchronize()
        t0 = time.time()
        name_lst = data["name"]
        key_size = data["key_size"]
        key_numel = data["key_numel"]
        ipc_handle = data["ipc_handle"]
        flatten_shape = data["flatten_shape"]
        flatten_dtype = data["flatten_dtype"]
        src_head_device_uuid = data["src_head_device_uuid"]

        assert actor_pp_rank == data[
            "actor_pp_rank"], f"{actor_pp_rank=} missmatch {data['actor_pp_rank']=}"
        assert actor_ep_rank == data[
            "actor_ep_rank"], f"{actor_ep_rank=} missmatch {data['actor_ep_rank']=}"
        resp = None
        sampler_dp_rank = mpu.get_data_parallel_rank()
        sampler_tp_size = mpu.get_tensor_model_parallel_world_size()

        assert mpu.get_tensor_model_parallel_rank() == 0 and \
            mpu.get_pipeline_model_parallel_rank() == 0
        if name_lst[0] not in ["update_finish", "moe_finish"]:
            packed_name = pack_name_lst_to_str(name_lst)
            named_tensors=[
                (
                    packed_name,
                    LocalSerializedTensor(values=[ipc_handle for _ in range(sampler_tp_size)]),
                )
            ]
            obj = GcoreUpdateWeightsFromTensorReqInput(
                serialized_named_tensors=[
                    MultiprocessingSerializer.serialize(named_tensors)
                    for _ in range(sampler_tp_size)
                ],
                load_format=None,
                flush_cache=True,
                head_device_uuid=src_head_device_uuid,
                flatten_weight_shape=flatten_shape,
                flatten_weight_dtype=flatten_dtype,
                key_size=key_size,
                key_numel=key_numel,
                name_list=name_lst,
            )
            resp = await self.engine.get_engine().tokenizer_manager.update_weights_from_tensor(obj, None)

            print_with_rank_and_datetime(f"update weight {resp=}")

        ret = True
        if resp is not None:
            ret = resp[0]
        torch.cuda.synchronize()
        t3 = time.time()
        if torch.distributed.get_rank() == 0:
            print(f"update perf {resp=} {ret=} {name_lst[:2]=} {actor_pp_rank=} "
                  f"{actor_ep_rank=} using_time {t3-t0}")
        return ret


                                                             
def run_grpo_sampler_v3(
    grpo_sampler,
    model_provider,
    gen_func,
    extra_args_provider=None,
):
    '''
    Serve GRPO / PPO sampler (vllm / sglang).

    Parameters
    ----------

    grpo_sampler : gpatch.training.v3.grpo_sampler.GrpoSamplerV3

    model_provider : Callable
        A callable to provide vllm / sglang

    rollout_get_batch_func : Callable
        A callable to call vllm / sglang for generation (given prompts from dataset).

    Returns
    -------
    '''
    initialize_megatron(extra_args_provider=extra_args_provider)
    args = get_args()
    validate_rl_args(args)
    args.rl_role = 'sampler'
    parse_dataset_config(args)
    set_global_variables(args)
    init_pg(distributed_timeout_minutes=args.distributed_timeout_minutes)

    args = get_args()
    if args.ppo_wecube_report:
        init_wecube_reporter()

    grpo_sampler.post_init()
    assert args.expert_model_parallel_size == 1
    assert args.context_parallel_size == 1
    assert args.no_fused_kernel
    assert args.use_tp_pp_dp_mapping

    if args.infer_engine_impl == 'sglang':
        from gpatch.core.sampler_v3.sglang import sglang_hack
        sglang_hack()

    cpu_barrier()
    print_with_rank_and_datetime(f"Init: memory trace {torch.cuda.memory_allocated() / (1024**3)} GB")
    if is_mp_head():
        run_grpo_sampler_server(grpo_sampler, model_provider, gen_func)
    else:
        run_grpo_sampler_worker(grpo_sampler, model_provider, gen_func)
