import traceback
from abc import abstractmethod
from argparse import ArgumentParser, Namespace

from langchain.base_language import BaseLanguageModel
from langchain.chat_models.base import BaseChatModel
from langchain.schema import BaseMessage, ChatResult

from procoder.functional import format_multiple_prompts
from procoder.prompt import Module as PromptModule
from toolemu.utils import convert_to_message, parse_llm_response
from toolemu.utils.my_typing import *


class BasePromptExecutor:
    """Base class for all prompt executors.

    To naturally support multithreading, the executor should be stateless.
    That is, the executor should not have any attributes that are not passed in
    through the constructor.

    args:
        llm: the language model to use
        sys_prompt_module: the prompt module for system message
        task_prompt_module: the prompt module for task message
        stop_at: stop at which stage, flexible for debugging. Can be "preprocess", "prompt", "llm"
    """

    _input_keys: List[str] = None
    _llm: BaseLanguageModel = None
    _sys_prompt_module: PromptModule = None
    _task_prompt_module: PromptModule = None
    _prompt_module_names: List[str] = ["system", "task"]

    def __init__(
            self,
            llm: BaseLanguageModel,
            stop_at: str = None,
    ):
        self._llm = llm
        self._is_chat_model = isinstance(llm, BaseChatModel)
        self._stop_at = stop_at

    @property
    def prompt_modules(self):
        return [self._sys_prompt_module, self._task_prompt_module]

    def set_prompt_module(self, prompt_module: PromptModule, prompt_type: str = "task"):
        """Set prompt modules."""
        if prompt_type == "system":
            self._sys_prompt_module = prompt_module
        elif prompt_type == "task":
            self._task_prompt_module = prompt_module
        else:
            raise ValueError(f"Unknown prompt type: {prompt_type}")

    def _validate_inputs(self, inputs: Dict[str, Any]):
        """Validate input dict."""
        if self._input_keys is None:
            # skip validation if input keys are not defined
            return
        for k in self._input_keys:
            if k not in inputs:
                raise ValueError(f"Missing input key: {k}")
        for k in inputs.keys():
            if k not in self._input_keys:
                raise ValueError(f"Unexpected input key: {k}")

    @abstractmethod
    def _preprocess_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
        """Preprocess inputs to the format that can be used by the prompt."""
        pass

    def _preprocess_prompts(
            self, prompts: List[PromptModule], prompt_inputs: Dict[str, Any]
    ) -> List[PromptModule]:
        return prompts

    def _postprocess_prompts(
            self, prompts: List[str], raw_inputs: Dict[str, Any]
    ) -> List[str]:
        return prompts

    def _process_prompts(
            self, prompt_inputs: Dict[str, Any], raw_inputs: Dict[str, Any]
    ) -> List[str]:
        prompt_modules = self._preprocess_prompts(self.prompt_modules, prompt_inputs)
        include_brackets = [name != "system" for name in self._prompt_module_names]
        prompts = format_multiple_prompts(
            prompt_modules, prompt_inputs, include_brackets=include_brackets
        )
        return self._postprocess_prompts(prompts, raw_inputs)

    def _convert_to_prompt(
            self,
            system_message: str,
            conversation: List[str],
            message_types: List[str] = None,
    ) -> Union[List[BaseMessage], str]:
        """Convert system message and conversation to prompt that can be used by LangChain LLM."""
        prompt = []
        if system_message is not None:
            prompt.append(convert_to_message(system_message, "system"))
        if message_types is None:
            message_types = [
                "human" if i % 2 == 0 else "ai" for i in range(len(conversation))
            ]
        for i, msg in enumerate(conversation):
            prompt.append(convert_to_message(msg, message_types[i]))

        if not self._is_chat_model:
            prompt = "\n\n".join([t.content for t in prompt])
        return [prompt]  # need to be a list for llm.generate

    def _build_conversation(
            self, task_prompt: str, raw_inputs: Dict[str, Any]
    ) -> List[str]:
        """Build conversation using task message and raw inputs."""
        return [task_prompt]

    def _build_prompt(
            self,
            prompt_inputs: Dict[str, Any],
            raw_inputs: Dict[str, Any],
    ) -> Dict[str, Any]:
        """Build prompt using input dict."""
        if self._task_prompt_module is None:
            raise ValueError("task prompt is not defined")
        prompts = self._process_prompts(prompt_inputs, raw_inputs)
        sys_prompt, task_prompt = prompts[:2]
        conversation = self._build_conversation(task_prompt, raw_inputs)
        return {
            "prompts": prompts,
            "conversation": conversation,
            "merged_prompt": self._convert_to_prompt(sys_prompt, conversation),
        }

    @abstractmethod
    def _parse_output(
            self, output: ChatResult, raw_inputs: Dict[str, Any], **kwargs
    ) -> Any:
        """Parse the output of LLM to the desired format."""
        # TODO: supports LLMResult
        pass

    def _run(self, inputs: Dict[str, Any], **kwargs) -> Any:
        """Execute the prompt with LLM given inputs, then parse the output."""
        if not kwargs.get("disable_input_validation", False):
            self._validate_inputs(inputs)
        prompt_inputs = self._preprocess_inputs(inputs)
        if self._stop_at == "preprocess":
            return prompt_inputs
        prompt = self._build_prompt(prompt_inputs, inputs)["merged_prompt"]
        if self._stop_at == "prompt":
            return prompt
        output = self._llm.generate(prompt)
        if self._stop_at == "llm":
            return output
        return self._parse_output(output, inputs)

    def _call_impl(self, inputs: Dict[str, Any], **kwargs) -> Any:
        """Call the executor with inputs."""
        try:
            return self._run(inputs, **kwargs)
        except Exception as e:
            traceback.print_exc()
            raise e  # ! changed from "return None" to "raise e" during refacting

    __call__: Callable[..., Any] = _call_impl

    @classmethod
    def register_args(cls, parser: ArgumentParser):
        """Register arguments for the executor."""
        parser.add_argument(
            "--stop-at",
            "-sa",
            type=str,
            default=None,
            choices=["preprocess", "prompt", "llm"],
        )

    @classmethod
    def from_args(cls, args: Namespace, llm: BaseLanguageModel):
        """Create an executor from arguments."""
        return cls(llm, stop_at=args.stop_at)


class BasePromptExecutorWithCritique(BasePromptExecutor):
    """Base class for all prompt executors with critique."""

    _critique_prompt_module: PromptModule = None
    _prompt_module_names: List[str] = ["system", "task", "critique"]

    def __init__(
            self,
            llm: BaseLanguageModel,
            critique_llm: BaseLanguageModel = None,
            stop_at: str = None,
            critique_rounds=0,
    ):
        super().__init__(llm, stop_at)
        self._critique_llm = critique_llm or llm
        self._critique_rounds = critique_rounds

    def set_prompt_modules(
            self, prompt_module: PromptModule, prompt_type: str = "task"
    ):
        """Set prompt modules."""
        if prompt_type == "critique":
            self._critique_prompt_module = prompt_module
        else:
            super().set_prompt_modules(prompt_module, prompt_type)

    @property
    def prompt_modules(self):
        return [
            self._sys_prompt_module,
            self._task_prompt_module,
            self._critique_prompt_module,
        ]

    def _run(self, inputs: Dict[str, Any], **kwargs) -> List[Any]:
        if not kwargs.get("disable_input_validation", False):
            self._validate_inputs(inputs)
        prompt_inputs = self._preprocess_inputs(inputs)
        if self._stop_at == "preprocess":
            return prompt_inputs

        prompt_dict = self._build_prompt(prompt_inputs, inputs)
        sys_prompt, task_prompt, critique_prompt = prompt_dict["prompts"]
        conversation = prompt_dict["conversation"]
        prompt = prompt_dict["merged_prompt"]
        if self._stop_at == "prompt":
            return prompt
        raw_output = self._llm.generate(prompt)
        if self._stop_at == "llm":
            return raw_output
        results = [self._parse_output(raw_output, inputs, critique_round=0)]
        for i in range(self._critique_rounds):
            try:
                if critique_prompt is None:
                    raise ValueError("critique prompt is not defined")
                conversation += [parse_llm_response(raw_output), critique_prompt]
                critique_prompt = self._convert_to_prompt(sys_prompt, conversation)
                raw_output = self._critique_llm.generate(critique_prompt)
                results.append(
                    self._parse_output(raw_output, inputs, critique_round=i + 1)
                )
            except Exception as e:
                print(f"Error encountered in critique round {i + 1}: {e}")
                break
        return results

    @classmethod
    def register_args(cls, parser: ArgumentParser):
        """Register arguments for the executor."""
        super().register_args(parser)
        parser.add_argument(
            "--critique-rounds",
            "-cr",
            type=int,
            default=0,
            help="The number of critique rounds",
        )

    @classmethod
    def from_args(
            cls,
            args: Namespace,
            llm: BaseLanguageModel,
            critique_llm: BaseLanguageModel = None,
    ):
        """Create an executor from arguments."""
        return cls(
            llm,
            critique_llm,
            stop_at=args.stop_at,
            critique_rounds=args.critique_rounds,
        )
