# coding=utf-8
import logging
logging.getLogger("httpx").setLevel(logging.WARNING)
import os
import sys
import time
import copy
import uuid
import asyncio
import traceback
from typing_extensions import override
import math

import torch
try:
    import vllm
    from vllm.engine.arg_utils import AsyncEngineArgs
    from vllm.v1.engine.async_llm import AsyncLLM
except ImportError:
    pass

try:
    import sglang as sgl
    from sglang.srt.managers.io_struct import (
        ReleaseMemoryOccupationReqInput,
        ResumeMemoryOccupationReqInput,
        UpdateWeightFromDiskReqInput,
    )
except ImportError:
    pass

from megatron.core import mpu
from megatron.training.global_vars import get_tokenizer
from megatron.training.utils import print_rank_0
from megatron.training.global_vars import get_args

from gpatch.training.v3.grpo_sampler import run_grpo_sampler_v3, GrpoSamplerV3
from gpatch.core.parallel_state import is_mp_and_cp_head, get_mp_and_cp_size
from gpatch.core.sampler_v3.infer_engine import InferEngine
from gpatch.patch_mcore import init_gpatch_for_mcore
from gpatch.training.v3.default_model_provider import default_sampler_model_provider

from tasks.math_rl_v3.args import get_tasks_args

class InferEngine(InferEngine):
    @override
    @staticmethod
    def from_engine_args(
        infer_engine_impl,
        model=None,
        dtype='bfloat16',
        distributed_executor_backend='ray',
        tensor_parallel_size=1,
        pipeline_parallel_size=1,
        gpu_memory_utilization=0.67,
        enforce_eager=True,
        tp_rank=None,
        pp_rank=None,
        dp_rank=None,
        dist_init_addr=None,
        num_gpus_per_node=8,
        infer_engine_role=None,
        use_fast=False,
        context_length=8*1024,
    ):
        assert infer_engine_role in [None, "sampler", "gen-rm"]
        infer_engine = None
        assert model is not None
        assert tp_rank is not None and pp_rank is not None and dp_rank is not None
        model_path = model if isinstance(model, str) else model[0]
        if infer_engine_impl == 'vllm':
            mm_processor_kwargs = {}
            if use_fast:
                mm_processor_kwargs = dict(use_fast=True)
            engine_args = AsyncEngineArgs(
                model=model_path,
                dtype=dtype,
                distributed_executor_backend=distributed_executor_backend,
                tensor_parallel_size=tensor_parallel_size,
                pipeline_parallel_size=pipeline_parallel_size,
                gpu_memory_utilization=gpu_memory_utilization,
                enforce_eager=enforce_eager,
                trust_remote_code=True,
                enable_sleep_mode=True,
                mm_processor_kwargs=mm_processor_kwargs,
            )
            infer_engine = AsyncLLM.from_engine_args(engine_args)
        else:
            assert infer_engine_impl == 'sglang'
            assert dist_init_addr is not None
            nnodes = int(math.ceil(tensor_parallel_size / num_gpus_per_node))
            node_rank = tp_rank // num_gpus_per_node
            base_gpu_id = (dp_rank * tensor_parallel_size) % num_gpus_per_node

            args = get_args()
            context_length = min(args.seq_length,context_length)

            server_args = sgl.ServerArgs(
                model_path=model_path,
                tp_size=tensor_parallel_size,
                dist_init_addr=dist_init_addr,
                nnodes=nnodes,
                node_rank=node_rank,
                base_gpu_id=base_gpu_id,
                mem_fraction_static=gpu_memory_utilization,
                trust_remote_code=True,
                enable_memory_saver=True,
                skip_tokenizer_init=True,
                context_length=context_length,
                allow_auto_truncate=True,
            )
            infer_engine = sgl.Engine(server_args=server_args)

        assert infer_engine != None, f"infer engine not initialized {infer_engine_impl=}"
        return InferEngine(infer_engine_impl, infer_engine, model, infer_engine_role)


# model_provider = default_sampler_model_provider
def model_provider():
    args = get_args()
    dtype = 'bfloat16'
    tp_size = args.tensor_model_parallel_size
    pp_size = args.pipeline_model_parallel_size
    use_fast = args.px_use_fast_tokenizer
    # https://github.com/vllm-project/vllm/issues/17902
    assert tp_size == args.ppo_sampler_tensor_model_parallel_size
    assert pp_size == args.ppo_sampler_pipeline_model_parallel_size

    return InferEngine.from_engine_args(
        args.infer_engine_impl,
        model=args.load,
        dtype=dtype,
        distributed_executor_backend='ray',
        tensor_parallel_size=tp_size,
        pipeline_parallel_size=pp_size,
        gpu_memory_utilization=args.sampler_gpu_memory_utilization,
        enforce_eager=True,
        tp_rank=mpu.get_tensor_model_parallel_rank(),
        pp_rank=mpu.get_pipeline_model_parallel_rank(),
        dp_rank=mpu.get_data_parallel_rank(),
        dist_init_addr=args.sampler_dist_init_addrs[mpu.get_data_parallel_rank()],
        infer_engine_role="sampler",
        use_fast=use_fast,
        context_length=args.seq_length,
    )

def get_sampling_params(engine):
    tokenizer = get_tokenizer()
    args = get_args()
    stop_at_token_id = None
    if args.model_arch in ['qwen2-72b', 'qwen2.5-1.5b', 'qwen2.5-math-1.5b']:
        stop_at_token_id = tokenizer._tokenizer("<|endoftext|>", add_special_tokens=False).input_ids
        assert len(stop_at_token_id) == 1 and isinstance(stop_at_token_id[0], int)
        stop_at_token_id = stop_at_token_id[0]
    else:
        stop_at_token_id = tokenizer._tokenizer.eos_token_id

    # Create a sampling params object.
    return engine.get_sampling_params(
        n=1,  # by passing vllm async llm issues
        temperature=args.ppo_rollout_temperature,
        top_k=args.ppo_rollout_top_k if args.ppo_rollout_top_k > 0 else -1,
        top_p=args.ppo_rollout_top_p,
        max_tokens=args.ppo_resp_seq_len,
        stop_token_ids=[stop_at_token_id],
        seed=args.seed,
    )


# tokens 包括 prompt 与 output，sequence_lengths 包括 prompt 与 output。
# 如果你的 RM 和 actor 不是一个系列的 model（甚至于一个系列，但 tokenizer 略有不同），
# 很明显你需要修改你的 prompt，以及 prompt 修改后，用另一个 tokenizer tokenize 出来的结果。
# 如果你有多个 RM，修改中间的数字即可。
# 我这里写了个简化的示意（用 cpu 处理的 if/else 特定业务逻辑这里实在是不方便写）。
#
# 千万要注意的是：transformers tokenizer 有很多需要留意的 pitfalls！
# 例如：https://git.xxx.com/wepsdl/docker-images/blob/master/test-cases/case-1-transformers/train.py
# 不要相信你的直觉，一定要把结果 print 出来人工检查，增加测试！
# ```python
# rollout_batch['rm_0_tokens'] = tokens.flip(1)
# rollout_batch['rm_0_prompt_lengths'] = lpad_lens.clone()
# rollout_batch['rm_0_sequence_lengths'] = generated_sequence_lengths.clone()
# rollout_batch['rm_0_output_mask'] = None
# ```
@torch.no_grad()
async def gen_rollouts(engine, batch, sampling_repeat):
    # batch 就是 actor 的 rollout_get_batch 的返回值，prompt_token_ids 是 list of dict，lpad_lens 和
    # gt_label 是 [b,] 的 tensor。
    args = get_args()
    prompt_token_ids, lpad_lens, gt_label = batch["prompt_token_ids"], batch["lpad_lens"], batch[
        "gt_label"]
    
    if len(prompt_token_ids) >= args.seq_length - args.ppo_resp_seq_len:
        torch.save(
            {
                "prompt_token_ids": prompt_token_ids,
                "lpad_lens": lpad_lens,
                "gt_label": gt_label,
            },
            f"save_long_prompt_{torch.distributed.get_rank()}_{len(prompt_token_ids)}.pt" 
        )

    # 不同的序列使用不同的seed，这样保证了相同输入的生成结果随机性
    # 另外，如果有拒绝采样（不符合要求需多次采样），的情况，还得再吃修改 seed，可考虑加一个
    # time.time()。
    sampling_params = get_sampling_params(engine)
    gens = []
    for i in range(len(prompt_token_ids)):
        for j in range(sampling_repeat):
            tmp_sampling_params = copy.deepcopy(sampling_params)
            tmp_sampling_params.seed += i * sampling_repeat + j
            gen = engine.async_generate(prompt_token_ids[i], tmp_sampling_params,
                                        str(uuid.uuid4().hex))
            gens.append(gen)

    tokens = []
    sequence_lengths = []
    max_response_len = 0
    prompt_lens = []
    gt_label_list = []
    gen_outputs = await engine.wait_and_get_async_generate_output(gens)

    for gi, gen in enumerate(gens):
        i = gi // sampling_repeat
        j = gi % sampling_repeat
        output = gen_outputs[gi]
        assert len(output.outputs) == 1
        output_tokens = list(output.outputs[0].token_ids)
        token = prompt_token_ids[i]['prompt_token_ids'] + output_tokens
        assert len(token) <= args.seq_length
        tokens.append(torch.tensor(token, dtype=torch.long))
        sequence_lengths.append(torch.tensor(len(token), dtype=torch.long))
        max_response_len = max(max_response_len, len(token))
        prompt_lens.append(lpad_lens[i])
        gt_label_list.append(gt_label[i])

    rollout_batch = {
        'tokens': tokens,
        'sequence_lengths':sequence_lengths,
        'prompt_lengths': prompt_lens,
        'gt_label': gt_label_list,
    }
    return rollout_batch


if __name__ == "__main__":
    init_gpatch_for_mcore()
    grpo_sampler = GrpoSamplerV3()
    run_grpo_sampler_v3(
        grpo_sampler,
        model_provider,
        gen_rollouts,
        extra_args_provider=get_tasks_args,
    )
