# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The vllm_rollout that can be applied in different backend
When working with FSDP:
- Use DTensor weight loader (recommended) or HF weight loader
- Utilize state_dict from the FSDP to synchronize the weights among tp ranks in vLLM
When working with Megatron:
- Use Megatron weight loader
- During training, only the current pp stage holds the parameters
- Before inference, broadcast the parameters of the current pp rank to all other pp ranks (all pp ranks holds all the parameters)
- Bind the parameters to the inference engine
- Do inference in tp. pp is treated as additional dp
- After inference, all the parameters that doesn't belong to this pp rank is freed.
"""

from omegaconf import DictConfig
import torch
import torch.distributed
from tensordict import TensorDict
from torch import nn

from verl import DataProto
from verl.utils.py_functional import to_1d_np_array
from verl.workers.rollout.base import BaseRollout
from verl.third_party.vllm import LLM, vllm_version
from verl.third_party.vllm import parallel_state as vllm_ps
from vllm import SamplingParams

import random

from contextlib import contextmanager
from codetiming import Timer
from typing import Dict
@contextmanager
def _timer(name: str, timing_raw: Dict[str, float]):
    with Timer(name=name, logger=None) as timer:
        yield
    timing_raw[name] = timer.last

import requests
# 服务基础URL
# BASE_URL = "http://g2:8000"
def generate_text(input_messages, max_tokens=128, temperature=0.7, base_url="http://g2:8000"):
    headers = {
        "Content-Type": "application/json",
        "Accept": "application/json"
    }
    
    data = {
        "input_messages": input_messages,
        "temperature": temperature,
        "max_tokens": max_tokens, 
    }
    
    response = requests.post(
        f"{base_url}/chat",
        headers=headers,
        json=data
    )
    return response.json()["generated_texts"]


from verl.manager import MultiEnvManager, BufferManager

class vLLMMultiTurnViaEnv(BaseRollout):

    def __init__(self, actor_module: nn.Module, config: DictConfig, tokenizer, model_hf_config, **kwargs):
        """A vLLM rollout. It requires the module is supported by the vllm.

        Args:
            module: module here follows huggingface APIs
            config: DictConfig
            tokenizer: the task/model tokenizer
            model_hf_config: the huggingface config to initiallize the generating model in vllm
            **kwargs: train_tp, for Megatron Backend to initialize hybrid engine (zero redundancy) process group
        """
        super().__init__()
        self.config = config
        assert not (not config.enforce_eager and config.free_cache_engine), \
            "disable CUDA graph (enforce_eager = False) if free cache engine"

        tensor_parallel_size = self.config.get('tensor_model_parallel_size', 1)
        assert tensor_parallel_size <= torch.distributed.get_world_size(), \
            "tensor parallel size should be less than or equal to the world size"
        max_num_batched_tokens = self.config.get('max_num_batched_tokens', 8192)

        if kwargs.get('train_tp', None) is not None:
            # deployed with megatron
            import os
            os.environ['CUDA_TIMER_STREAM_KAFKA_ENABLE'] = '0'
            os.environ['MEGATRON_IMPORT_TIMERS'] = '0'
            train_tp = kwargs.get('train_tp', None)
            num_tp_per_train_tp = train_tp // tensor_parallel_size
            if vllm_version in ('0.4.2', '0.5.4', '0.6.3'):
                vllm_ps.initialize_parallel_state(tensor_model_parallel_size=tensor_parallel_size,
                                                  num_tp_per_train_tp=num_tp_per_train_tp)

        assert model_hf_config.max_position_embeddings >= config.prompt_length + config.response_length, \
            "model context length should be greater than total sequence length"
        self.tokenizer = tokenizer
        self.total_length = config.prompt_length + config.response_length
        print("vLLM Total Length: ", self.total_length)
        self.inference_engine = LLM(
            actor_module,
            tokenizer=tokenizer,
            model_hf_config=model_hf_config,
            tensor_parallel_size=tensor_parallel_size,
            dtype=config.dtype,
            enforce_eager=config.enforce_eager,
            gpu_memory_utilization=config.gpu_memory_utilization,
            skip_tokenizer_init=False,
            max_model_len=self.total_length,
            load_format=config.load_format,
            disable_log_stats=config.disable_log_stats,
            max_num_batched_tokens=max_num_batched_tokens,
            enable_chunked_prefill=config.enable_chunked_prefill,
        )

        # Offload vllm model to reduce peak memory usage
        self.inference_engine.offload_model_weights()

        kwargs = dict(
            n=1,
            logprobs=1,  # can be set to 0 and let actor to recompute
            max_tokens=config.environment.actor_length
        )

        # we may detokenize the result all together later
        if vllm_version in ('0.4.2', '0.5.4', '0.6.3'):
            kwargs['detokenize'] = False

        # supporting adding any sampling params from the config file
        for k in config.keys():
            if hasattr(SamplingParams(), str(k)):
                kwargs[k] = config.get(k)
        
        #Important:
        #This is multi turn, so we need to set n=1 for sampling params, as we will manually batch n since some samplings might terminate earlier.
        kwargs['n']=1

        #print(f"kwargs: {kwargs}")
        self.sampling_params = SamplingParams(**kwargs)

        self.pad_token_id = tokenizer.pad_token_id

    @contextmanager
    def update_sampling_params(self, **kwargs):
        # update sampling params
        old_sampling_params_args = {}
        if kwargs:
            for key, value in kwargs.items():
                if hasattr(self.sampling_params, key):
                    old_value = getattr(self.sampling_params, key)
                    old_sampling_params_args[key] = old_value
                    setattr(self.sampling_params, key, value)
        yield
        # roll back to previous sampling params
        # if len(old_sampling_params_args):
        for key, value in old_sampling_params_args.items():
            setattr(self.sampling_params, key, value)

    def get_n_tokens(self,prompt,add_generation_prompt=False):
        return len(self.tokenizer.apply_chat_template(prompt,tokenize=True,add_generation_prompt=add_generation_prompt))
    
    def tokenize_with_assistant_mask(self,messages):
        n_messages=len(messages)
        tokenized_messages=self.tokenizer.apply_chat_template(messages,tokenize=True,add_generation_prompt=False)
        head=0
        assistant_mask=[]
        for i_last_message in range(n_messages):
            if (i_last_message!=n_messages-1) and (messages[i_last_message+1]["role"]=="assistant"):
                is_next_assistant=True
            else:
                is_next_assistant=False
            last_message_role=messages[i_last_message]["role"]
            n_tokens_with_last_message=self.get_n_tokens(messages[:i_last_message+1],add_generation_prompt=is_next_assistant)
            n_add=n_tokens_with_last_message-head
            if last_message_role=="assistant":
                assistant_mask.append(torch.ones(n_add,dtype=torch.bool))
            else:
                assistant_mask.append(torch.zeros(n_add,dtype=torch.bool))
            head+=n_add
        assistant_mask=torch.cat(assistant_mask,dim=0)
        assert len(assistant_mask)==len(tokenized_messages), "Bug: assistant mask length mismatch"
        return tokenized_messages,assistant_mask

    @torch.no_grad()
    def generate_sequences(self, prompts: DataProto,**kwargs)->DataProto:#, environment, **kwargs) -> DataProto: #see verl/single_controller/base/decorator l.54, we can't send these classes as usual.
        print("ROLLOUT STAGE START!")
        ###################################
        #### rebuild vllm cache engine ####
        ###################################
        if self.config.free_cache_engine:
            self.inference_engine.init_cache_engine()
        
        idx = prompts.batch['input_ids'] #just for device, size
        batch_size = idx.size(0)
  
        do_sample = prompts.meta_info.get('do_sample', True)

        # 获取不同并行维度下的 rank
        rank = vllm_ps.get_tensor_model_parallel_src_rank()  # 源 rank
        tp_rank = vllm_ps.get_tensor_model_parallel_rank()  # 模型并行 rank
        global_size = torch.distributed.get_world_size()
        global_rank = torch.distributed.get_rank()  # 全局 rank
        print(f"SRC RANK: {rank}; TP RANK: {tp_rank}; GLOBAL SIZE: {global_size}; GLOBAL RANK: {global_rank}")

        random.seed(global_rank)
      
        ###########################################
        #### acquire env configs and init envs ####
        ###########################################
        if prompts.meta_info.get('validate', False):
            print("******* TESTING *******")
            n = 1
        else:
            n = self.config.n

        #####################
        #### env configs ####
        #####################
        total_env_infos = []
        total_env_names = []
        for extra_info in prompts.non_tensor_batch['extra_info']:
            for _ in range(n):
                uid = extra_info["uid"]
                env_name = extra_info["env_name"]
                total_env_names.append(env_name)
                total_env_infos.append({
                    "uid": uid,
                    "env_name": env_name,
                    "env_config": extra_info,
                    "special_settings": {},
                })
        
        timing_raw = {}
        with _timer('init env', timing_raw):
            env_manager = MultiEnvManager(total_env_infos)  # 通过本地或者远程启动多个进程
            initial_feedbacks = env_manager.init_envs()     # 通过多进程并行初始化多个环境，并返回初始化信息，如task, init_obs等
        if global_rank == 0:
            print(f"GLOBAL RANK {global_rank} TIMEING RAW: ", timing_raw)
        
        ######################################
        #### init buffer for info logging ####
        ######################################
        buffer_manager = BufferManager(initial_feedbacks) #进行数据记录和prompt生成
        
        ##############################################
        #### interactive with env until max-turns ####
        ##############################################
        MAX_TURNS = self.config.get('max_turns', 3)  # max-turns = 3 is just for testing
        update_role = self.config.get("train_actor_or_thinker")
        assert update_role in ['actor', 'thinker']

        if update_role == "thinker":
            base_url = random.choice(self.config.fixed_actor_api)  # 启用随机actor api
            print("Using BASE URL: ", base_url)
            #########################################
            #### prepare input prompts (summary) ####
            #########################################
            running_ids = list(range(len(initial_feedbacks)))
            summary_ids, messagess_todo_v2 = buffer_manager.build_prompts_for_deepthinks(running_ids, force=True)
            print("Deepthink number: {}".format(len(messagess_todo_v2)))

            with _timer('vllm sampling thinker', timing_raw):
                kwargs = {
                    'top_p': 0.9,
                    'top_k': -1,
                    'min_p': 0.0,
                    'temperature': 1.0 if do_sample else 0.0,
                    'n': 1,
                    "max_tokens": self.config.environment.thinker_length,
                }
                with self.update_sampling_params(**kwargs):
                    assert self.sampling_params.n == 1, "n should be 1 for multi-turn"
                    outputs = self.inference_engine.chat(
                        messages=messagess_todo_v2,
                        sampling_params=self.sampling_params,
                        use_tqdm=False)
                responses = outputs[0].to(idx.device)
                assert responses.shape[1] <= self.config.environment.thinker_length,"Bug: response too long from vllm: " + str(responses.shape)

            ##################################
            #### preprocess the responses ####
            ##################################
            # 对response的详细处理可以集成到环境类中，因环境而异, 先对Response进行预处理
            with _timer('response decoding v2', timing_raw):    
                response_texts = []
                for response in responses:
                    try:
                        response_text = self.tokenizer.decode(response, skip_special_tokens=True)
                    except:
                        print("Decode Error: ", response)
                        response = response.to(torch.int64)
                        response_text = self.tokenizer.decode(response, skip_special_tokens=True)
                    response_texts.append(response_text)
                
            ###################################
            #### postprocess the feedbacks ####
            ###################################
            with _timer('update buffer v2', timing_raw):
                buffer_manager.update_trajectory_for_deepthinks(summary_ids, response_texts)
   

        while True:
            # Break at max-turns
            if buffer_manager.step >= MAX_TURNS:
                break
            
            ###############################
            #### prepare input prompts ####
            ###############################
            if update_role == "thinker":
                running_ids, messagess_todo = buffer_manager.build_prompts_earlystop_for_actors() # 1.对外返回todo list 2.对内记录 running_ids
            elif update_role == "actor":
                running_ids, messagess_todo = buffer_manager.build_prompts_for_actors()
            else:
                raise ValueError("Unknown update_role value.")

            # Break when no tasks
            if len(messagess_todo) == 0:
                break
        
            timing_raw = {}

            if update_role == "thinker":
                ##############################
                #### generate by vLLM API ####
                ##############################
                with _timer('vllm sampling api (actor)', timing_raw):
                    response_texts = generate_text(messagess_todo, max_tokens=128, temperature=0.0, base_url=base_url)

            else:
                ##########################
                #### generate by vLLM ####
                ##########################
                with _timer('vllm sampling actor', timing_raw):
                    kwargs = {
                        'top_p': 0.9,
                        'top_k': -1,
                        'min_p': 0.0,
                        'temperature': 1.0 if do_sample else 0.0,
                        'n': 1,
                        "max_tokens": self.config.environment.actor_length,
                    }
                    with self.update_sampling_params(**kwargs):
                        assert self.sampling_params.n == 1, "n should be 1 for multi-turn"
                        outputs = self.inference_engine.chat(
                            messages=messagess_todo,
                            sampling_params=self.sampling_params,
                            use_tqdm=False)
                    responses = outputs[0].to(idx.device)
                    assert responses.shape[1] <= self.config.environment.actor_length,"Bug: response too long from vllm: " + str(responses.shape)
            

                with _timer('response decoding', timing_raw):    
                    response_texts = []
                    for response in responses:
                        try:
                            response_text = self.tokenizer.decode(response, skip_special_tokens=True)
                        except:
                            print("Decode Error: ", response)
                            response = response.to(torch.int64)
                            response_text = self.tokenizer.decode(response, skip_special_tokens=True)
                        response_texts.append(response_text)
            
            #################################################
            #### execute in environment and get feedback ####
            #################################################
            with _timer('action executing', timing_raw):
                feedbacks = env_manager.execute_actions(running_ids, response_texts)

            ###################################
            #### postprocess the feedbacks ####
            ###################################
            with _timer('postprocessing', timing_raw):
                buffer_manager.update_trajectory(running_ids, response_texts, feedbacks)
                buffer_manager.step += 1

            if global_rank == 0:
                print(f"GLOBAL RANK {global_rank}, STEP: {buffer_manager.step}, TIMEING RAW: ", timing_raw)
                
        
        ###################################
        #### update score for all envs ####
        ###################################
        timing_raw = {}
        with _timer('compute final score (format score, deepthink score)', timing_raw):
            buffer_manager.update_final_score()
        if global_rank == 0:
            print(timing_raw)


        idx = idx.repeat_interleave(n, dim=0)
        batch = TensorDict({"input_ids": idx}, batch_size=batch_size * n, device=idx.device) # 占位符
        non_tensor_batch = {
            "batch_rollout_data": to_1d_np_array(buffer_manager.batch_rollout_data),
            "data_source": to_1d_np_array(total_env_names),
        }
        
        #####################################
        #### clear all envs and shutdown ####
        #####################################
        env_manager.shutdown()
        del env_manager
        del buffer_manager
        
        ################################
        #### free vllm cache engine ####
        ################################
        if self.config.free_cache_engine:
            self.inference_engine.free_cache_engine()
        print("ROLLOUT STAGE DONE!")
     
        return DataProto(batch=batch, non_tensor_batch=non_tensor_batch)
    
