# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The vllm_rollout that can be applied in different backend
When working with FSDP:
- Use DTensor weight loader (recommended) or HF weight loader
- Utilize state_dict from the FSDP to synchronize the weights among tp ranks in vLLM
When working with Megatron:
- Use Megatron weight loader
- During training, only the current pp stage holds the parameters
- Before inference, broadcast the parameters of the current pp rank to all other pp ranks (all pp ranks holds all the parameters)
- Bind the parameters to the inference engine
- Do inference in tp. pp is treated as additional dp
- After inference, all the parameters that doesn't belong to this pp rank is freed.
"""
import asyncio
from concurrent.futures import ProcessPoolExecutor
from itertools import chain

import numpy as np
from typing import List
from contextlib import contextmanager

import vllm.sequence
from omegaconf import DictConfig
import torch
import torch.distributed
from tensordict import TensorDict
from torch import nn
from typing import Any, Union
from verl import DataProto
from verl.utils.reward_score import _default_compute_score
from verl.utils.reward_score.prime_math import match_answer, math_normalize, _normalize
from verl.utils.torch_functional import get_eos_mask, pad_2d_list_to_length
from verl.workers.reward_manager import PrimeRewardManager
from verl.workers.reward_manager.prime import parallel_compute_score_async
from verl.workers.rollout.base import BaseRollout
from vllm.distributed import parallel_state as vllm_ps
from vllm import LLM, SamplingParams
from verl.third_party.vllm import vllm_version
import os

from verl.workers.rollout.vllm_rollout.vllm_rollout import summarize_prompt_format


# TODO
# 1. support pp in vllm
# 2. passing tokenizer is not necessary? no encoding/decoding is happending here
# 3. simplify init logics

def extract_math(sequence):
    _, extracted_answer = match_answer(sequence)
    normalized_answer = math_normalize.normalize_answer(extracted_answer)
    normalized_answer2 = _normalize(normalized_answer)
    if normalized_answer2 is None or len(normalized_answer2) == 0:
        normalized_answer2 = None
    return normalized_answer2


# NOTE(sgm): add for verl. We can optimize it by making the dataloader yield List[int] without padding.
def _pre_process_inputs(pad_token_id, prompt_token_ids: torch.Tensor) -> List[int]:
    # remove the left padding in the prompt token_id
    # pad_token_id = self.llm_engine.tokenizer.pad_token_id if self.llm_engine.tokenizer.pad_token_id is not None else self.llm_engine.tokenizer.eos_token_id
    non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][0]
    token_ids = prompt_token_ids[non_pad_index:].tolist()
    return token_ids


def _repeat_interleave(value: Union[torch.Tensor, np.ndarray], repeats: int) -> Union[torch.Tensor, List[Any]]:
    if isinstance(value, torch.Tensor):
        return value.repeat_interleave(repeats, dim=0)
    else:
        return np.repeat(value, repeats, axis=0)


class vLLMRollout(BaseRollout):

    def __init__(self, model_path: str, config: DictConfig, tokenizer, model_hf_config, **kwargs):
        """A vLLM rollout. It requires the module is supported by the vllm.

        Args:
            module: module here follows huggingface APIs
            config: DictConfig
            tokenizer: the task/model tokenizer
            model_hf_config: the huggingface config to initiallize the generating model in vllm
            **kwargs: train_tp, for Megatron Backend to initialize hybrid engine (zero redundancy) process group
        """
        super().__init__()
        self.config = config
        assert not (not config.enforce_eager and config.free_cache_engine), \
            "disable CUDA graph (enforce_eager = False) if free cache engine"

        tensor_parallel_size = self.config.get('tensor_model_parallel_size', 1)
        assert tensor_parallel_size <= torch.distributed.get_world_size(), \
            "tensor parallel size should be less than or equal to the world size"
        max_num_batched_tokens = self.config.get('max_num_batched_tokens', 8192)

        if kwargs.get('train_tp', None) is not None:
            # deployed with megatron

            os.environ['CUDA_TIMER_STREAM_KAFKA_ENABLE'] = '0'
            os.environ['MEGATRON_IMPORT_TIMERS'] = '0'
            train_tp = kwargs.get('train_tp', None)
            num_tp_per_train_tp = train_tp // tensor_parallel_size
            vllm_ps.initialize_parallel_state(tensor_model_parallel_size=tensor_parallel_size,
                                              num_tp_per_train_tp=num_tp_per_train_tp)

        assert model_hf_config.max_position_embeddings >= config.prompt_length + config.response_length, \
            "model context length should be greater than total sequence length"

        self.inference_engine = LLM(
            model=model_path,
            enable_sleep_mode=True,
            tensor_parallel_size=tensor_parallel_size,
            distributed_executor_backend="external_launcher",
            dtype=config.dtype,
            enforce_eager=config.enforce_eager,
            gpu_memory_utilization=config.gpu_memory_utilization,
            disable_custom_all_reduce=True,
            skip_tokenizer_init=False,
            max_model_len=config.prompt_length + config.response_length,
            disable_log_stats=config.disable_log_stats,
            max_num_batched_tokens=max_num_batched_tokens,
            enable_chunked_prefill=config.enable_chunked_prefill,
            enable_prefix_caching=True,
            seed=int(os.getenv("RANK", "0")) // tensor_parallel_size,
        )

        # Offload vllm model to reduce peak memory usage
        self.inference_engine.sleep(level=1)

        kwargs = dict(
            n=1,
            logprobs=1,  # can be set to 0 and let actor to recompute
            max_tokens=config.response_length,
        )

        # # we may detokenize the result all together later
        if vllm_version != '0.3.1':
            kwargs['detokenize'] = False

        # supporting adding any sampling params from the config file
        for k in config.keys():
            if hasattr(SamplingParams(), str(k)):
                kwargs[k] = config.get(k)

        print(f"kwargs: {kwargs}")
        self.sampling_params = SamplingParams(**kwargs)

        self.pad_token_id = tokenizer.pad_token_id

    @contextmanager
    def update_sampling_params(self, **kwargs):
        # update sampling params
        old_sampling_params_args = {}
        if kwargs:
            for key, value in kwargs.items():
                if hasattr(self.sampling_params, key):
                    old_value = getattr(self.sampling_params, key)
                    old_sampling_params_args[key] = old_value
                    setattr(self.sampling_params, key, value)
        yield
        # roll back to previous sampling params
        # if len(old_sampling_params_args):
        for key, value in old_sampling_params_args.items():
            setattr(self.sampling_params, key, value)

    @torch.no_grad()
    def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:
        # rebuild vllm cache engine
        if vllm_version in ('0.3.1', '0.4.2', '0.5.4', '0.6.3') and self.config.free_cache_engine:
            self.inference_engine.init_cache_engine()

        idx = prompts.batch['input_ids']  # (bs, prompt_length)
        # left-padded attention_mask
        attention_mask = prompts.batch['attention_mask']
        position_ids = prompts.batch['position_ids']

        # used to construct attention_mask
        eos_token_id = prompts.meta_info['eos_token_id']

        batch_size = idx.size(0)

        non_tensor_batch = prompts.non_tensor_batch
        if 'raw_prompt_ids' not in non_tensor_batch:
            non_tensor_batch['raw_prompt_ids'] = np.array(
                [_pre_process_inputs(self.pad_token_id, idx[i]) for i in range(batch_size)], dtype=object)

        if batch_size != len(non_tensor_batch['raw_prompt_ids']):
            raise RuntimeError('vllm sharding manager is not work properly.')

        if 'multi_modal_data' in non_tensor_batch:
            vllm_inputs = []
            for raw_prompt_ids, multi_modal_data in zip(non_tensor_batch.pop('raw_prompt_ids'),
                                                        non_tensor_batch.pop('multi_modal_data')):
                vllm_inputs.append({'prompt_token_ids': raw_prompt_ids, 'multi_modal_data': multi_modal_data})
        else:
            vllm_inputs = [{
                'prompt_token_ids': raw_prompt_ids
            } for raw_prompt_ids in non_tensor_batch.pop('raw_prompt_ids')]

        is_validating = prompts.meta_info.get('validate', False)
        do_sample = prompts.meta_info.get('do_sample', True)
        if is_validating and do_sample:
            kwargs = {
                'best_of': 1,
                'top_p': 0.95,
                'top_k': -1,
                'min_p': 0.0,
                'temperature': 0.6,
                'n': 1
            }
        elif not do_sample:
            kwargs = {
                'best_of': 1,
                'top_p': 1.0,
                'top_k': -1,
                'min_p': 0.0,
                'temperature': 0,
                'n': 1  # if greedy, only 1 response
            }

        # users can customize different sampling_params at different run
        with self.update_sampling_params(**kwargs):
            outputs = self.inference_engine.generate(
                prompts=vllm_inputs,  # because we have already convert it to prompt token id
                sampling_params=self.sampling_params,
                use_tqdm=False)

        # TODO(sgm): disable logprob when recompute_log_prob is enable
        # if n = 1: (bs, response_length) ; if n > 1: (bs * n, response_length)

        response = []
        logprobs = []
        for output in outputs:
            for sample_id in range(len(output.outputs)):
                response.append(output.outputs[sample_id].token_ids)
                logprobs.append(output.outputs[sample_id].logprobs)
        # with ProcessPoolExecutor(max_workers=16) as pool:
        #     response = list(pool.map(dedup_one, response))

        if self.config.get('dedup', True):
            for i in range(len(response)):
                response[i] = dedupe_tensor(torch.tensor(response[i])).numpy().tolist()

        # 开始summarize前，由于后续部分操作per-prompt进行，将一些key interleave一下
        if self.config.n > 1 and do_sample and not is_validating:
            if 'data_source' in non_tensor_batch.keys():
                non_tensor_batch['data_source'] = _repeat_interleave(non_tensor_batch['data_source'],
                                                                            self.config.n)
            if 'reward_model' in non_tensor_batch.keys():
                non_tensor_batch['reward_model'] = _repeat_interleave(non_tensor_batch['reward_model'],
                                                                     self.config.n)

        if self.config.get('summarize', False) and not is_validating:
            # 用一个写死的prompt重新生成。注意prompt部分所有参数需要repeat_interleave到n
            # 前检查
            # 1. 加入答案校验，只有正确回答允许被摘要
            # 2. 只有非最短的正确回答可以被摘要
            # 后检查（暂时只加1）
            # 1. 摘要前后应该能提取到相同答案
            # 2. 摘要长度不能低于最短的正确回答，prompt里是否需要写明白不能低于多少词
            # 3. 不允许低置信度摘要，logprob不能低于原始回答
            print('summarizing! ')
            tokenizer = self.inference_engine.get_tokenizer()
            original_answers = tokenizer.batch_decode(response, skip_special_tokens=True)
            prompts_raw_questions = [prompts.non_tensor_batch['raw_prompt'][i][-1]['content'] for i in range(len(prompts))]
            # original_prompts = tokenizer.batch_decode(prompts_raw_questions, skip_special_tokens=True)
            original_prompts = _repeat_interleave(prompts_raw_questions, self.config.n)
            summarize_prompt_list = []
            for i in range(len(original_prompts)):
                summarize_prompt_list.append([{'role':'user','content':summarize_prompt_format.format(original_prompts[i], original_answers[i])}])
            summarize_prompt_token_ids = tokenizer.apply_chat_template(summarize_prompt_list, tokenize=True, add_generation_prompt=True)

            data_sources = non_tensor_batch['data_source']
            ground_truth = [ntb['ground_truth'] for ntb in non_tensor_batch['reward_model']]

            try:
                scores = asyncio.run(
                    parallel_compute_score_async(_default_compute_score,
                                                 original_answers,
                                                 ground_truth,
                                                 data_sources,
                                                 num_processes=64))
            except asyncio.TimeoutError as e:
                print('Global timeout in reward computing! Setting all as 0.')
                scores = [0. for _ in range(len(original_answers))]
            except Exception as e:
                print(f"Unexpected error in batched reward computing. Setting all as 0.: {e}")
                scores = [0. for _ in range(len(original_answers))]

            summarize_prompt_lengths = [len(s) for s in summarize_prompt_token_ids]

            summarize_id = []
            for i in range(len(summarize_prompt_token_ids)):
                rangemin = int(i//self.config.n)*self.config.n
                rangemax = rangemin+self.config.n
                if scores[i]>0 \
                        and len(summarize_prompt_token_ids[i])< self.config.response_length-10:
                    summarize_id.append(i)

            chosen_summarize_prompt_token_ids = [summarize_prompt_token_ids[i] for i in summarize_id]

            if chosen_summarize_prompt_token_ids.__len__()>0:
                kwargs['n']=1
                kwargs['temperature']=0
                with self.update_sampling_params(**kwargs):
                    summarize_outputs = self.inference_engine.generate(
                        prompts=None,  # because we have already convert it to prompt token id
                        sampling_params=self.sampling_params,
                        prompt_token_ids=chosen_summarize_prompt_token_ids,
                        use_tqdm=False)
                # old_response=response
                old_logprobs=logprobs
                summarize_response = []
                logprobs = []
                for output in summarize_outputs:
                    for sample_id in range(len(output.outputs)):
                        summarize_response.append(output.outputs[sample_id].token_ids)
                        # logprobs.append(output.outputs[sample_id].logprobs)

                # old_response_str = tokenizer.batch_decode(old_response, skip_special_tokens=True)
                summarize_response_str = tokenizer.batch_decode(summarize_response, skip_special_tokens=True)

                for i, (id, resp) in enumerate(zip(summarize_id,summarize_response_str)):
                    valid_summarize = True
                    old_answer = extract_math(original_answers[id])
                    answer = extract_math(resp)
                    if old_answer is None or answer is None or old_answer != answer:
                        valid_summarize = False

                    # if self.config.get('summarize_confident', False):  # summarize结果必须比原来的结果有更高的logprob下限
                    #     # logprobs: Optional[List[Dict[int, float]]],
                    #     answer_logprob = min([list(logprobdict.values())[0].logprob for logprobdict in logprobs[i]])
                    #     old_answer_logprob = min([list(logprobdict.values())[0].logprob for logprobdict in old_logprobs[i]])
                    #
                    #     if answer_logprob<old_answer_logprob:
                    #         valid_summarize = False

                    if valid_summarize:
                        response[id] = summarize_response[i]



        response = pad_2d_list_to_length(response, self.pad_token_id,
                                         max_length=self.config.response_length).to(idx.device)

        if self.config.n > 1 and do_sample and not is_validating:
            idx = _repeat_interleave(idx, self.config.n)
            attention_mask = _repeat_interleave(attention_mask, self.config.n)
            position_ids = _repeat_interleave(position_ids, self.config.n)
            batch_size = batch_size * self.config.n
            if 'multi_modal_inputs' in non_tensor_batch.keys():
                non_tensor_batch['multi_modal_inputs'] = _repeat_interleave(non_tensor_batch['multi_modal_inputs'],
                                                                            self.config.n)
            if 'raw_prompt' in non_tensor_batch.keys():
                non_tensor_batch['raw_prompt'] = _repeat_interleave(non_tensor_batch['raw_prompt'],
                                                                            self.config.n)

        seq = torch.cat([idx, response], dim=-1)

        response_length = response.size(1)
        delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device)
        delta_position_id = delta_position_id.unsqueeze(0).expand(batch_size, -1)
        if position_ids.dim() == 3:  # qwen2vl mrope
            delta_position_id = delta_position_id.view(batch_size, 1, -1).expand(batch_size, 3, -1)

        # TODO(sgm): fix position_ids on right_pad
        # prompt: left pad + response: right pad
        # attention_mask: [0,0,0,0,1,1,1,1, | 1,1,1,0,0,0,0,0]
        # position_ids:   [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11]
        response_position_ids = position_ids[:, -1:] + delta_position_id
        position_ids = torch.cat([position_ids, response_position_ids], dim=-1)
        response_attention_mask = get_eos_mask(response_id=response, eos_token=eos_token_id, dtype=attention_mask.dtype)
        attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1)

        # all the tp ranks should contain the same data here. data in all ranks are valid
        batch = TensorDict(
            {
                'prompts': idx,
                'responses': response,
                'input_ids': seq,  # here input_ids become the whole sentences
                # 'old_log_probs': log_probs, # we will recompute old log prob with actor
                'attention_mask': attention_mask,
                'position_ids': position_ids
            },
            batch_size=batch_size)

        # free vllm cache engine
        if vllm_version in ('0.3.1', '0.4.2', '0.5.4', '0.6.3') and self.config.free_cache_engine:
            self.inference_engine.free_cache_engine()

        return DataProto(batch=batch, non_tensor_batch=non_tensor_batch)


def dedupe_tensor(x: torch.Tensor, threshold: int = 5) -> torch.Tensor:
    """
    对 1D long tensor x 去除“重复块”：
      - 遍历可能的块长度 n（1..499）
      - 将 x[:m*n] 切成 m 个长度为 n 的块，m = L//n
      - 找到相邻块完全相同的长段（run length >= threshold）
      - 对每段重复，只保留第一个块，丢弃后续块
      - 重建 tensor，并拼回末尾剩余的 x[m*n:]
    """
    L = x.size(0)
    for n in range(1, 1000):
        m = L // n
        if m <= threshold:
            continue

        # 切块（保留尾部到最后再补回）
        main = x[: m * n].view(m, n)
        tail = x[m * n :]

        # 找相邻块相等
        eq = (main[1:] == main[:-1]).all(dim=1)  # length = m-1
        if not eq.any():
            continue

        eq_int = eq.int()
        # 前后都 pad 一个 0，方便捕捉到开头和结尾的边界
        padded = torch.cat([
            eq_int.new_zeros(1),
            eq_int,
            eq_int.new_zeros(1),
        ])  # length = m+1
        dif = torch.diff(padded)  # length = (m+1)-1 = m

        starts = (dif == 1).nonzero(as_tuple=True)[0]      # run 开始 (在 eq_int 上)
        ends   = (dif == -1).nonzero(as_tuple=True)[0] - 1 # run 结束 (调整到 eq_int)

        mask = torch.ones(m, dtype=torch.bool, device=x.device)
        for s, e in zip(starts.tolist(), ends.tolist()):
            run_len = e - s + 1  # eq_int 上连续为 1 的长度
            if run_len >= threshold:
                # eq[s:e] 为真表示 blocks[s]…blocks[e+1] 全都一样
                # 只保留 blocks[s]，丢弃后续 blocks[s+1]…blocks[e+1]
                drop_idx = torch.arange(s+1, e+2, device=x.device)
                mask[drop_idx] = False

        if mask.all():
            continue

        # 用 mask 重建 main，再拼回 tail
        kept = main[mask].reshape(-1)
        x = torch.cat([kept, tail], dim=0)
        L = x.size(0)

    return x