                                                      
                                                                 

from contextlib import asynccontextmanager
import asyncio
from optparse import Option
import os
import threading
import time
import io
import queue
from typing import Dict, List, Union, Any, Optional
from packaging.version import Version

from megatron.training.training import _TRAIN_START_TIME                                       

from fastapi import Request
from fastapi.responses import JSONResponse, Response
import fastapi
import torch
import uvicorn

from megatron.core import package_info
from megatron.core import mpu
from megatron.core import parallel_state
from megatron.core.utils import get_model_config
from megatron.training import get_tokenizer
from megatron.training.checkpointing import load_checkpoint
from megatron.training.initialize import initialize_megatron
from megatron.training.initialize import set_jit_fusion_options

_FP16_Module_Ver = None
try:
    from megatron.legacy.model import Float16Module
    _FP16_Module_Ver = "legacy_ver"
except ImportError:
    from megatron.core.transformer.module import Float16Module
    _FP16_Module_Ver = "transformer_ver"

from megatron.training.training import (
    print_datetime, )
from megatron.training.global_vars import (
    get_args,
    get_timers,
)
try:
    from megatron.training.global_vars import get_energy_monitor
except ImportError:
    get_energy_monitor = None

from gpatch.rpc import once_rpc
from gpatch.core.aligner_helper import (
    clear_memory,
    retrieve_model_state_dict_in_cpu,
)
from gpatch.training.arguments import validate_rl_args
from gpatch.training.utils import print_rank_0, print_with_rank_and_datetime
from gpatch.training.global_vars import set_global_variables
from gpatch.core.parallel_state import (
    init_pg,
    is_mp_and_cp_head,
    is_mp_head,
    cpu_barrier,
)
from gpatch.core.swap import (
    offload_megatron_model,
    onload_megatron_model,
    offload_megatron_optimizer,
    onload_megatron_optimizer,
)
from gpatch.core.utils import  print_memory_tracking
from gpatch.core.parallel_state import get_model_parallel_group_gloo, get_model_parallel_src_rank_gloo
from gpatch.core.wecube import report_ppo_metrics, init_wecube_reporter
from gpatch.rpc.monitor import start_monitor_client_in_background
from megatron_datasets.args import parse_dataset_config


def run_grpo_rm_server(trainer, model_provider, critic_provider, model_type):
    args = get_args()
    tokenizer = get_tokenizer()
    ep_ip = args.ppo_critic_ips[mpu.get_data_parallel_rank()]
    ep_port = args.ppo_critic_ports[mpu.get_data_parallel_rank()]

    monitor_kwargs = {
        "do_monitor": args.do_monitor,
        "monitor_server_ip": args.monitor_server_ip,
        "monitor_port": args.monitor_port,
    }

                                                                      
    checkpointing_context = {}

    lock = asyncio.Lock()
    computed = False
    batching_reqs: List[Dict[str, Union[int, List[Any]]]] = []
                               
    infer_rm_critic_results: Dict[int, Dict[int, Dict]]  = {}

    @asynccontextmanager
    async def lifespan(app: fastapi.FastAPI):
        yield

    app = fastapi.FastAPI(lifespan=lifespan)

    @app.post("/heartbeat")
    @once_rpc(**monitor_kwargs)
    async def heartbeat(req_dict):
        return {"ret": "ok"}

    @app.post("/exit")
    @once_rpc(**monitor_kwargs)
    async def exit(req_dict):
        return {'ret': 'ok'}

    @app.post("/setup")
    @once_rpc(**monitor_kwargs)
    async def setup(req_dict):

        async with lock:
            if trainer.critic_model is None:
                cmd_obj = [{'cmd': 'setup'}]
                torch.distributed.broadcast_object_list(cmd_obj,
                                                        src=get_model_parallel_src_rank_gloo(),
                                                        group=get_model_parallel_group_gloo())
                trainer.setup(model_provider, critic_provider, model_type)

        assert trainer.critic_model is not None
        return {"ret": 'ok'}

    @app.post("/sleep")
    @once_rpc(**monitor_kwargs)
    async def sleep(req_dict):
        async with lock:
            cmd_obj = [{'cmd': 'sleep'}]
            torch.distributed.broadcast_object_list(cmd_obj,
                                                    src=get_model_parallel_src_rank_gloo(),
                                                    group=get_model_parallel_group_gloo())

            assert trainer.critic_model is not None
            offload_megatron_model(trainer.critic_model.model)
            if trainer.optimizer is not None:
                offload_megatron_optimizer(trainer.optimizer)
        clear_memory()
        torch.cuda.synchronize()
        print_memory_tracking(f"Memory tracking: rm after sleep", verbose=True, rank=0)
        return {"ret": 'ok'}

    @app.post("/mark_rm_ppo_step_begin")
    @once_rpc(**monitor_kwargs)
    async def mark_rm_ppo_step_begin(req_dict):
        ppo_step = req_dict['ppo_step']
        actor_dp_rank = req_dict['actor_dp_rank']

        nonlocal computed
        computed = False
        torch.cuda.synchronize()
        print_memory_tracking(f"Memory tracking: rm before onload", verbose=True, rank=0)

        async with lock:
            cmd_obj = [{'cmd': 'mark_rm_ppo_step_begin'}]
            torch.distributed.broadcast_object_list(cmd_obj,
                                                    src=get_model_parallel_src_rank_gloo(),
                                                    group=get_model_parallel_group_gloo())

            assert trainer.critic_model is not None
            onload_megatron_model(trainer.critic_model.model)
            if trainer.optimizer is not None:
                onload_megatron_optimizer(trainer.optimizer)

        clear_memory()
        torch.cuda.synchronize()
        print_memory_tracking(f"Memory tracking: rm after onload", verbose=True, rank=0)
        return {"ret": 'ok'}

    @app.post("/mark_rm_ppo_step_end")
    @once_rpc(**monitor_kwargs)
    async def mark_rm_ppo_step_end(req_dict):
        ppo_step = req_dict['ppo_step']
        actor_dp_rank = req_dict['actor_dp_rank']

        nonlocal computed
        torch.cuda.synchronize()
        print_memory_tracking(f"Memory tracking: rm before offload", verbose=True, rank=0)

        async with lock:
            infer_rm_critic_results.clear()

            cmd_obj = [{'cmd': 'mark_rm_ppo_step_end'}]
            torch.distributed.broadcast_object_list(cmd_obj,
                                                    src=get_model_parallel_src_rank_gloo(),
                                                    group=get_model_parallel_group_gloo())

            assert trainer.critic_model is not None
            offload_megatron_model(trainer.critic_model.model)
            if trainer.optimizer is not None:
                offload_megatron_optimizer(trainer.optimizer)
        clear_memory()

        torch.cuda.synchronize()
        print_memory_tracking(f"Memory tracking: rm after offload", verbose=True, rank=0)
        return {"ret": 'ok'}

    @app.post("/issue_infer_rm_critic")
    @once_rpc(**monitor_kwargs)
    async def issue_infer_rm_critic(req_dict):
        nonlocal computed
        assert computed == False

        batching_reqs.append(req_dict)

        if args.ppo_wecube_report and req_dict.get('tokens', None) is not None:
            report_data = {
                "rm_infer_inputs":
                len(req_dict['tokens']) if (torch.is_tensor(req_dict['tokens'])) else 1,
            }
            report_ppo_metrics(report_data)
        return {}

    @app.post("/get_infer_rm_critic_result")
    @once_rpc(**monitor_kwargs)
    async def get_infer_rm_critic_result(req_dict):
        actor_dp_rank = req_dict['actor_dp_rank']
        ppo_step = req_dict['ppo_step']
        sample_idx = req_dict['sample_idx']
        sampling_repeat = req_dict['sampling_repeat']

        async with lock:
            nonlocal computed, batching_reqs
            if not computed:
                computed = True
                for batch in batching_reqs:
                    assert batch['ppo_step'] == ppo_step

                cmd_obj = [{'cmd': 'get_infer_rm_critic_result'}]
                torch.distributed.broadcast_object_list(cmd_obj,
                                                        src=get_model_parallel_src_rank_gloo(),
                                                        group=get_model_parallel_group_gloo())
                resp_dicts = trainer.generate_rollouts(trainer.critic_model, batching_reqs, sampling_repeat)

                for batch, b_resp_dict in zip(batching_reqs, resp_dicts, strict=True):
                    b_actor_dp_rank = batch['actor_dp_rank']
                    b_sample_idx = batch['sample_idx']

                    tmpd = infer_rm_critic_results.setdefault(b_actor_dp_rank, {})
                    tmpdd = tmpd.setdefault(ppo_step, {})
                    tmpdd[b_sample_idx] = b_resp_dict
                batching_reqs = []

        tmpd = infer_rm_critic_results[actor_dp_rank]
        tmpdd = tmpd[ppo_step]
        resp_dict = tmpdd[sample_idx]

        return resp_dict

    def serve_forever_fn():
        uvicorn.run(app,
                    host=ep_ip,
                    port=ep_port,
                    log_level='error',
                    use_colors=False,
                    timeout_keep_alive=args.ppo_rm_critic_server_timeout_keep_alive,
                    ssl_keyfile=None,
                    ssl_certfile=None,
                    ssl_ca_certs=None,
                    ssl_cert_reqs=None)

    print_with_rank_and_datetime(f'critic_server http://{ep_ip}:{ep_port}')
    serve_forever_fn()


def run_grpo_rm_worker(trainer, model_provider, critic_provider, model_type):
    while True:
        cmd_obj = [None]
        torch.distributed.broadcast_object_list(cmd_obj,
                                                src=get_model_parallel_src_rank_gloo(),
                                                group=get_model_parallel_group_gloo())
        cmd_obj = cmd_obj[0]

        if cmd_obj['cmd'] == 'setup':
            assert trainer.critic_model is None
            trainer.setup(model_provider, critic_provider, model_type)
            assert trainer.critic_model is not None

        elif cmd_obj['cmd'] == 'mark_rm_ppo_step_begin':
            assert trainer.critic_model is not None
            onload_megatron_model(trainer.critic_model.model)
            if trainer.optimizer is not None:
                onload_megatron_optimizer(trainer.optimizer)

        elif cmd_obj['cmd'] in ['sleep', 'mark_rm_ppo_step_end']:
            assert trainer.critic_model is not None
            offload_megatron_model(trainer.critic_model.model)
            if trainer.optimizer is not None:
                offload_megatron_optimizer(trainer.optimizer)
            clear_memory()
        elif cmd_obj['cmd'] == 'get_infer_rm_critic_result':
            trainer.generate_rollouts(trainer.critic_model, None, None)
        else:
            raise ValueError(f'unknown cmd obj {cmd_obj}')


class GrpoRmTrainerV3:

    def __init__(self):
        self.critic_model = None
        self.optimizer = None
        self.opt_param_scheduler = None
        self.save_first = True
        self.mtx = threading.Lock()
        self.cb_q = queue.SimpleQueue()
        self.first_time_of_batching = None
        self.batching_samples = []

    def setup_model_and_optimizer_and_rm(self,
                                         model_provider_func,
                                         critic_provider,
                                         model_type,
                                         no_wd_decay_cond=None,
                                         scale_lr_cond=None,
                                         lr_mult=1.0):
        """Setup model and optimizer."""
        args = get_args()
        rm_state_dicts = []
        ppo_reward_scalings = args.ppo_reward_scalings
        rm_ref_factors = args.rm_ref_factors
        if len(rm_ref_factors) == 0 and len(args.load_ref) > 0:
            rm_ref_factors = [1.0 / len(args.load_ref)] * len(args.load_ref)
        if len(ppo_reward_scalings) == 0 and len(args.load_ref) > 0:
            ppo_reward_scalings = [1.0] * len(args.load_ref)
        assert len(args.load_ref) == len(rm_ref_factors)
        assert len(args.load_ref) == len(ppo_reward_scalings)
        if args.rm_output_sequence is not None:
            assert len(args.load_ref) == len(args.rm_output_sequence)

        if args.use_grpo and len(args.load_ref) == 1:
            rm_state_dicts.append(None)
            args.load = args.load_ref[0]
        else:
                   
            for index in range(len(args.load_ref)):
                rm_model = model_provider_func(pre_process=mpu.is_pipeline_first_stage(),
                                               post_process=mpu.is_pipeline_last_stage())
                args.load_ref_tmp = args.load_ref[index]
                load_checkpoint([rm_model], None, None, load_arg='load_ref_tmp')
                rm_state_dicts.append(retrieve_model_state_dict_in_cpu(rm_model))
                del rm_model
            del args.load_ref_tmp

        model, optimizer, opt_param_scheduler = None, None, None
        if args.ppo_grpo_reward_type != "rule_only":
            model = model_provider_func(pre_process=mpu.is_pipeline_first_stage(),
                                        post_process=mpu.is_pipeline_last_stage())
                               
                                                                                       
            model = [model]
                             
            for model_module in model:
                model_module.cuda(torch.cuda.current_device())
                              
            if args.fp16 or args.bf16:
                assert _FP16_Module_Ver in [ "legacy_ver", "transformer_ver" ]
                if _FP16_Module_Ver == "legacy_ver":
                    model = [Float16Module(model_module, args) for model_module in model]
                else:
                    config = get_model_config(model[0])
                    model = [Float16Module(config, model_module) for model_module in model]

            args.iteration, args.num_floating_point_operations_so_far = load_checkpoint(
                model, None, None, load_arg='load')
        else:
            args.iteration = 0
            args.num_floating_point_operations_so_far = 0

        critic_model = critic_provider(reward_model=model, )
        critic_model.rm_state_dicts = rm_state_dicts
        critic_model.rm_factors = rm_ref_factors
        critic_model.ppo_reward_scalings = ppo_reward_scalings
        return critic_model, optimizer, opt_param_scheduler

    def setup(self, model_provider, critic_provider, model_type):
        args = get_args()
        timers = get_timers()

        torch.cuda.synchronize()
        print_memory_tracking(f"Memory tracking: rm before build model", verbose=True, rank=0)
                                              
        timers('model-and-optimizer-setup', log_level=0).start(barrier=True)
        critic_model, optimizer, opt_param_scheduler = self.setup_model_and_optimizer_and_rm(
            model_provider, critic_provider, model_type)
        timers('model-and-optimizer-setup').stop()
        print_datetime('after model, optimizer, and learning rate scheduler are built')
        torch.cuda.synchronize()
        print_memory_tracking(f"Memory tracking: rm after build model", verbose=True, rank=0)
        config = get_model_config(critic_model)

                                           
        if optimizer is not None:
            config.grad_scale_func = optimizer.scale_loss
        config.timers = timers
        assert not args.overlap_grad_reduce and not args.overlap_param_gather
        assert config.no_sync_func is None
        assert config.grad_sync_func is None
        assert config.param_sync_func is None
        assert config.finalize_model_grads_func is None

        assert not args.skip_train
        self.critic_model = critic_model
        self.optimizer = optimizer
        self.opt_param_scheduler = opt_param_scheduler
        self.mlm_model_config = config

                             
        print_rank_0('done with setup ...')
        timers.log(['model-and-optimizer-setup', 'train/valid/test-data-iterators-setup'],
                   barrier=True)

    def decode_inputs_and_call_infer_fn(
        self,
        infer_fn,
        batches: List[Dict[str, Union[int, List[Any]]]],
        sampling_repeat,
    ):
        args = get_args()

        rewards, values, per_token_rewards, exceeded, custom_rewards = infer_fn(batches, sampling_repeat)
        rewards = rewards.cpu()
        if values is not None:
            values = values.cpu()
        if per_token_rewards is not None:
            per_token_rewards = per_token_rewards.cpu()

        for k, v in custom_rewards.items():
                                             
            if torch.is_tensor(v):
                custom_rewards[k] = v.cpu()

        if not (args.use_grpo and args.ppo_grpo_reward_type == "rule_only"):
            extra = 0
            exceeded = exceeded[:len(exceeded) - extra]
            exceeded = torch.tensor(exceeded, dtype=torch.int32).reshape(-1, 1)
        return rewards, values, per_token_rewards, exceeded, custom_rewards

    def run_inference(
        self,
        critic_model,
        batches: Optional[List[Dict[str, Union[int, List[Any]]]]] = None,
        sampling_repeat=None,
    ) -> List[Dict[str, List[Any]]]:
        args = get_args()
        timers = get_timers()
        tokenizer = get_tokenizer()

        dp_size = parallel_state.get_data_parallel_world_size()
        print_with_rank_and_datetime(f'rm rollout run_inference')
        timers('infer', log_level=0).start()

        assert args.combine_rm_and_critic_server
        assert args.use_grpo
        infer_fn = critic_model.infer_rm_only

        gen_t0 = time.time()
        if is_mp_head():

            if args.ppo_save_first_rollout_data:
                if self.save_first:
                    debug_dir = "./debug-tmp"
                    os.makedirs(debug_dir, exist_ok=True)
                    torch.save(batches, f"{debug_dir}/rm_input_{torch.distributed.get_rank()}.pt")
                self.save_first = False

            outputs = self.decode_inputs_and_call_infer_fn(infer_fn, batches=batches, sampling_repeat=sampling_repeat)
            rewards, values, per_token_rewards, exceeded, custom_rewards = outputs
        else:
            infer_fn()
            rewards, values, per_token_rewards, exceeded, custom_rewards = None, None, None, None, {}
        gen_t1 = time.time()

        resp_dict = {
            "values": values,
            "exceeded": exceeded,
            "rewards": rewards,
            "per_token_rewards": per_token_rewards,                   
        }
        for k, v in custom_rewards.items():
            resp_dict[k] = v
        chunked_resp_dicts = []

        log_string = f'rm microbatch elapsed {gen_t1 - gen_t0:3f}'
        print_with_rank_and_datetime(log_string)

        if is_mp_head():
            num_batch = len(batches)
            for k in resp_dict.keys():
                if resp_dict[k] is not None:
                    resp_dict[k] = resp_dict[k].chunk(num_batch)

            for b_i in range(num_batch):
                infer_ret: Dict[str, List[Any]] = {}
                for k, v in resp_dict.items():
                    if v is not None:
                        tmp_v = v[b_i]
                        if torch.is_tensor(tmp_v):
                            tmp_v = [e.squeeze(0) for e in tmp_v.chunk(tmp_v.shape[0])]
                        assert isinstance(tmp_v, list)
                        infer_ret[k] = tmp_v
                        assert sampling_repeat == len(tmp_v)
                for k, v in resp_dict.items():
                    if v is None:
                        infer_ret[k] = [None] * sampling_repeat
                chunked_resp_dicts.append(infer_ret)

        timers('infer').stop()
        return chunked_resp_dicts

    @torch.no_grad()
    def generate_rollouts(
        self,
        critic_model,
        batches: Optional[List[Dict[str, Union[int, List[Any]]]]] = None,
        sampling_repeat=None,
    ) -> List[Dict[str, List[Any]]]:
        critic_model.prepare_for_inference()
        ret = self.run_inference(critic_model, batches, sampling_repeat)
        critic_model.finish_inference()
        return ret


                                                            
                    
def run_grpo_rm_v3(
    trainer,
    model_provider,
    critic_provider,
    model_type,
    process_non_loss_data_func=None,
    extra_args_provider=None,
    args_defaults={},
    store=None,
):
    '''
    Serve GRPO / PPO RM.

    Parameters
    ----------

    trainer : gpatch.training.v3.grpo_rm.GrpoRmTrainerV3

    model_provider : Callable
        A callable to provide megatron.GPTModel

    critic_provider : Callable
        A callable to provide GptPpoCriticModel wrapper for rm and critic model

    model_type : megatron.core.enums.ModelType
        megatron model type, use ModelType.encoder_or_decoder for causal model.

    Returns
    -------
    '''
    mcore_version = Version(package_info.__version__)
    extra_args = {}
    if mcore_version >= Version("0.13.0"):
        extra_args = {"store": store}

                                                                  
    initialize_megatron(extra_args_provider=extra_args_provider,
                        args_defaults=args_defaults,
                        **extra_args)
    args = get_args()
    validate_rl_args(args)
    args.rl_role = 'rm'
    parse_dataset_config(args)
    set_global_variables(args)
    init_pg(distributed_timeout_minutes=args.distributed_timeout_minutes)

    args = get_args()
    if args.ppo_wecube_report:
        init_wecube_reporter()

    if mcore_version >= Version("0.13.0"):
        energy_monitor = get_energy_monitor()
        energy_monitor.setup()

    if args.ppo_grpo_reward_type == "rule_only":
        assert args.no_fused_kernel

                                                                    
    set_jit_fusion_options()

                                                               
                                                                
                         
    global _TRAIN_START_TIME
    start_time_tensor = torch.tensor([_TRAIN_START_TIME], dtype=torch.double, device='cuda')
    torch.distributed.all_reduce(start_time_tensor, op=torch.distributed.ReduceOp.MIN)
    _TRAIN_START_TIME = start_time_tensor.item()
    print_rank_0('time to initialize megatron (seconds): {:.3f}'.format(time.time() -
                                                                        _TRAIN_START_TIME))
    print_datetime('after megatron is initialized')

                      
                                                                                        
                                                                   
    assert mpu.get_expert_model_parallel_world_size() == 1

    cpu_barrier()
    if is_mp_and_cp_head():
        run_grpo_rm_server(trainer, model_provider, critic_provider, model_type)
    else:
        run_grpo_rm_worker(trainer, model_provider, critic_provider, model_type)
