              
                                                      
                                                                 

import os
import sys
import time
import copy
import uuid
import asyncio
import traceback

import torch

from megatron.core import mpu
from megatron.training.global_vars import get_tokenizer
from megatron.training.utils import print_rank_0
from megatron.training.global_vars import get_args

from gpatch.training.v3.grpo_sampler import run_grpo_sampler_v3, GrpoSamplerV3
from gpatch.core.parallel_state import is_mp_and_cp_head, get_mp_and_cp_size
from gpatch.core.sampler_v3.infer_engine import InferEngine
from gpatch.patch_mcore import init_gpatch_for_mcore
from gpatch.training.v3.default_model_provider import default_sampler_model_provider

from tasks.math_rl_v3.args import get_tasks_args

model_provider = default_sampler_model_provider


def get_sampling_params(engine):
    tokenizer = get_tokenizer()
    args = get_args()
    stop_at_token_id = None
    if args.model_arch in ['qwen2.5-math-1.5b']:
        stop_at_token_id = tokenizer._tokenizer("<|endoftext|>", add_special_tokens=False).input_ids
        assert len(stop_at_token_id) == 1 and isinstance(stop_at_token_id[0], int)
        stop_at_token_id = stop_at_token_id[0]
    else:
        stop_at_token_id = tokenizer._tokenizer.eos_token_id

                                      
    return engine.get_sampling_params(
        n=1,                                    
        temperature=args.ppo_rollout_temperature,
        top_k=args.ppo_rollout_top_k if args.ppo_rollout_top_k > 0 else -1,
        top_p=args.ppo_rollout_top_p,
        max_tokens=args.ppo_resp_seq_len,
        stop_token_ids=[stop_at_token_id],
        seed=args.seed,
    )


                                                                
                                                          
                                                                
                      
                                                 
 
                                                   
                                                                                                 
                                    
           
                                               
                                                          
                                                                             
                                          
     
@torch.no_grad()
async def gen_rollouts(engine, batch, sampling_repeat):
                                                                                         
                               
    args = get_args()
    prompt_token_ids, lpad_lens, gt_label = batch["prompt_token_ids"], batch["lpad_lens"], batch[
        "gt_label"]

                                      
                                                   
                  
    sampling_params = get_sampling_params(engine)
    gens = []
    for i in range(len(prompt_token_ids)):
        for j in range(sampling_repeat):
            tmp_sampling_params = copy.deepcopy(sampling_params)
            tmp_sampling_params.seed += i * sampling_repeat + j
            gen = engine.async_generate(prompt_token_ids[i], tmp_sampling_params,
                                        str(uuid.uuid4().hex))
            gens.append(gen)

    tokens = []
    sequence_lengths = []
    max_response_len = 0
    prompt_lens = []
    gt_label_list = []
    gen_outputs = await engine.wait_and_get_async_generate_output(gens)

    for gi, gen in enumerate(gens):
        i = gi // sampling_repeat
        j = gi % sampling_repeat
        output = gen_outputs[gi]
        assert len(output.outputs) == 1
        output_tokens = list(output.outputs[0].token_ids)
        token = prompt_token_ids[i]['prompt_token_ids'] + output_tokens
        assert len(token) <= args.seq_length
        tokens.append(torch.tensor(token, dtype=torch.long))
        sequence_lengths.append(torch.tensor(len(token), dtype=torch.long))
        max_response_len = max(max_response_len, len(token))
        prompt_lens.append(lpad_lens[i])
        gt_label_list.append(gt_label[i])

    rollout_batch = {
        'tokens': tokens,
        'sequence_lengths':sequence_lengths,
        'prompt_lengths': prompt_lens,
        'gt_label': gt_label_list,
    }
    return rollout_batch


if __name__ == "__main__":
    init_gpatch_for_mcore()
    grpo_sampler = GrpoSamplerV3()
    run_grpo_sampler_v3(
        grpo_sampler,
        model_provider,
        gen_rollouts,
        extra_args_provider=get_tasks_args,
    )
