                                                      
                                                                 

from contextlib import asynccontextmanager
import asyncio
import uuid
import os
from typing import Dict, List, Union, Any, Optional

import fastapi
import torch
import uvicorn

from megatron.training.initialize import initialize_megatron
from megatron.core import mpu
from megatron.training.global_vars import (
    get_args,
    get_timers,
    get_tokenizer,
)

from gpatch.rpc import once_rpc
from gpatch.training import get_actor_tokenizer
from gpatch.training.arguments import validate_rl_args
from gpatch.training.global_vars import set_global_variables
from gpatch.training.utils import print_with_rank_and_datetime, find_process_using_port
from gpatch.core.utils import get_nvml_memory_info, print_memory_tracking
from gpatch.core.parallel_state import (
    init_pg,
    cpu_barrier,
    is_mp_head,
)
from gpatch.core.aligner_helper import (
    clear_memory,
)
from gpatch.core.parallel_state import (get_model_parallel_group_gloo,
                                        get_model_parallel_src_rank_gloo)
from megatron_datasets.args import parse_dataset_config
from gpatch.rpc.monitor import start_monitor_client_in_background

def run_grpo_gen_rm_server(trainer, model_provider, gen_rm_func):
    args = get_args()
    ep_ip = args.ppo_gen_rm_ips[mpu.get_data_parallel_rank()]
    ep_port = args.ppo_gen_rm_ports[mpu.get_data_parallel_rank()]

    monitor_kwargs = {
        "do_monitor": args.do_monitor,
        "monitor_server_ip": args.monitor_server_ip,
        "monitor_port": args.monitor_port,
    }

                                                                                                 
                                  
    lock = asyncio.Lock()
    actor_pp_size = args.ppo_actor_pipeline_model_parallel_size

    @asynccontextmanager
    async def lifespan(app: fastapi.FastAPI):
        yield

    app = fastapi.FastAPI(lifespan=lifespan)

    @app.post("/heartbeat")
    @once_rpc(**monitor_kwargs)
    async def heartbeat(req_dict):
        return {"ret": "ok"}

    @app.post("/exit")
    @once_rpc(**monitor_kwargs)
    async def exit(req_dict):
        return {'ret': 'ok'}

    @app.post("/setup")
    @once_rpc(**monitor_kwargs)
    async def setup(req_dict):
        assert is_mp_head()
        print(f'setup rank {torch.distributed.get_rank()}')

        async with lock:
            if trainer.engine is None:
                cmd_obj = [{'cmd': 'setup'}]
                torch.distributed.broadcast_object_list(cmd_obj,
                                                        src=get_model_parallel_src_rank_gloo(),
                                                        group=get_model_parallel_group_gloo())
                trainer.engine = model_provider()

        assert trainer.engine is not None
        print_with_rank_and_datetime(f"setup gen rm engine")

        print_memory_tracking(f"Memory tracking: gen_rm after setup", verbose=True, rank=0)
        cmd_obj = [{'cmd': 'setup_log_memory'}]
        torch.distributed.broadcast_object_list(cmd_obj,
                                                src=get_model_parallel_src_rank_gloo(),
                                                group=get_model_parallel_group_gloo())
        return {"ret": 'ok'}

           
    @app.post("/test_generate")
    @once_rpc(**monitor_kwargs)
    async def test_generate(req_dict):
        prompts = req_dict.pop("prompt")
        if not isinstance(prompts, list):
            prompts = [prompts]
        for prompt in prompts:
            sampling_params = trainer.engine.get_sampling_params(temperature=1.,
                                                                 top_k=10,
                                                                 seed=123,
                                                                 n=1)

            _prompt = prompt
            if trainer.engine.infer_engine_impl == 'sglang':
                _prompt = {
                    'prompt_token_ids':
                    get_tokenizer()._tokenizer(prompt, add_special_tokens=False).input_ids,
                }
            res_gen = trainer.engine.async_generate(_prompt, sampling_params, str(uuid.uuid4().hex))

            outputs = await trainer.engine.wait_and_get_async_generate_output([res_gen])

            if trainer.engine.infer_engine_impl == 'sglang':
                for output in outputs:
                    output.outputs[0].text = get_tokenizer()._tokenizer.decode(
                        output.outputs[0].token_ids, skip_special_tokens=False)
            text_output = [output.outputs[0].text for output in outputs]

        return {"text": text_output}

    @app.post("/sleep")
    @once_rpc(**monitor_kwargs)
    async def sleep(req_dict):
        await trainer.engine.sleep()
        return {"ret": 'ok'}

    @app.post("/wake_up")
    @once_rpc(**monitor_kwargs)
    async def wake_up(req_dict):
        tag_names = req_dict["tag_names"]
        await trainer.engine.wake_up(tags=tag_names)
        return {"ret": 'ok'}

    @app.post('/update_gen_rm_weight_by_model_idx')
    @once_rpc(**monitor_kwargs)
    async def update_gen_rm_weight_by_model_idx(req_dict):
        assert 'rm_idx' in req_dict
        await trainer.engine.update_gen_rm_weight_by_model_idx(req_dict['rm_idx'])
        return {"ret": 'ok'}

    @app.post("/generate_rewards")
    @once_rpc(**monitor_kwargs)
    async def generate_rewards(req_dict):
        gen_rm_rank = torch.distributed.get_rank()
        actor_dp_rank = req_dict['actor_dp_rank']
        gen_rm_dp_rank = req_dict['gen_rm_dp_rank']
        ppo_step = req_dict['ppo_step']
        sample_idx = req_dict['sample_idx']
        prompt_idx = req_dict['prompt_idx']
        assert gen_rm_dp_rank == mpu.get_data_parallel_rank(), \
            f'{gen_rm_dp_rank=} {mpu.get_data_parallel_rank()=}'

        resp_dict = None

                                                  
        if resp_dict is None:
            resp_dict = await trainer.generate_rewards(gen_rm_func, req_dict['batch'])

        torch.cuda.synchronize()
        print_memory_tracking(f"Memory tracking: gen-rm generate", rank=0)
        return resp_dict

    @app.post("/flush_cache")
    @once_rpc(**monitor_kwargs)
    async def flush_cache(req_dict):
        if trainer.engine.infer_engine_impl == 'sglang':
            await trainer.engine.get_engine().tokenizer_manager.flush_cache()
        else:
                                              
            pass
        return {"ret": 'ok'}

    @app.post("/mark_gen_rm_ppo_step_begin")
    @once_rpc(**monitor_kwargs)
    async def mark_gen_rm_ppo_step_begin(req_dict):
        ppo_step = req_dict['ppo_step']
        actor_dp_rank = req_dict['actor_dp_rank']
        tag_names = req_dict['tag_names']

        cmd_obj = [{'cmd': 'mark_gen_rm_ppo_step_begin'}]
        torch.distributed.broadcast_object_list(cmd_obj,
                                                src=get_model_parallel_src_rank_gloo(),
                                                group=get_model_parallel_group_gloo())

        async with lock:
            await trainer.engine.wake_up(tags=tag_names)
        torch.cuda.synchronize()
        print_memory_tracking(f"Memory tracking: gen-rm after wake_up", verbose=True, rank=0)
        cmd_obj = [{'cmd': 'wakeup_log_memory'}]
        torch.distributed.broadcast_object_list(cmd_obj,
                                                src=get_model_parallel_src_rank_gloo(),
                                                group=get_model_parallel_group_gloo())
        return {"ret": 'ok'}

    @app.post("/mark_gen_rm_ppo_step_end")
    @once_rpc(**monitor_kwargs)
    async def mark_gen_rm_ppo_step_end(req_dict):
        ppo_step = req_dict['ppo_step']
        actor_dp_rank = req_dict['actor_dp_rank']

        cmd_obj = [{'cmd': 'mark_gen_rm_ppo_step_end'}]
        torch.distributed.broadcast_object_list(cmd_obj,
                                                src=get_model_parallel_src_rank_gloo(),
                                                group=get_model_parallel_group_gloo())

        async with lock:
            await trainer.engine.sleep()
        clear_memory()
        print_memory_tracking(f"Memory tracking: gen-rm after sleep", verbose=True, rank=0)
        cmd_obj = [{'cmd': 'sleep_log_memory'}]
        torch.distributed.broadcast_object_list(cmd_obj,
                                                src=get_model_parallel_src_rank_gloo(),
                                                group=get_model_parallel_group_gloo())
        return {"ret": 'ok'}

    def serve_forever_fn():
        uvicorn.run(app,
                    host=ep_ip,
                    port=ep_port,
                    log_level='error',
                    use_colors=False,
                    timeout_keep_alive=args.ppo_sampler_server_timeout_keep_alive,
                    ssl_keyfile=None,
                    ssl_certfile=None,
                    ssl_ca_certs=None,
                    ssl_cert_reqs=None)
    find_process_using_port(ep_ip, ep_port)
    print_with_rank_and_datetime(f'run_grpo_gen_rm_server http://{ep_ip}:{ep_port}')
    serve_forever_fn()


def run_grpo_gen_rm_worker(trainer, model_provider, gen_rm_func):
    args = get_args()

    while True:
        cmd_obj = [None]
        torch.distributed.broadcast_object_list(cmd_obj,
                                                src=get_model_parallel_src_rank_gloo(),
                                                group=get_model_parallel_group_gloo())
        cmd_obj = cmd_obj[0]

        if cmd_obj['cmd'] == 'setup':
            if args.infer_engine_impl == 'sglang' and mpu.get_tensor_model_parallel_rank(
            ) % args.num_gpus_per_node == 0:
                trainer.engine = model_provider()
        elif cmd_obj['cmd'] == 'wakeup_log_memory':
            torch.cuda.synchronize()
            print_memory_tracking(f"Memory tracking: gen_rm after wake_up")
        elif cmd_obj['cmd'] == 'sleep_log_memory':
            torch.cuda.synchronize()
            print_memory_tracking(f"Memory tracking: gen_rm after sleep")
        elif cmd_obj['cmd'] == 'setup_log_memory':
            torch.cuda.synchronize()
            print_memory_tracking(f"Memory tracking: gen_rm after setup")
        else:
            pass


class GrpoGenRmTrainerV3:

    def __init__(self):
        self.engine = None

    def post_init(self):
        args = get_args()

    @torch.no_grad()
    def generate_rewards(self,
                         gen_rm_func,
                         batch: Optional[Dict[str, Union[int, List[Any]]]]):
        args = get_args()
        timers = get_timers()
        assert is_mp_head()
        rollout_batch = gen_rm_func(self.engine, batch)

        return rollout_batch


                                                                
def run_grpo_gen_rm_v3(
    trainer,
    model_provider,
    gen_rm_func,
    extra_args_provider=None,
):
    '''
    Serve GRPO / PPO gen rm (vllm / sglang).

    Parameters
    ----------

    trainer : gpatch.training.v3.grpo_gen_rm.GrpoGenRmTrainerV3

    model_provider : Callable
        A callable to provide vllm / sglang

    gen_rm_func : Callable
        A callable to call vllm / sglang for generation (given responses from actor).

    Returns
    -------
    '''
    initialize_megatron(extra_args_provider=extra_args_provider)
    args = get_args()
    validate_rl_args(args)
    args.rl_role = 'gen-rm'
    parse_dataset_config(args)
    set_global_variables(args)
    init_pg(distributed_timeout_minutes=args.distributed_timeout_minutes)

    trainer.post_init()
    assert args.expert_model_parallel_size == 1
    assert args.context_parallel_size == 1
    assert args.use_gen_rm
    assert args.no_fused_kernel
    assert args.use_tp_pp_dp_mapping

    if args.infer_engine_impl == 'sglang':
        from gpatch.core.sampler_v3.sglang import sglang_hack
        sglang_hack()

    cpu_barrier()
    print_with_rank_and_datetime(f"Init: memory trace {torch.cuda.memory_allocated() / (1024**3)} GB")
    if is_mp_head():
        run_grpo_gen_rm_server(trainer, model_provider, gen_rm_func)
    else:
        run_grpo_gen_rm_worker(trainer, model_provider, gen_rm_func)
