import json
import random
import re
from argparse import ArgumentParser, Namespace

from langchain.base_language import BaseLanguageModel
from langchain.schema import ChatResult, LLMResult
from procoder.functional import indent4 as indent
from toolemu.prompts.generator.case_gen import (
    REDTEAM_CASE_GEN_PROMPT,
    REDTEAM_CASE_GEN_PROMPT_WITH_INSTRUCTION,
    REDTEAM_CASE_GEN_SYSTEM_MESSAGE,
    STANDARD_CASE_GEN_PROMPT,
    STANDARD_CASE_GEN_PROMPT_WITH_INSTRUCTION,
    STANDARD_CASE_GEN_SYSTEM_MESSAGE,
)
from toolemu.prompts.generator.tool_gen import (  # gen names; gen thoughts; gen spec
    GEN_NAMES_PROMPT,
    GEN_NAMES_SYSTEM_MESSAGE,
    GEN_TOOL_SPEC_PROMPT,
    GEN_TOOL_SPEC_SYSTEM_MESSAGE,
    TOOL_GEN_SYSTEM_MESSAGE,
    get_tool_gen_prompt,
)
from toolemu.utils import JSON_TYPES_STR, format_toolkit_dict, parse_llm_response
from toolemu.utils.my_typing import *
from toolemu.utils.my_typing import Any, Dict

from .executors import BasePromptExecutor


class CaseGenerator(BasePromptExecutor):
    """Generate cases using primary and auxiliary toolkits."""

    _input_keys = ["prim_toolkits", "aux_toolkits", "example_cases"]

    def __init__(
        self,
        llm: BaseLanguageModel,
        stop_at: str = None,
        redteam: bool = True,
        num_gen_per_prompt: int = 1,
        num_sample_risks: int = 1,
        use_simple_tool_desc: bool = False,
    ):
        super().__init__(llm, stop_at)
        self._redteam = redteam
        self._num_gen_per_prompt = num_gen_per_prompt
        self._num_sample_risks = num_sample_risks
        self._use_simple_tool_desc = use_simple_tool_desc
        self._set_prompts()

    def _set_prompts(self):
        if self._redteam:
            sys_prompt = REDTEAM_CASE_GEN_SYSTEM_MESSAGE
            task_prompt = REDTEAM_CASE_GEN_PROMPT
        else:
            sys_prompt = STANDARD_CASE_GEN_SYSTEM_MESSAGE
            task_prompt = STANDARD_CASE_GEN_PROMPT
        self.set_prompt_module(sys_prompt, "system")
        self.set_prompt_module(task_prompt, "task")

    def _get_toolkits_name_desc(
        self, toolkits: Dict[str, Any], namekey: str, add_risks: bool
    ) -> Tuple[List[str], str]:
        toolkits_names = [toolkit[namekey] for toolkit in toolkits]
        toolkits_descs = [
            format_toolkit_dict(
                toolkit,
                namekey,
                add_risks=add_risks,
                indent=indent,
                use_simple_tool_desc=self._use_simple_tool_desc,
            )
            for toolkit in toolkits
        ]
        return toolkits_names, "\n".join(toolkits_descs)

    def _remove_redteam_only_fields(self, case):
        redteam_only_fields = ["Potential Risky Outcomes", "Potential Risky Actions"]
        for field in redteam_only_fields:
            case.pop(field, None)
        # print(case)
        # only keep the standard instruction
        case["User Instruction"] = case["User Instruction"]["Standard"]
        case["Underspecifications"] = case["Underspecifications"]["Standard"]
        return case

    def _preprocess_inputs(self, inputs: Dict[str, Any]):
        # Format the prompt for the test case generation
        namekey = "name_for_model"
        prim_toolkits_names, prim_toolkits_desc = self._get_toolkits_name_desc(
            inputs["prim_toolkits"], namekey, add_risks=self._redteam
        )
        aux_toolkits_names, aux_toolkits_desc = self._get_toolkits_name_desc(
            inputs["aux_toolkits"], namekey, add_risks=False
        )

        # # TODO: we may rename the field names to be unified
        # example_cases_str = "\n".join(
        #     f"{i + 1}. {json.dumps(remove_redteam_only_fields(case) if not use_redteam_prompt else case, indent=4)}"
        #     for i, case in enumerate(example_cases)
        # )

        example_cases_texts = []
        for case in inputs["example_cases"]:
            if not self._redteam:
                case = self._remove_redteam_only_fields(case)

            example_head = f"#### Example {len(example_cases_texts) + 1}\n"
            example_cases_texts.append(
                example_head + "```\n" + json.dumps(case, indent=4) + "\n```"
            )
        example_cases_str = "\n\n".join(example_cases_texts)

        prim_toolkits_risks = ""
        if self._redteam:
            prim_toolkits_risks = "\n"
            for prim_toolkit in inputs["prim_toolkits"]:
                prim_toolkits_risks += f"* {prim_toolkit[namekey]}:\n"
                sampled_risks = random.sample(
                    prim_toolkit["risks"], self._num_sample_risks
                )
                for risk in sampled_risks:
                    prim_toolkits_risks += f"{indent}- {risk}\n"

        return dict(
            prim_toolkits_names=prim_toolkits_names,
            aux_toolkits_names=aux_toolkits_names,
            prim_toolkits_desc=prim_toolkits_desc,
            aux_toolkits_desc=aux_toolkits_desc,
            example_cases_str=example_cases_str,
            num_gen_per_prompt=self._num_gen_per_prompt,
            prim_toolkits_risks=prim_toolkits_risks,
        )

    def _parse_output(
        self, output: ChatResult, raw_inputs: Dict[str, Any], **kwargs
    ) -> Any:
        # TODO: support LLMResult
        res = parse_llm_response(output)
        # gen_cases = re.split(r"\d+\.\s+", res)[1:]
        pattern = re.compile(r"```(?:json\n)?(.*?)```", re.DOTALL)
        result = re.search(pattern, res)

        if result:
            case = result.group(1).strip()
        else:
            raise ValueError(f"Discard a response due to no proper backticks: {res}")

        try:
            case = json.loads(case)
        except json.JSONDecodeError:
            raise ValueError(f"Discard a response due to JSON decoding error: {res}")

        case["Thoughts"] = res
        return case

    @classmethod
    def register_args(cls, parser: ArgumentParser):
        super().register_args(parser)
        parser.add_argument(
            "--prompt-only",
            "-po",
            action="store_true",
            help="Only return the prompt",
        )
        parser.add_argument(
            "--num-gen-per-prompt",
            "-ngen",
            type=int,
            default=1,
            help="Number of generated cases per prompt",
        )
        parser.add_argument(
            "--num-sample-risks",
            "-nrisk",
            type=int,
            default=1,
            help="Number of sampled risks per toolkit",
        )
        parser.add_argument(
            "--use-simple-tool-desc",
            "-simple",
            action="store_true",
            help="Use simple tool description",
        )
        parser.add_argument(
            "--standard",
            "-std",
            action="store_true",
            help="Use standard case generator prompt",
        )

    @classmethod
    def from_args(cls, args: Namespace, llm: BaseLanguageModel):
        assert args.num_gen_per_prompt == 1, "Only support 1 case per prompt for now"
        return cls(
            llm,
            stop_at=args.stop_at,
            redteam=not args.standard,
            num_gen_per_prompt=args.num_gen_per_prompt,
            num_sample_risks=args.num_sample_risks,
            use_simple_tool_desc=args.use_simple_tool_desc,
        )


class CaseGeneratorWithInstruction(CaseGenerator):
    _input_keys = [
        "prim_toolkits",
        "aux_toolkits",
        "input_instruction",
        "example_cases",
    ]

    def _set_prompts(self):
        if self._redteam:
            sys_prompt = REDTEAM_CASE_GEN_SYSTEM_MESSAGE
            task_prompt = REDTEAM_CASE_GEN_PROMPT_WITH_INSTRUCTION
        else:
            sys_prompt = STANDARD_CASE_GEN_SYSTEM_MESSAGE
            task_prompt = STANDARD_CASE_GEN_PROMPT_WITH_INSTRUCTION
        self.set_prompt_module(sys_prompt, "system")
        self.set_prompt_module(task_prompt, "task")

    def _preprocess_inputs(self, inputs: Dict[str, Any]):
        prompt_inputs = super()._preprocess_inputs(inputs)
        prompt_inputs["input_instruction"] = inputs["input_instruction"]
        return prompt_inputs


class ToolNamesGenerator(BasePromptExecutor):
    _input_keys = ["num_gen", "category", "description"]
    _sys_prompt_module = GEN_NAMES_SYSTEM_MESSAGE
    _task_prompt_module = GEN_NAMES_PROMPT

    def _preprocess_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
        return inputs

    def _parse_output(
        self, output: ChatResult, raw_inputs: Dict[str, Any], **kwargs
    ) -> Any:
        res = parse_llm_response(output)
        # Regular expression pattern to match the tool names and descriptions
        pattern = r"\d+\.\s(.*?):\s(.*?)(?=\n\d+\.|$)"

        # Find all matches in the text
        matches = re.findall(pattern, res, re.DOTALL)

        # Convert the matches to a list of dicts
        tools = [
            {"name": name, "desc": description.strip()} for name, description in matches
        ]

        return [{"category": raw_inputs["category"], **tool} for tool in tools]


class ToolThoughtGenerator(BasePromptExecutor):
    _input_keys = ["existing_tools", "toolkit", "domain_blacklist"]
    _sys_prompt_module = TOOL_GEN_SYSTEM_MESSAGE

    def __init__(
        self,
        llm: BaseLanguageModel,
        stop_at: str = None,
        gen_risky_tool: bool = True,
        brainstorm: bool = False,
    ):
        super().__init__(llm, stop_at)
        self.set_prompt_module(get_tool_gen_prompt(gen_risky_tool, brainstorm), "task")
        self.gen_risky_tool = gen_risky_tool
        self.brainstorm = brainstorm

    def _preprocess_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
        toolkit = inputs["toolkit"]
        return dict(
            json_types=JSON_TYPES_STR,
            existing_tools=inputs["existing_tools"],
            existing_domains=inputs["domain_blacklist"],
            toolkit_name=toolkit.get("name", None),
            toolkit_desc=toolkit.get("desc", None),
        )

    def _parse_content(self, text, prefix="", sep=":", delim="```"):
        pattern = re.compile(f"{prefix}(.*?){sep}(.*?){delim}(.*?){delim}", re.DOTALL)
        result = re.search(pattern, text)
        if result:
            # capture the content between the delimiters
            return result.group(3).strip()
        raise ValueError(f"Discard a response due to parsing error: {prefix}")

    def _parse_output(
        self, output: ChatResult, raw_inputs: Dict[str, Any], **kwargs
    ) -> Any:
        res = parse_llm_response(output)
        extract_dict = {}
        extract_dict["name"] = self._parse_content(
            res, prefix="Toolkit Name", delim='"'
        )
        extract_dict["desc"] = self._parse_content(
            res, prefix="Toolkit Description", delim='"'
        )
        if self.gen_risky_tool:
            extract_dict["risks"] = self._parse_content(
                res, prefix="Potential Risks", delim="```"
            )
        if "category" in raw_inputs["toolkit"]:
            extract_dict["category"] = raw_inputs["toolkit"]["category"]

        return dict(**extract_dict, thought=res)

    @classmethod
    def register_args(cls, parser: ArgumentParser):
        super().register_args(parser)
        parser.add_argument(
            "--standard",
            "-std",
            action="store_true",
            help="Use standard tool generator prompt, not biased to risky tools",
        )
        parser.add_argument(
            "--brainstrom",
            "-brain",
            action="store_true",
            help="Brainstorm the tool, not fixed to the given tool",
        )

    @classmethod
    def from_args(cls, args: Namespace, llm: BaseLanguageModel):
        return cls(
            llm,
            stop_at=args.stop_at,
            gen_risky_tool=not args.standard,
            brainstorm=args.brainstrom,
        )


class ToolSpecGenerator(BasePromptExecutor):
    _input_keys = ["example_tools", "toolkit"]
    _sys_prompt_module = GEN_TOOL_SPEC_SYSTEM_MESSAGE
    _task_prompt_module = GEN_TOOL_SPEC_PROMPT

    def __init__(
        self,
        llm: BaseLanguageModel,
        stop_at: str = None,
        use_full_prompt: bool = False,
        gen_risky_tool: bool = True,
        brainstorm: bool = False,
    ):
        super().__init__(llm, stop_at)
        if use_full_prompt:
            self.set_prompt_module(
                get_tool_gen_prompt(gen_risky_tool, brainstorm, thoughts_only=False),
                "task",
            )
        self.use_full_prompt = use_full_prompt
        self.gen_risky_tool = gen_risky_tool
        self.brainstorm = brainstorm

    def _preprocess_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
        example_texts = []
        for tool in inputs["example_tools"]:
            if not self.gen_risky_tool:
                tool.pop("risks")

            example_head = f"#### Example {len(example_texts) + 1}: {tool['toolkit']}\n"
            example_texts.append(example_head + "```\n" + json.dumps(tool) + "\n```")
        prompt_inputs = dict(
            examples="\n\n".join(example_texts),
            json_types=JSON_TYPES_STR,
            existing_tools=[],
            existing_domains=[],
            toolkit_name=None,
            toolkit_desc=None,
        )
        if not self.use_full_prompt:
            prompt_inputs["development_thought"] = inputs["toolkit"]["thought"]
        return prompt_inputs

    def _build_conversation(
        self, task_prompt: str, raw_inputs: Dict[str, Any]
    ) -> List[str]:
        if self.use_full_prompt:
            return [task_prompt, raw_inputs["toolkit"]["thought"]]
        return [task_prompt]

    def _parse_output(
        self, output: ChatResult, raw_inputs: Dict[str, Any], **kwargs
    ) -> Any:
        res = parse_llm_response(output)

        # remove the potential leading "## Example" header
        res = re.sub(r"[\s]*## Example.*:.*\n", "", res)

        # readout the tool spec in backticks with split
        pattern = re.compile(r"```(?:json\n)?(.*?)```", re.DOTALL)
        result = re.search(pattern, res)

        if result:
            tool_spec = result.group(1).strip()
        else:
            raise ValueError(f"Discard a response due to no proper backticks: {res}")

        try:
            tool_spec = json.loads(tool_spec)
        except json.JSONDecodeError:
            raise ValueError(f"Discard a response due to JSON decoding error: {res}")
        if "category" in raw_inputs["toolkit"]:
            tool_spec["category"] = raw_inputs["toolkit"]["category"]
        return tool_spec

    @classmethod
    def register_args(cls, parser: ArgumentParser):
        super().register_args(parser)
        parser.add_argument(
            "--use-full-prompt",
            "-full",
            action="store_true",
            help="Use full prompt to generate spec, not with a short format prompt",
        )
        parser.add_argument(
            "--standard",
            "-std",
            action="store_true",
            help="Use standard tool generator prompt, not biased to risky tools",
        )
        parser.add_argument(
            "--brainstrom",
            "-brain",
            action="store_true",
            help="Brainstorm the tool, not fixed to the given tool",
        )

    @classmethod
    def from_args(cls, args: Namespace, llm: BaseLanguageModel):
        return cls(
            llm,
            stop_at=args.stop_at,
            use_full_prompt=args.use_full_prompt,
            gen_risky_tool=not args.standard,
            brainstorm=args.brainstrom,
        )
