import torch
from openai.types.chat.chat_completion import ChatCompletion
from tensordict import TensorDict

from verl.protocol import DataProto
from verl.workers.rollout.async_server import ChatCompletionScheduler
import verl.utils.torch_functional as verl_F
from typing import Dict, Any, Tuple, List
from verl.tools.async_tools import ToolType
from verl.workers.rollout.vllm_rollout.token_list_serving_chat import (
    TokenListChatCompletionResponse,
)
import asyncio
import importlib.util
import uuid
import aiohttp
from functools import partial
import numpy as np


def load_dict_from_py(file_path, dict_name):
    import os
    import sys

    # Generate a unique module name to avoid conflicts in sys.modules
    module_name = f"_tmp_mod_{uuid.uuid4().hex}"
    # Load the .py file as a Python module
    file_path = os.path.abspath(file_path)
    spec = importlib.util.spec_from_file_location(module_name, file_path)
    mod = importlib.util.module_from_spec(spec)
    sys.modules[module_name] = mod  # Register the module
    spec.loader.exec_module(mod)  # Execute the module (runs the code)
    # Get the dictionary by name (e.g., "func_map") from the loaded module
    func_map = getattr(mod, dict_name)
    return func_map


def get_mark(mark_data: List[torch.tensor], response_len: int, mark_value: int):
    mark = torch.zeros((len(mark_data), response_len), dtype=torch.int)

    row_index = []
    for batch_index, x in enumerate(mark_data):
        row_index += [batch_index] * x.size(0)
    col_index = torch.cat(mark_data)
    row_index = torch.IntTensor(row_index)

    valid_mask = (col_index >= 0) & (col_index < response_len)
    row_index = row_index[valid_mask]
    col_index = col_index[valid_mask]

    mark[row_index, col_index] = mark_value
    return mark


def get_tool_response_mask(
    tool_response_start: List[torch.tensor],
    tool_response_end: List[torch.tensor],
    response_len: int,
) -> torch.Tensor:
    left_ind = get_mark(tool_response_start, response_len, 1)
    right_ind = get_mark(tool_response_end, response_len, -1)
    ind = left_ind + right_ind
    return 1 - torch.cumsum(ind, dim=-1)


def get_overlong_filtering_mask(batch_inputs):
    overlong_filtering_mask = []
    for example in batch_inputs:
        valid = True
        for finish_reason in example["finish_reason"]:
            if finish_reason != "stop":
                valid = False
                break

        overlong_filtering_mask.append(valid)

    return overlong_filtering_mask


class ToolChatCompletionScheduler(ChatCompletionScheduler):
    def __init__(
        self,
        config,
        model_path,
        server_addresses,
        tool_app_addresses: Dict = {},
        **kwargs,
    ):
        super().__init__(config, model_path, server_addresses, **kwargs)

        self.tool_app_addresses = tool_app_addresses
        self.tool_types = list(tool_app_addresses.keys())
        self.tool_address_to_round_robin_index = {
            ta: 0 for ta in self.tool_app_addresses
        }

        extra_eos_tokens = []
        if config.extra_eos_token is not None:
            extra_eos_tokens.append(config.extra_eos_token)
        for tool_type in self.tool_types:
            extra_eos_tokens.append(getattr(config.tools, tool_type.value).eos_token)
        self.extra_eos_tokens = extra_eos_tokens

        self.resp_parsers = load_dict_from_py(
            config.tools.tool_custom_utility_path, "RESP_PARSERS"
        )
        self.tool_call_res_encapsulation = load_dict_from_py(
            config.tools.tool_custom_utility_path, "RESP_ENCAPSULATION"
        )

    def parse_tool(self, content: str) -> Tuple[ToolType, Dict]:
        for tool_type in self.tool_types:
            tool_parser = self.resp_parsers[tool_type]
            tool_type, tool_args = tool_parser(content)
            if tool_type is not None:
                return tool_type, tool_args
        return None, None

    async def tool_execution(
        self,
        tool_type: ToolType,
        tool_args: Dict,
        retries: int = 5,
        last_token: int = 0,
    ) -> str:
        for i in range(retries):
            try:
                async with aiohttp.ClientSession() as session:
                    async with session.post(
                        url=f"http://{self.tool_app_addresses[tool_type][self.tool_address_to_round_robin_index[tool_type]]}/tool/{tool_type.value}",
                        json=tool_args,
                    ) as resp:
                        response = await resp.json()
                        # response
                        tool_call_res_encapsulation = partial(
                            self.tool_call_res_encapsulation[tool_type],
                            tokenizer=self.tokenizer,
                            max_len=getattr(
                                self.config.tools, tool_type.value
                            ).max_tool_resp_len,
                            last_token=last_token,
                        )
                        self.tool_address_to_round_robin_index[tool_type] = (
                            self.tool_address_to_round_robin_index[tool_type] + 1
                        ) % len(self.tool_app_addresses[tool_type])
                        return tool_call_res_encapsulation(response)
            except Exception as e:
                self.tool_address_to_round_robin_index[tool_type] = (
                    self.tool_address_to_round_robin_index[tool_type] + 1
                ) % len(self.tool_app_addresses[tool_type])
                print(f"REQUEST ERROR: [{i}, {e}]")

        raise Exception("Failed tool call!")

    def _split_thought_and_tool(self, tool_type, response: str) -> Tuple[str]:
        tool_bos_token = getattr(self.config.tools, tool_type.value).bos_token

        idx = response.rfind(tool_bos_token)
        if idx == -1:
            raise ValueError(f"Invalid response: {response}")
        else:
            return (response[:idx], response[idx:])

    async def generate_sequences(
        self, batch: DataProto, **sampling_params
    ) -> DataProto:
        # NOTE: Since tools may cause repeated invocations, we force `n=1`
        # in such cases and then repeat the batch `n` times ourselves.
        do_sample = batch.meta_info.get("do_sample", True)
        is_validate = batch.meta_info.get("validate", False)

        if not is_validate:
            batch = batch.repeat(self.config.n)
        kwargs = dict(
            n=1,
            max_completion_tokens=self.config.response_length,
            temperature=self.config.temperature,
            top_p=self.config.top_p,
            include_stop_str_in_output=True,
            stop=self.extra_eos_tokens + [self.tokenizer.eos_token],
            skip_special_tokens=False,  # NOTE: we include eos here. Hence, we do not add any eos manually when apply_chat_template
        )

        if not do_sample or is_validate:
            kwargs["n"] = 1
            kwargs["temperature"] = 0

        kwargs.update(sampling_params)
        print(
            f"[ToolChatCompletionScheduler] generate_sequences sampling params: {kwargs}"
        )

        max_turns = self.config.tools.max_turns

        async def callback(
            completions: TokenListChatCompletionResponse,
            info: Dict[str, Any],
            exception: Exception,
        ):
            batch_inputs, batch_index, turn = (
                info["batch_inputs"],
                info["batch_index"],
                info["turn"],
            )

            if exception is not None:
                print(
                    f"vllm exception: {exception}, {self.tokenizer.decode(batch_inputs[batch_index]['prompt_token_ids'])}"
                )
                batch_inputs[batch_index]["finish_reason"].append("exception")
                return

            response_token_ids, finish_reason, content = (
                completions.choices[0].response_token_ids,
                completions.choices[0].finish_reason,
                completions.choices[0].content,
            )

            last_token = batch_inputs[batch_index]["prompt_token_ids"][-1]
            batch_inputs[batch_index]["prompt_token_ids"] += response_token_ids
            batch_inputs[batch_index]["finish_reason"].append(finish_reason)
            if "messages" in batch_inputs[batch_index]:
                batch_inputs[batch_index]["messages"].append(
                    {"role": "assistant", "content": content}
                )

            # STEP 1: check if we got answer
            answer = self.resp_parsers["answer"](content)
            if answer:
                print(f"[id={completions.id},turn={turn}] Got answer: {answer}, done!")
                return

            # STEP 2: check if we met a tool call
            # tool parser
            tool_type, tool_args = self.parse_tool(content)
            if not tool_type:
                print(
                    f"[id={completions.id},turn={turn}] No tool found, finish reason: {finish_reason}, done!"
                )
                return

            # STEP 3: call tool
            # execute tool
            result, tool_resp_msg = await self.tool_execution(
                tool_type, tool_args, last_token=last_token
            )
            batch_inputs[batch_index]["tool_response_start"].append(
                len(batch_inputs[batch_index]["prompt_token_ids"])
            )
            batch_inputs[batch_index]["prompt_token_ids"] += result
            batch_inputs[batch_index]["tool_response_end"].append(
                len(batch_inputs[batch_index]["prompt_token_ids"])
            )

            if "messages" in batch_inputs[batch_index]:
                # split the tool call response
                batch_inputs[batch_index]["messages"][-1]["content"] = (
                    self._split_thought_and_tool(
                        tool_type,
                        batch_inputs[batch_index]["messages"][-1]["content"],
                    )
                )
                batch_inputs[batch_index]["messages"].append(
                    {"role": "tool", "content": tool_resp_msg}
                )

            print(
                f"[id={completions.id},turn={turn}] {tool_type.value} executed, continue..."
            )

            # STEP 4: check if we reach max turns
            if turn == max_turns:
                print(
                    f"[id={completions.id},turn={turn}] Reach max turns {max_turns}, done!"
                )
                return

            # STEP 5: resubmit chat completions with code block output
            extra_headers = {"x-request-id": completions.id}
            await self.submit_chat_completions(
                callback=callback,
                callback_additional_info={
                    "batch_inputs": batch_inputs,
                    "batch_index": batch_index,
                    "turn": turn + 1,
                },
                prompt_token_ids=batch_inputs[batch_index]["prompt_token_ids"],
                extra_headers=extra_headers,
                **kwargs,
            )

        tasks, batch_inputs = [], [None] * len(batch)
        input_ids = batch.batch["input_ids"].tolist()
        prompt_lens = batch.batch["attention_mask"].sum(dim=-1).tolist()
        for batch_index in range(batch.batch["input_ids"].shape[0]):
            # remove padding
            prompt_token_ids = input_ids[batch_index][-prompt_lens[batch_index] :]
            batch_inputs[batch_index] = {
                "prompt_token_ids": prompt_token_ids,
                "tool_response_start": [],
                "tool_response_end": [],
                "finish_reason": [],
            }

            if "raw_prompt" in batch.non_tensor_batch:
                batch_inputs[batch_index]["messages"] = list(
                    batch.non_tensor_batch["raw_prompt"][batch_index]
                )

            tasks.append(
                asyncio.create_task(
                    self.submit_chat_completions(
                        callback=callback,
                        callback_additional_info={
                            "batch_inputs": batch_inputs,
                            "batch_index": batch_index,
                            "turn": 1,
                        },
                        prompt_token_ids=prompt_token_ids,
                        **kwargs,
                    )
                )
            )
        await asyncio.gather(*tasks)
        print("[ToolChatCompletionScheduler] generate_sequences done")

        return self._postprocess(batch, batch_inputs)

    def _postprocess(
        self,
        batch: DataProto,
        batch_inputs: List[Dict[str, List[int]]],
    ) -> DataProto:
        # NOTE: consistent with batch version of generate_sequences in vllm_rollout_spmd.py
        # prompts: left pad
        # responses: right pad
        # input_ids: prompt + response
        # attention_mask: [0,0,0,0,1,1,1,1, | 1,1,1,0,0,0,0,0]
        # position_ids:   [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11]

        tool_call_num = list(
            map(lambda example: len(example["tool_response_start"]), batch_inputs)
        )

        if "raw_prompt" in batch.non_tensor_batch:
            raw_prompt = batch.non_tensor_batch["raw_prompt"]
        else:
            raw_prompt = None

        prompt_lens = batch.batch["attention_mask"].sum(dim=-1)
        (
            prompts_unpad,
            responses_unpad,
            tool_call_positions,
            tool_response_start,
            tool_response_end,
            messages,
        ) = ([], [], [], [], [], [])
        for index, inputs in enumerate(batch_inputs):
            token_ids = inputs["prompt_token_ids"]
            prompt_len = prompt_lens[index]
            prompts_unpad.append(token_ids[:prompt_len])
            responses_unpad.append(token_ids[prompt_len:])
            tool_response_start.append(
                torch.IntTensor(inputs["tool_response_start"]) - prompt_len
            )
            tool_call_positions.append(tool_response_start[-1] - 1)
            tool_response_end.append(
                torch.IntTensor(inputs["tool_response_end"]) - prompt_len
            )

            if "messages" in inputs:
                messages.append(inputs["messages"])

        prompts, prompts_attention_mask = verl_F.pad_2d_list_to_length(
            prompts_unpad,
            self.tokenizer.pad_token_id,
            left_pad=True,
            return_attention_mask=True,
        )
        responses, responses_attention_mask = verl_F.pad_2d_list_to_length(
            responses_unpad,
            self.tokenizer.pad_token_id,
            left_pad=False,
            return_attention_mask=True,
        )
        input_ids = torch.cat([prompts, responses], dim=-1)
        attention_mask = torch.cat(
            [prompts_attention_mask, responses_attention_mask], dim=-1
        )
        position_ids = (torch.cumsum(attention_mask, dim=-1) - 1) * attention_mask

        # construct loss mask for responses
        # mask out tool responses
        tool_response_mask = get_tool_response_mask(
            tool_response_start, tool_response_end, responses.shape[1]
        )
        loss_mask = responses_attention_mask * tool_response_mask

        # overlong filtering
        if self.config.tools.apply_overlong_filtering:
            overlong_fitering_mask = get_overlong_filtering_mask(batch_inputs)
            print(
                f"[{len(overlong_fitering_mask) - sum(overlong_fitering_mask)}/{len(overlong_fitering_mask)}] rollouts are invalid!"
            )
            loss_mask = loss_mask * torch.tensor(
                overlong_fitering_mask,
                dtype=loss_mask.dtype,
                device=loss_mask.device,
            ).unsqueeze(-1)

        batch = TensorDict(
            {
                "prompts": prompts,
                "responses": responses,
                "input_ids": input_ids,
                "attention_mask": attention_mask,
                "loss_mask": loss_mask,
                "position_ids": position_ids,
                "tool_call_num": torch.LongTensor(tool_call_num),
            },
            batch_size=len(input_ids),
        )
        non_tensor_batch = {
            "tool_call_positions": np.array(tool_call_positions, dtype=object),
        }

        if messages:
            non_tensor_batch["messages"] = np.array(messages, dtype=object)
        if raw_prompt is not None:
            non_tensor_batch["raw_prompt"] = raw_prompt

        return DataProto(
            batch=batch,
            non_tensor_batch=non_tensor_batch,
        )
