# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations
import os
import numpy as np
from contextlib import contextmanager
from typing import TYPE_CHECKING, List
from omegaconf import DictConfig
from tensordict import TensorDict
from verl import DataProto
from verl.workers.rollout.base import BaseRollout
from verl.utils.torch_functional import get_response_mask, pad_sequence_to_length, pad_2d_list_to_length
from sglang.srt.entrypoints.verl_engine import VerlEngine
from torch.distributed.device_mesh import init_device_mesh
from sglang.srt.sampling.sampling_params import SamplingParams
from verl.third_party.sglang import parallel_state as sglang_ps
import torch.distributed
from torch.nn.utils.rnn import pad_sequence
from sglang.srt.utils import broadcast_pyobj, get_ip
from sglang.srt.server_args import PortArgs, ServerArgs

if TYPE_CHECKING:
    from torch import nn


# NOTE(sgm): add for verl. We can optimize it by making the dataloader yield List[int] without padding.
def _pre_process_inputs(pad_token_id, prompt_token_ids: torch.Tensor) -> List[int]:
    # remove the left padding in the prompt token_id
    # pad_token_id = self.llm_engine.tokenizer.pad_token_id if self.llm_engine.tokenizer.pad_token_id is not None else self.llm_engine.tokenizer.eos_token_id
    non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][0]
    token_ids = prompt_token_ids[non_pad_index:].tolist()
    return token_ids


# NOTE(linjunrong): adhoc
def _post_process_outputs(tokenizer, output):

    def _map_each_response(l):
        # output_token_ids = torch.tensor(l['token_ids'])
        log_probs = []
        output_token_ids = []
        for log_prob, token_ids, _ in l["meta_info"]["output_token_logprobs"]:
            log_probs.append(log_prob)
            output_token_ids.append(token_ids)
        log_probs = torch.tensor(log_probs)
        output_token_ids = torch.tensor(output_token_ids)
        return output_token_ids, log_probs

    out_map = map(lambda x: _map_each_response(x), output)
    batched_output_token_ids = []
    batched_logprobs = []
    for output_token_ids, log_probs in out_map:
        batched_output_token_ids.append(output_token_ids)
        batched_logprobs.append(log_probs)
    pad_token_id = (tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id)
    batched_output_token_ids = pad_sequence(batched_output_token_ids, batch_first=True, padding_value=pad_token_id)
    if len(batched_logprobs) > 0:
        batched_logprobs = pad_sequence(batched_logprobs, batch_first=True, padding_value=pad_token_id)
    return batched_output_token_ids, batched_logprobs

def _ungroup_outputs(n2outputs, n2idx):
    """
    n2outputs: map from n to a list of outputs
    n2idx: map from n to a list of indices, each index corresponds to the index of the input in the original batch

    This function de-groups the n2outputs into a list of output, following the index in n2idx
    """
    idx2output = {}
    for n in n2outputs:
        assert len(n2outputs[n]) == len(n2idx[n]) * n, f"len(n2outputs[n]): {len(n2outputs[n])}, len(n2idx[n]): {len(n2idx[n])}, n: {n}"
        for i, idx in enumerate(n2idx[n]):
            idx2output[idx] = n2outputs[n][i*n:(i+1)*n]
    flatten_output = []
    for idx in range(len(idx2output)):
        flatten_output += idx2output[idx]
    return flatten_output


class SGLangRollout(BaseRollout):

    def __init__(
        self,
        actor_module: nn.Module | str,
        config: DictConfig,
        tokenizer,
        model_hf_config,
        **kwargs,
    ):
        """A SGLang rollout. It requires the module is supported by the SGLang.

        Args:
            actor_module: module here follows huggingface APIs
            config: DictConfig
            tokenizer: the task/model tokenizer
            model_hf_config: the huggingface config to initiallize the generating model in SGLang
            **kwargs: train_tp, for Megatron Backend to initialize hybrid engine (zero redundancy) process group
        """
        super().__init__()
        self.config = config

        assert not (not config.enforce_eager and
                    config.free_cache_engine), "disable CUDA graph (enforce_eager = False) if free cache engine"

        tensor_parallel_size = self.config.get("tensor_model_parallel_size", 1)
        assert (tensor_parallel_size <= torch.distributed.get_world_size()
               ), "tensor parallel size should be less than or equal to the world size"

        if kwargs.get("train_tp", None) is not None:
            # deployed with megatron
            os.environ["CUDA_TIMER_STREAM_KAFKA_ENABLE"] = "0"
            os.environ["MEGATRON_IMPORT_TIMERS"] = "0"
            train_tp = kwargs.get("train_tp", None)
            num_tp_per_train_tp = train_tp // tensor_parallel_size
            sglang_ps.initialize_parallel_state(
                tensor_model_parallel_size=tensor_parallel_size,
                num_tp_per_train_tp=num_tp_per_train_tp,
            )

        assert (model_hf_config.max_position_embeddings >= config.prompt_length +
                config.response_length), "model context length should be greater than total sequence length"

        tp_size = tensor_parallel_size
        world_size = int(os.getenv("WORLD_SIZE", "-1"))

        # init device mesh
        device_mesh_kwargs = dict(
            mesh_shape=(world_size // tp_size, tp_size, 1),
            mesh_dim_names=["dp", "tp", "pp"],
        )
        device_mesh_cpu = init_device_mesh("cpu", **device_mesh_kwargs)
        # device_mesh_device = init_device_mesh("cuda", **device_mesh_kwargs)

        # get tp_rank of this process in this tp group
        tp_rank = device_mesh_cpu["tp"].get_local_rank()
        visible_devices = [None] * device_mesh_cpu.size(1)
        torch.distributed.all_gather_object(visible_devices, os.environ["CUDA_VISIBLE_DEVICES"],
                                            device_mesh_cpu.get_group("tp"))
        visible_devices_set = set(visible_devices)
        os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(sorted(list(visible_devices_set)))

        nnodes = -(-tp_size // len(visible_devices_set))
        server_args = ServerArgs(model_path=actor_module, nnodes=nnodes)
        ip, port_args = get_ip(), PortArgs.init_new(server_args)
        [ip, port_args] = broadcast_pyobj([ip, port_args],
                                          rank=tp_rank,
                                          dist_group=device_mesh_cpu.get_group("tp"),
                                          src=device_mesh_cpu["tp"].mesh[0].item())
        dist_init_addr = f"{ip}:{port_args.nccl_port}"
        load_format = 'dummy' if config.load_format.startswith('dummy') else config.load_format
        self.inference_engine = VerlEngine(
            model_path=actor_module,
            dtype=config.dtype,
            mem_fraction_static=config.gpu_memory_utilization,
            device_mesh_cpu=device_mesh_cpu["tp"],
            enable_memory_saver=True,
            base_gpu_id=0,
            gpu_id_step=1,
            load_format=load_format,
            dist_init_addr=dist_init_addr,
            nnodes=nnodes,
            max_prefill_tokens=8192,
            max_running_requests=config.micro_rollout_batch_size
            # NOTE(Chenyang): if you want to debug the sglang engine
            # please set the following parameters
            # Otherwise, it will make the engine run too slow
            # log_level="INFO",
            # log_requests=True,
            # log_requests_level=2,
            # max_running_requests=1,
        )

        # offload
        self.inference_engine.release_memory_occupation()

        kwargs = dict(n=1,
                      max_new_tokens=config.response_length,
                      presence_penalty=0.0,
                      frequency_penalty=0.0,
                      repetition_penalty=1.0)
        # supporting adding any sampling params from the config file
        for k in config.keys():
            if hasattr(SamplingParams(), str(k)):
                kwargs[k] = config.get(k)
        print(f"kwargs: {kwargs}")
        self.sampling_params = kwargs

        self.tokenizer = tokenizer
        self.pad_token_id = tokenizer.pad_token_id

    @contextmanager
    def update_sampling_params(self, **kwargs):
        # update sampling params
        old_sampling_params_args = {}
        if kwargs:
            for key, value in kwargs.items():
                if key in self.sampling_params:
                    old_value = self.sampling_params[key]
                    old_sampling_params_args[key] = old_value
                    self.sampling_params[key] = value
        yield
        # roll back to previous sampling params
        # if len(old_sampling_params_args):
        for key, value in old_sampling_params_args.items():
            self.sampling_params[key] = value

    @torch.no_grad()
    def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:
        # if self.config.free_cache_engine:

        idx = prompts.batch["input_ids"]  # (bs, prompt_length)
        # left-padded attention_mask
        attention_mask = prompts.batch["attention_mask"]
        position_ids = prompts.batch["position_ids"]

        # used to construct attention_mask
        eos_token_id = prompts.meta_info["eos_token_id"]

        # TODO: get n for each sample from prompts.meta_info["n"]
        sample_specific_n = None
        if "n" in prompts.non_tensor_batch:
            sample_specific_n = prompts.non_tensor_batch["n"]

        # temp TODO: remove this assertion
        # assert sample_specific_n is not None, f"prompts: {prompts}"
        
        batch_size = idx.size(0)

        # Extract non-tensor data
        non_tensor_batch = prompts.non_tensor_batch
        if 'raw_prompt_ids' not in non_tensor_batch:
            non_tensor_batch['raw_prompt_ids'] = np.array(
                [_pre_process_inputs(self.pad_token_id, idx[i]) for i in range(batch_size)], dtype=object)

        if 'multi_modal_data' in non_tensor_batch:
            sglang_inputs = []
            for raw_prompt_ids, multi_modal_data in zip(non_tensor_batch.pop('raw_prompt_ids'),
                                                        non_tensor_batch.pop('multi_modal_data')):
                sglang_inputs.append({
                    'prompt_token_ids': raw_prompt_ids,
                    'multi_modal_data': multi_modal_data,
                    'image_data': multi_modal_data.get('image', None) if isinstance(multi_modal_data, dict) else None
                })
        else:
            sglang_inputs = [{
                'prompt_token_ids': raw_prompt_ids
            } for raw_prompt_ids in non_tensor_batch.pop('raw_prompt_ids')]

        # Ensure token IDs are lists
        for input_data in sglang_inputs:
            if isinstance(input_data['prompt_token_ids'], np.ndarray):
                input_data['prompt_token_ids'] = input_data['prompt_token_ids'].tolist()
            elif not isinstance(input_data['prompt_token_ids'], list):
                raise TypeError(
                    f"prompt_token_ids must be a list or numpy array, got {type(input_data['prompt_token_ids'])}")

        # Extract token IDs and image data for SGLang Engine
        idx_list = [input_data['prompt_token_ids'] for input_data in sglang_inputs]
        image_list = [input_data.get('image_data', None) for input_data in sglang_inputs]

        do_sample = prompts.meta_info.get("do_sample", True)
        if not do_sample:
            kwargs = dict(
                n=1,
                presence_penalty=0.0,
                frequency_penalty=0.0,
                repetition_penalty=1.0,
                temperature=0,
                top_p=1,
                top_k=-1,
                ignore_eos=False,
                min_new_tokens=0,
                max_new_tokens=self.config.response_length,
                skip_special_tokens=True,
                spaces_between_special_tokens=True,
            )
        # users can customize different sampling_params at different run
        if sample_specific_n is not None:
            assert len(sample_specific_n) == len(idx_list), f"sample_specific_n: {sample_specific_n}, idx_list: {idx_list}"
            # group idx_list by sample_specific_n
            unique_n = set(sample_specific_n)
            n2idx = {n: [] for n in unique_n}
            n2ids_inputs = {n: [] for n in unique_n}
            n2image_inputs = {n: [] for n in unique_n}
            for i, (n, ids) in enumerate(zip(sample_specific_n, idx_list)):
                n2idx[n].append(i)
                n2ids_inputs[n].append(ids)
                n2image_inputs[n].append(image_list[i])
        else:
            unique_n = [self.config.n]
            n2idx = {self.config.n: list(range(len(idx_list)))}
            n2ids_inputs = {self.config.n: idx_list}
            n2image_inputs = {self.config.n: image_list}

        # generate
        n2outputs = {}
        for n in unique_n:
            idx_inputs = n2ids_inputs[n]
            image_inputs = n2image_inputs[n]
            with self.update_sampling_params(**{**kwargs, "n": n}):
                print(f"{self.sampling_params=}")
                output_group = self.inference_engine.generate(
                    prompt=None,  # because we have already convert it to prompt token id
                    sampling_params=self.sampling_params,
                    return_logprob=True,
                    input_ids=idx_inputs,
                    image_data=image_inputs)
            n2outputs[n] = output_group
        
        # post process
        # print("example before ungrouping: ")
        # for n in unique_n:
        #     print(type(n2outputs[n]))
        #     print(len(n2outputs[n][0]))
        #     print(n2outputs[n][0])
        #     break
        output = _ungroup_outputs(n2outputs, n2idx)
        # print("example after ungrouping: ")
        # print(type(output))
        # print(type(output[0]))
        # print(output[0])

        # with self.update_sampling_params(**kwargs):
        #     print(f"{self.sampling_params=}")
        #     output = self.inference_engine.generate(
        #         prompt=None,  # because we have already convert it to prompt token id
        #         sampling_params=self.sampling_params,
        #         return_logprob=True,
        #         input_ids=idx_list,
        #         image_data=image_list)

        out = _post_process_outputs(self.tokenizer, output)

        response = out[0].to(idx.device)
        # log_probs = out[1].to(idx.device)

        if response.shape[1] < self.config.response_length:
            response = pad_sequence_to_length(response, self.config.response_length, self.pad_token_id)
            # log_probs = pad_sequence_to_length(log_probs, self.config.response_length, self.pad_token_id)
        
        # TODO: use different n for each sample
        if sample_specific_n is not None:
            # convert numpy.object to int64
            sample_specific_n = np.array(sample_specific_n, dtype=np.int64)
            sample_specific_n = torch.tensor(sample_specific_n).to(idx.device)
            idx = idx.repeat_interleave(sample_specific_n, dim=0)
            attention_mask = attention_mask.repeat_interleave(sample_specific_n, dim=0)
            position_ids = position_ids.repeat_interleave(sample_specific_n, dim=0)
            batch_size = sample_specific_n.sum().item()
            if 'multi_modal_inputs' in non_tensor_batch:
                non_tensor_batch['multi_modal_inputs'] = np.repeat(non_tensor_batch['multi_modal_inputs'],
                                                                    sample_specific_n,
                                                                    axis=0)
        else:
            if self.config.n > 1 and do_sample:
                idx = idx.repeat_interleave(self.config.n, dim=0)
                attention_mask = attention_mask.repeat_interleave(self.config.n, dim=0)
                position_ids = position_ids.repeat_interleave(self.config.n, dim=0)
                batch_size = batch_size * self.config.n
                if 'multi_modal_inputs' in non_tensor_batch:
                    non_tensor_batch['multi_modal_inputs'] = np.repeat(non_tensor_batch['multi_modal_inputs'],
                                                                    self.config.n,
                                                                    axis=0)
        seq = torch.cat([idx, response], dim=-1)

        response_length = response.size(1)
        delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device)
        delta_position_id = delta_position_id.unsqueeze(0).repeat(batch_size, 1)

        # TODO(sgm): fix position_ids on right_pad
        # prompt: left pad + response: right pad
        # attention_mask: [0,0,0,0,1,1,1,1, | 1,1,1,0,0,0,0,0]
        # position_ids:   [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11]
        response_position_ids = position_ids[:, -1:] + delta_position_id
        position_ids = torch.cat([position_ids, response_position_ids], dim=-1)
        response_attention_mask = get_response_mask(response_id=response,
                                                    eos_token=eos_token_id,
                                                    dtype=attention_mask.dtype)
        attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1)

        # all the tp ranks should contain the same data here. data in all ranks are valid
        batch = TensorDict(
            {
                "prompts": idx,
                "responses": response,
                "input_ids": seq,  # here input_ids become the whole sentences
                # 'old_log_probs': log_probs, # we will recompute old log prob with actor
                "attention_mask": attention_mask,
                "position_ids": position_ids,
            },
            batch_size=batch_size,
        )

        # free cache engine
        if (self.config.free_cache_engine and self.inference_engine._engine is not None and
                self.inference_engine._engine.tokenizer_manager is not None):
            self.inference_engine._engine.tokenizer_manager.flush_cache()
        if "n"  in non_tensor_batch:
            non_tensor_batch.pop("n")
        return DataProto(batch=batch, non_tensor_batch=non_tensor_batch)