from dotenv import load_dotenv
import getpass
import logging
import os
import subprocess
import sys
import tiktoken
from typing import Literal, TypeVar

from langchain.output_parsers import OutputFixingParser
from langchain.output_parsers.openai_tools import JsonOutputToolsParser
from langchain_core.messages import AIMessage, ToolMessage, HumanMessage
from langchain_experimental.utilities import PythonREPL
from langchain_openai import ChatOpenAI


from src.util.logger import BaseLogger
from src.prompt.prompt import BasePrompt

# flake8: noqa
from langchain_core.prompts.prompt import PromptTemplate

NAIVE_JSON_FIX = """Instructions:
<Instruction>
{instructions}
</Instruction>

<Completion>
{input}
</Completion>

Above, the Completion did not satisfy the constraints given in the Instructions.
<Error>
{error}
</Error>

<Recommendation>
- Check if you get any errors when parsing Json.
- I recommend checking if you have problems escaping special characters.
</Recommendation>

Please try again. 

Please only respond with an answer that satisfies the constraints laid out in the Instructions:"""

NAIVE_FIX_PROMPT = PromptTemplate.from_template(NAIVE_JSON_FIX)


def ToolAndFixingParser(parser, fixing_model):
    def output_parser(ai_message: AIMessage):
        def message_to_str(message: AIMessage):
            return message.content

        llm = fixing_model | message_to_str
        real_parser = OutputFixingParser.from_llm(parser=parser, llm=llm, prompt=NAIVE_FIX_PROMPT, max_retries=3)
        res = real_parser.parse(ai_message.content)
        MAX_PARSE_RETRIES = 3
        for _ in range(MAX_PARSE_RETRIES):
            res = real_parser.parse(ai_message.content)
            if res and type(res) is not list:
                break
        if not res or type(res) is list:
            return None
        res["ai_message"] = ai_message
        return res

    return output_parser


class BaseLLM(BaseLogger):
    name = "base"
    vendor = "base"
    load_dotenv(override=True)

    def __init__(self,  model, prompt: BasePrompt | None,  log_level: int = logging.WARNING, log_file: str | None = None, langsmith: bool = True, temperature: float = 1.0):
        if langsmith and not os.getenv("LANGSMITH_API_KEY"):
            self.logger.info("LANGSMITH_API_KEY not set. Loading from .env file")
            if os.getenv("LANGSMITH_API_KEY"):
                os.environ["LANGSMITH_API_KEY"] = os.getenv("LANGSMITH_API_KEY")
            else:
                self.logger.warning("LANGSMITH_API_KEY not found in .env file. Please set it manually.")
                os.environ["LANGSMITH_API_KEY"] = getpass.getpass("Enter Langsmith API Key: ")

        if log_file:
            fh = logging.FileHandler(log_file)
            fh.setLevel(log_level)

            self.logger.addHandler(fh)

        if log_file:
            fh = logging.FileHandler(log_file)
            fh.setLevel(log_level)
            self.logger.addHandler(fh)

        self.model = model
        self.prompt = prompt
        self.temperature = temperature

        self.fixing_model = ChatOpenAI(model="gpt-4o-mini-2024-07-18", temperature=0.1, max_retries=2, api_key="", request_timeout=30)
        self.final_parser = ToolAndFixingParser(self.prompt.parser, self.fixing_model)

        self.chain = self.prompt.template | self.model | self.final_parser

    def invoke(self, inputs: dict):
        if self.prompt.is_jinja:
            self.prompt.update_template_with_input(inputs)
        self.chain = self.prompt.template | self.model | self.final_parser
        for _ in range(2):
            ret = self.chain.invoke(inputs)
            if ret:
                return ret
