from jinja2 import Environment
from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.messages import BaseMessage, SystemMessage
import os

from src.util.logger import BaseLogger


def load_template_file(file_name: str) -> str:
    return open(os.path.join("src/prompt", file_name), 'r').read()


class BasePrompt(BaseLogger):
    def __init__(self, output_pydantic: BaseModel = BaseModel, human_template_file: str = "system/basic_human.txt",
                 system_template_file: str = "system/basic_system.txt", input_variables: list = [],
                 is_jinja: bool = False, messages: list[BaseMessage] = None):
        super().__init__()

        self.output_pyantic = output_pydantic
        self.parser = JsonOutputParser(pydantic_object=output_pydantic)

        self.human_template_file = human_template_file
        self.system_template_file = system_template_file
        self.human_message = load_template_file(self.human_template_file)
        self.system_message = load_template_file(self.system_template_file)

        self.input_variables = input_variables
        self.is_jinja = is_jinja

        # self.logprob_keys = logprob_keys
        if messages is None:
            self.messages = [
                SystemMessage(content=self.system_message),
            ]
        else:
            self.messages = messages

        self.messages.append(HumanMessagePromptTemplate.from_template(self.human_message + "\n\n{format_instructions}"))

        self.template = ChatPromptTemplate(
            messages=self.messages,
            input_variables=self.input_variables,
            output_parser=self.parser,
            partial_variables={"format_instructions": self.parser.get_format_instructions()}
        )

    def update_template_with_input(self, input: dict):
        if not self.is_jinja:
            self.logger.warning("This prompt is not jinja template. Skipping update_template_with_input")
            return

        env = Environment()
        template = env.from_string(self.human_message)
        self.human_message = template.render(**input)

        self.messages = self.template.messages
        self.messages[1] = HumanMessagePromptTemplate.from_template(self.human_message + "\n\n{format_instructions}")

        self.template = ChatPromptTemplate(
            messages=self.messages,
            input_variables=self.input_variables,
            output_parser=self.parser,
            partial_variables={"format_instructions": self.parser.get_format_instructions()}
        )
