import os

from pydantic import BaseModel, Field
from voyager.agents.llm import get_llm
import voyager.utils as U
from langchain_openai import OpenAIEmbeddings
from langchain_community.vectorstores import Chroma
from langchain.schema import HumanMessage, SystemMessage
from voyager.prompts import load_prompt
from langchain.prompts import ChatPromptTemplate
from langchain.chains.openai_functions import create_structured_output_runnable
from langchain.output_parsers import PydanticOutputParser
from langchain.prompts import PromptTemplate



from typing import List, Dict

class QAItem(BaseModel):
    """Model for a QA Item"""
    question: str = Field(..., description="The question part of the QA pair")
    answer: str = Field(..., description="The answer part of the QA pair")

class QAList(BaseModel):
    """Model for a list of QA Items"""
    qa_list: List[QAItem] = Field(..., description="A list of QA items")

class QAItemCandidate(BaseModel):
    """Model for a QA Item"""
    question_id: str = Field(..., description="The question part of the QA pair")
    should_be_selected: bool = Field(..., description="Whether the QA pair should be selected")

class QAListCandidate(BaseModel):
    """Model for a list of QA Items"""
    qa_list: List[QAItemCandidate] = Field(..., description="A list of candidate QA items")


class QAItemSelected(BaseModel):
    """Model for a QA Item"""
    selected_questions_id: list[int] = Field(..., description="The selected questions id")


class KnowledgeManager:
    def __init__(
        self,
        logger,
        model_name="gpt-3.5-turbo",
        temperature=0,
        retrieval_top_k=5,
        request_timeout=120,
        ckpt_dir="ckpt",
        resume=True,
    ):

        self.logger = logger

        self.logger.info(f"arguments: model_name={model_name}, temperature={temperature}, retrieval_top_k={retrieval_top_k}, request_timeout={request_timeout}, ckpt_dir={ckpt_dir}, resume={resume}")

        self.ckpt_dir = ckpt_dir
        self.logger.info(f"Checkpoint directory: {ckpt_dir}")

        U.f_mkdir(f"{ckpt_dir}/qa")

        # Vector database for QA pairs
        self.vectordb = Chroma(
            collection_name="qa_vectordb",
            embedding_function=OpenAIEmbeddings(),
            persist_directory=f"{ckpt_dir}/qa/vectordb",
        )
        print("resume", resume)
        print("os.path.exists(f'{ckpt_dir}/qa/qa_pairs.json')", os.path.exists(f"{ckpt_dir}/qa/qa_pairs.json"))

        if resume and os.path.exists(f"{ckpt_dir}/qa/qa_pairs.json"):

            print(f"\033[33mLoading QA Interaction Manager from {ckpt_dir}/qa\033[0m")
            self.qa_pairs = U.load_json(f"{ckpt_dir}/qa/qa_pairs.json")
        else:
            self.qa_pairs = {}

        self.logger.info(f"QA pairs: {self.qa_pairs}")

        assert self.vectordb._collection.count() == len(self.qa_pairs.keys()), (
            "QA Interaction Manager's vectordb is not synced with qa_pairs.json.\n"
            f"There are {self.vectordb._collection.count()} entries in vectordb but {len(self.qa_pairs)} entries in qa_pairs.json.\n"
            "Did you set resume=False when initializing the manager?\n"
            "You may need to manually delete the vectordb directory for running from scratch."
        )
        self.retrieval_top_k = retrieval_top_k
        self.llm = get_llm(model_name)

        self.print_database()

    def save_belief(self, belief: QAItem, task):
        self.add_new_qa_pair(belief.question, belief.answer, task)

    def process_beliefs(self, beliefs, task):
        # print in red we are consolidating interactions into long term memory
        print("\033[31mConsolidating interactions into long term memory\033[0m")

        qas = self.generate_qa_from_beliefs(beliefs)
        print("qas", qas)
        for qa in qas:
            self.add_new_qa_pair(qa.question, qa.answer, task)

    def format_qa(self, qa_list):
        formatted_str = ""
        for qa in qa_list:
            formatted_str += f"Question: {qa['question']}\nAnswer: {qa['answer']}\n\n"
        return formatted_str

    def is_context_correct_based_on_beliefs(self, context, task, partner_beliefs=None):

        qas = self.retrieve_qa_pairs(task)
        if partner_beliefs is not None:
            qas += partner_beliefs
        if qas == []:
            formated_qas = "No beliefs found"
        else:
            formated_qas = self.format_qa(qas)

        messages = [
            SystemMessage(content=load_prompt("interaction_is_context_correct")),
            HumanMessage(
                content=f"""
                AI generated context (to evaluate if it is true or flase):\n
                {context}\n
                ---\n
                Beliefs:\n
                {formated_qas}""",
            ),
        ]

        self.logger.info(f"PROMPT IS CONTEXT CORRECT: {messages[1].content}\n\n")

        # print("Messages:", messages)

        output = self.llm(messages).content

        # check if message contains "False" or "True" if contains both return False be case incensitive
        if "true" in output.lower() and "false" not in output.lower():
            result = True

        else:
            result = False

        self.logger.info(f" Is context correct based on QA? {result}")

        return result



    def update_context(self, context, task, partner_beliefs=None):
        # correctly format the partner beliefs to match the other ones used to update the context
        if partner_beliefs is not None:
            partner_beliefs = [{"question": belief.question, "answer": belief.answer} for belief in partner_beliefs]

        new_context = self.update_context_based_on_beliefs(context, task, partner_beliefs)

        # NOTE: the partner beliefs that come from the conversation are not added here
        new_context = self.add_relevant_qa_pairs_to_context(new_context, task)

        return new_context

    def add_relevant_qa_pairs_to_context(self, correct_context, task, partner_help=None):
        qas = self.retrieve_qa_pairs(task)
        if partner_help is not None:
            qas += partner_help
        if qas == []:
            return correct_context

        # TODO remove duplicates, use langchain to do this
        # add id to each qa pair

        # remove QAS with the same question, remove duplicates

        print(correct_context)

        question = self.extract_question_from_context(correct_context)


        # filter all qas that have the same question
        unique_questions = set()
        unique_questions.add(question)




        filtered_qas = []

        for qa in qas:
            if qa['question'] not in unique_questions:
                unique_questions.add(qa['question'])
                filtered_qas.append(qa)

        if filtered_qas == []:
            return correct_context

        formated_qas = self.format_qa(filtered_qas)
        new_context_with_qas = f"{correct_context}\n\nFrom memory:\n{formated_qas}"



        return new_context_with_qas


    def store_qa_pairs(self, qa_list, task):
        for qa in qa_list:
            print(qa.question, qa.answer)
            self.add_new_qa_pair(qa.question, qa.answer, task)

    def transform_message_into_qa(self, message, task):

        qas = self.retrieve_qa_pairs(task)

        formated_qas = self.format_qa(qas)

        prompt = ChatPromptTemplate.from_messages([
            SystemMessage(content=load_prompt("interaction_new_context_from_beliefs")),
            HumanMessage(
                content=f"The task is {task}\n The example questions and answers are:\n{formated_qas}\n. You need to create a QA based on the following message:\n{message}",
            ),
        ])

        runnable = create_structured_output_runnable(QAItem, self.llm, prompt)
        output = runnable.invoke({"formated_qas": formated_qas, "message": message, "task": task})

        return output


    def extract_question_from_context(self, context):
        if "Question:" in context:
            question = context.split('Question: ')[1].split('Answer: ')[0]
            question = question.strip()



            return question

        if "Q:" in context:
            question = context.split('Q: ')[1].split('A: ')[0]
            question = question.strip()
            return question

        raise ValueError("No question found in context", context)
        # extcrat everything after Q: and before A:


    def update_context_based_on_beliefs(self, context, task, partner_help=None):
        # print("Context:", context)
        question = self.extract_question_from_context(context)

        if partner_help is not None:
            qas = self.retrieve_qa_pairs(task)
            qas += partner_help
        else:
            qas = self.retrieve_qa_pairs(task)
        if qas == []:
            return context
        formated_qas = self.format_qa(qas)

        messages = [
            SystemMessage(content=load_prompt("new_context_from_beliefs_and_prior_context")),
            HumanMessage(
                content=f"The questions and answers are:\n{formated_qas}\n. You need to find an answer to the following question:\n{question}",
            ),
        ]

        answer = self.llm(messages).content


        new_context = f"(Updated based on other beliefs)\nQuestion: {question}\nAnswer: {answer}"
        self.logger.info(f"FORMATTED QAS: {formated_qas} for new context: {new_context}")
        return new_context




    def add_new_qa_pair(self, question, answer, task):
        print("Adding new QA pair")
        # Generate a unique identifier for the QA pair
        qa_id = f"qa_{len(self.qa_pairs) + 1}"
        # self.qa_pairs[qa_id] = {"question": question, "answer": answer}

        # if question already exists, update the previously stored answer with the new one
        for qa_question in self.qa_pairs.keys():
            if qa_question == question:
                # TODO we should update question and answer
                return

        id = len(self.qa_pairs)


        self.qa_pairs[question] = {"question": question, "answer": answer, "task": task}

        self.vectordb.add_texts(
            texts=[question],
            ids=[question],
            metadatas=[{"question": question, "answer": answer, "task": task}],
        )

        U.dump_json(self.qa_pairs, f"{self.ckpt_dir}/qa/qa_pairs.json")
        self.vectordb.persist()
        self.logger.info(f"Added new QA pair: {question} -> {answer}")

        print(self.qa_pairs)



    def generate_qa_from_beliefs(self, beliefs: str) -> List[QAItem]:

        # parser = PydanticOutputParser(pydantic_object=QAList)

        # prompt = PromptTemplate(
        #     template="""Answer the user instruction.
        #     {format_instructions}
        #     System instructions:
        #     {system}
        #     User instructions:
        #     {query}
        #     """,
        #     input_variables=["system", "query"],
        #     partial_variables={"format_instructions": parser.get_format_instructions()},
        # )

        # chain = prompt | self.llm | parser

        # output = chain.invoke(
        #     {"query": "the beliefs are: {beliefs}",
        #      "system" : load_prompt("interaction_qa_from_beliefs"),
        #      })

        # print("output", output)

        qas = []

        prompt = ChatPromptTemplate.from_messages(
            [
                (
                    "system",
                    load_prompt("interaction_qa_from_beliefs"),
                ),
                (
                    "human",
                        f"the beliefs are: {beliefs}",
                )
            ]
        )


        runnable = create_structured_output_runnable(QAList, self.llm, prompt)
        output = runnable.invoke({"input": beliefs})
        qas += output.qa_list

        return qas

    def remove_all_beliefs(self, task: str):
        """Remove all beliefs associated with a specific task from the database."""
        self.qa_pairs = {k: v for k, v in self.qa_pairs.items() if v["task"] != task}
        self.vectordb._collection.delete(where={"task": task})
        U.dump_json(self.qa_pairs, f"{self.ckpt_dir}/qa/qa_pairs.json")
        self.vectordb.persist()
        self.logger.info(f"All beliefs associated with task '{task}' have been removed from the database.")

    def retrieve_qa_pairs(self, query, all_pairs=False):
        if not all_pairs:
            k = min(self.vectordb._collection.count(), self.retrieval_top_k)
        else:
            k = self.vectordb._collection.count()
        if k == 0:
            return []
            self.logger.info(f"Retrieving top {k} similar QA pairs, for query {query}")
        docs_and_scores = self.vectordb.similarity_search_with_score(query, k=k)
        similar_qa_pairs = [
            {"question": doc.metadata["question"], "answer": doc.metadata["answer"]}
            for doc, _ in docs_and_scores
        ]

        # for i, qa in enumerate(similar_qa_pairs, 1):
        #     print(f"\033[33mQA Pair {i}:\033[0m Q: {qa['question']}\nA: {qa['answer']}\n")

        return similar_qa_pairs

    def print_database(self):

        self.logger.info("QA Database:")
        self.logger.info(f"dir: {self.ckpt_dir}")
        self.logger.info(f"Number of QA pairs: {len(self.qa_pairs)}")
        for i, (qa_id, qa) in enumerate(self.qa_pairs.items(), 1):
            self.logger.info(f"QA Pair {i}: Q: {qa['question']}\nA: {qa['answer']}\n")
