
from __future__ import annotations
import os
import numpy as np
from contextlib import contextmanager
from typing import TYPE_CHECKING, List
from omegaconf import DictConfig
from tensordict import TensorDict
from verl import DataProto
from verl.workers.rollout.base import BaseRollout
from verl.utils.torch_functional import get_response_mask, pad_sequence_to_length, pad_2d_list_to_length
from sglang.srt.entrypoints.verl_engine import VerlEngine
from torch.distributed.device_mesh import init_device_mesh
from sglang.srt.sampling.sampling_params import SamplingParams
from verl.third_party.sglang import parallel_state as sglang_ps
import torch.distributed
from torch.nn.utils.rnn import pad_sequence
from sglang.srt.utils import broadcast_pyobj, get_ip
from sglang.srt.server_args import PortArgs, ServerArgs

if TYPE_CHECKING:
    from torch import nn



def _pre_process_inputs(pad_token_id, prompt_token_ids: torch.Tensor) -> List[int]:

    non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][0]
    token_ids = prompt_token_ids[non_pad_index:].tolist()
    return token_ids



def _post_process_outputs(tokenizer, output):

    def _map_each_response(l):

        log_probs = []
        output_token_ids = []
        for log_prob, token_ids, _ in l["meta_info"]["output_token_logprobs"]:
            log_probs.append(log_prob)
            output_token_ids.append(token_ids)
        log_probs = torch.tensor(log_probs)
        output_token_ids = torch.tensor(output_token_ids)
        return output_token_ids, log_probs

    out_map = map(lambda x: _map_each_response(x), output)
    batched_output_token_ids = []
    batched_logprobs = []
    for output_token_ids, log_probs in out_map:
        batched_output_token_ids.append(output_token_ids)
        batched_logprobs.append(log_probs)
    pad_token_id = (tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id)
    batched_output_token_ids = pad_sequence(batched_output_token_ids, batch_first=True, padding_value=pad_token_id)
    if len(batched_logprobs) > 0:
        batched_logprobs = pad_sequence(batched_logprobs, batch_first=True, padding_value=pad_token_id)
    return batched_output_token_ids, batched_logprobs


class SGLangRollout(BaseRollout):

    def __init__(
        self,
        actor_module: nn.Module | str,
        config: DictConfig,
        tokenizer,
        model_hf_config,
        **kwargs,
    ):

        super().__init__()
        self.config = config

        assert not (not config.enforce_eager and
                    config.free_cache_engine), "disable CUDA graph (enforce_eager = False) if free cache engine"

        tensor_parallel_size = self.config.get("tensor_model_parallel_size", 1)
        assert (tensor_parallel_size <= torch.distributed.get_world_size()
               ), "tensor parallel size should be less than or equal to the world size"

        if kwargs.get("train_tp", None) is not None:

            os.environ["CUDA_TIMER_STREAM_KAFKA_ENABLE"] = "0"
            os.environ["MEGATRON_IMPORT_TIMERS"] = "0"
            train_tp = kwargs.get("train_tp", None)
            num_tp_per_train_tp = train_tp // tensor_parallel_size
            sglang_ps.initialize_parallel_state(
                tensor_model_parallel_size=tensor_parallel_size,
                num_tp_per_train_tp=num_tp_per_train_tp,
            )

        assert (model_hf_config.max_position_embeddings >= config.prompt_length +
                config.response_length), "model context length should be greater than total sequence length"

        tp_size = tensor_parallel_size
        world_size = int(os.getenv("WORLD_SIZE", "-1"))


        device_mesh_kwargs = dict(
            mesh_shape=(world_size // tp_size, tp_size, 1),
            mesh_dim_names=["dp", "tp", "pp"],
        )
        device_mesh_cpu = init_device_mesh("cpu", **device_mesh_kwargs)

        tp_rank = device_mesh_cpu["tp"].get_local_rank()
        visible_devices = [None] * device_mesh_cpu.size(1)
        torch.distributed.all_gather_object(visible_devices, os.environ["CUDA_VISIBLE_DEVICES"],
                                            device_mesh_cpu.get_group("tp"))
        visible_devices_set = set(visible_devices)
        os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(sorted(list(visible_devices_set)))

        nnodes = -(-tp_size // len(visible_devices_set))
        server_args = ServerArgs(model_path=actor_module, nnodes=nnodes)
        ip, port_args = get_ip(), PortArgs.init_new(server_args)
        [ip, port_args] = broadcast_pyobj([ip, port_args],
                                          rank=tp_rank,
                                          dist_group=device_mesh_cpu.get_group("tp"),
                                          src=device_mesh_cpu["tp"].mesh[0].item())
        dist_init_addr = f"{ip}:{port_args.nccl_port}"

        self.inference_engine = VerlEngine(
            model_path=actor_module,
            dtype=config.dtype,
            mem_fraction_static=config.gpu_memory_utilization,
            device_mesh_cpu=device_mesh_cpu["tp"],
            enable_memory_saver=True,
            base_gpu_id=0,
            gpu_id_step=1,
            dist_init_addr=dist_init_addr,
            nnodes=nnodes

        )


        self.inference_engine.release_memory_occupation()

        kwargs = dict(n=1,
                      max_new_tokens=config.response_length,
                      presence_penalty=0.0,
                      frequency_penalty=0.0,
                      repetition_penalty=1.0)

        for k in config.keys():
            if hasattr(SamplingParams(), str(k)):
                kwargs[k] = config.get(k)
        print(f"kwargs: {kwargs}")
        self.sampling_params = kwargs

        self.tokenizer = tokenizer
        self.pad_token_id = tokenizer.pad_token_id

    @contextmanager
    def update_sampling_params(self, **kwargs):

        old_sampling_params_args = {}
        if kwargs:
            for key, value in kwargs.items():
                if key in self.sampling_params:
                    old_value = self.sampling_params[key]
                    old_sampling_params_args[key] = old_value
                    self.sampling_params[key] = value
        yield

        for key, value in old_sampling_params_args.items():
            self.sampling_params[key] = value

    @torch.no_grad()
    def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:


        idx = prompts.batch["input_ids"]
        attention_mask = prompts.batch["attention_mask"]
        position_ids = prompts.batch["position_ids"]


        eos_token_id = prompts.meta_info["eos_token_id"]

        batch_size = idx.size(0)


        non_tensor_batch = prompts.non_tensor_batch
        if 'raw_prompt_ids' not in non_tensor_batch:
            non_tensor_batch['raw_prompt_ids'] = np.array(
                [_pre_process_inputs(self.pad_token_id, idx[i]) for i in range(batch_size)], dtype=object)

        if 'multi_modal_data' in non_tensor_batch:
            sglang_inputs = []
            for raw_prompt_ids, multi_modal_data in zip(non_tensor_batch.pop('raw_prompt_ids'),
                                                        non_tensor_batch.pop('multi_modal_data')):
                sglang_inputs.append({
                    'prompt_token_ids': raw_prompt_ids,
                    'multi_modal_data': multi_modal_data,
                    'image_data': multi_modal_data.get('image', None) if isinstance(multi_modal_data, dict) else None
                })
        else:
            sglang_inputs = [{
                'prompt_token_ids': raw_prompt_ids
            } for raw_prompt_ids in non_tensor_batch.pop('raw_prompt_ids')]


        for input_data in sglang_inputs:
            if isinstance(input_data['prompt_token_ids'], np.ndarray):
                input_data['prompt_token_ids'] = input_data['prompt_token_ids'].tolist()
            elif not isinstance(input_data['prompt_token_ids'], list):
                raise TypeError(
                    f"prompt_token_ids must be a list or numpy array, got {type(input_data['prompt_token_ids'])}")


        idx_list = [input_data['prompt_token_ids'] for input_data in sglang_inputs]
        image_list = [input_data.get('image_data', None) for input_data in sglang_inputs]

        do_sample = prompts.meta_info.get("do_sample", True)
        if not do_sample:
            kwargs = dict(
                n=1,
                presence_penalty=0.0,
                frequency_penalty=0.0,
                repetition_penalty=1.0,
                temperature=0,
                top_p=1,
                top_k=-1,
                ignore_eos=False,
                min_new_tokens=0,
                max_new_tokens=self.config.response_length,
                skip_special_tokens=True,
                spaces_between_special_tokens=True,
            )

        with self.update_sampling_params(**kwargs):
            print(f"{self.sampling_params=}")
            output = self.inference_engine.generate(
                prompt=None,
                sampling_params=self.sampling_params,
                return_logprob=True,
                input_ids=idx_list,
                image_data=image_list)

        out = _post_process_outputs(self.tokenizer, output)

        response = out[0].to(idx.device)
        log_probs = out[1].to(idx.device)

        if response.shape[1] < self.config.response_length:
            response = pad_sequence_to_length(response, self.config.response_length, self.pad_token_id)
            log_probs = pad_sequence_to_length(log_probs, self.config.response_length, self.pad_token_id)
        if self.config.n > 1 and do_sample:
            idx = idx.repeat_interleave(self.config.n, dim=0)
            attention_mask = attention_mask.repeat_interleave(self.config.n, dim=0)
            position_ids = position_ids.repeat_interleave(self.config.n, dim=0)
            batch_size = batch_size * self.config.n
            if 'multi_modal_inputs' in non_tensor_batch:
                non_tensor_batch['multi_modal_inputs'] = np.repeat(non_tensor_batch['multi_modal_inputs'],
                                                                   self.config.n,
                                                                   axis=0)
        seq = torch.cat([idx, response], dim=-1)

        response_length = response.size(1)
        delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device)
        delta_position_id = delta_position_id.unsqueeze(0).repeat(batch_size, 1)


        response_position_ids = position_ids[:, -1:] + delta_position_id
        position_ids = torch.cat([position_ids, response_position_ids], dim=-1)
        response_attention_mask = get_response_mask(response_id=response,
                                                    eos_token=eos_token_id,
                                                    dtype=attention_mask.dtype)
        attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1)


        batch = TensorDict(
            {
                "prompts": idx,
                "responses": response,
                "input_ids": seq,
                # 'old_log_probs': log_probs,
                'response_log_probs': log_probs,
                "attention_mask": attention_mask,
                "position_ids": position_ids,
            },
            batch_size=batch_size,
        )


        if (self.config.free_cache_engine and self.inference_engine._engine is not None and
                self.inference_engine._engine.tokenizer_manager is not None):
            self.inference_engine._engine.tokenizer_manager.flush_cache()

        return DataProto(batch=batch, non_tensor_batch=non_tensor_batch)
