import json
import re
from tqdm import tqdm
from datetime import datetime
from qdrant_vectordb import QdrantVectorDB
from neo4j_graphdb import Neo4jGraphDB
from llama_index.core.llms import LLM
from data_loader import DataLoader
from evaluator import Evaluator

class RAGRouter:
    """route to different RAG systems"""

    def __init__(
            self,
            llm,
            embed_model,
            similarity_top_k: int = 5,
            dataset_name: str = "2wiki",
            result_dir: str = "../results/",
    ):
        self.llm = llm
        self.embed_model = embed_model
        self.similarity_top_k = similarity_top_k
        self.dataset_name = dataset_name
        self.result_dir = result_dir

        self.qdrant_db = QdrantVectorDB(
            llm=self.llm,
            embed_model=self.embed_model,
            collection_name=self.dataset_name,
            similarity_top_k=self.similarity_top_k,
        )
        self.neo4j_db = Neo4jGraphDB(
            llm_model=self.llm,
            embed_model=self.embed_model,
        )

    def is_multi_hop(
            self,
            query: str,
    ) -> bool:
        """
        Classify whether a question likely requires >=3 reasoning hops.
        Uses a LlamaIndex LLM object with `.complete(prompt)`.
        Returns True (>=3 hops) or False (<=2 hops or parse error).
        """
        prompt = f"""You are a strict binary classifier of question complexity for QA routing.
        Return EXACTLY one token: True or False. No punctuation. No explanation.

        Definition:
        - "True" if the question likely needs >= 3 reasoning hops.
        - "False" if it can be answered with 1–2 hops.

        Examples:
        Q: Who wrote 'Pride and Prejudice'?
        A: False
        Q: Capital of Spain?
        A: False
        Q: What is the currency of the country whose capital is Cairo?
        A: False
        Q: What is the capital of the country where New York is located?
        A: False
        Q: Which city has a larger population, the capital of France or the capital of Italy?
        A: False
        Q: Which scientist discovered the element isolated by the spouse of the founder of X-ray crystallography?
        A: True
        Q: Find the museum that houses the painting created by the student of the artist who co-founded Cubism, and then report the city of that museum.
        A: True

        Q: {query}
        A:
        """

        try:
            resp = self.llm.complete(prompt)
            text = (getattr(resp, "text", None) or str(resp)).strip().lower()
            return text == "true" if text in {"true", "false"} else False
        except Exception:
            return False

    def run_on_dataset(
            self,
            dataset_name: str,
            sample_size: int,
            llm: LLM,
            confidence_threshold: float = 0.8,
    ):
        qa, context = DataLoader(dataset_name=dataset_name, sample_size=sample_size).load()

        result_list = []
        for item in tqdm(qa):
            question = item['question']
            multi_hop = self.is_multi_hop(question)
            if not multi_hop:
                try:
                    rag_answer = self.qdrant_db.generate_structured_response(
                        query=question,
                        llm=llm,
                    )
                except Exception as e:
                    rag_answer = {"reasoning": "Exception !!!", "answer": str(e), "confidence": 0.0}
                    print(e)
                if float(rag_answer['confidence']) >= confidence_threshold:
                    result = {**item, "rag_answer": rag_answer['answer'], "confidence": rag_answer['confidence']}
                    result_list.append(result)
                    continue

            try:
                graphrag_answer = self.neo4j_db.generate_structured_response(
                    query=question,
                    llm=llm,
                )
            except Exception as e:
                graphrag_answer = {"reasoning": "Exception !!!", "answer": str(e), "confidence": 0.0}
                print(e)
            result = {**item, "rag_answer": graphrag_answer['answer'], "confidence": 0.66}
            result_list.append(result)

        file_path = self.result_dir + dataset_name + "RAGRouter.json"
        try:
            with open(file_path, 'w', encoding='utf-8') as file:
                json.dump(result_list, file, ensure_ascii=False, indent=4)
            print(f"data has been saved to {file_path}")
        except Exception as e:
            print(f"file save error: {e}")


def main(llm, embed_model, dataset_name: str="2wiki", sample_size: int=1000):
    qa, context = DataLoader(dataset_name=dataset_name, sample_size=sample_size).load()
    rag_router = RAGRouter(
        llm=llm,
        embed_model=embed_model,
        similarity_top_k=5,
        dataset_name=dataset_name,
    )

    rag_router.qdrant_db.delete_collection(confirm=True)
    rag_router.qdrant_db.add_documents(context)

    rag_router.neo4j_db.add_documents_in_batches(context)

    rag_router.run_on_dataset(dataset_name, sample_size=sample_size, llm=llm)

    eval_filename = f"{dataset_name}RAGRouter"
    em_score, em_set = Evaluator().evaluate_file_by_em(
        json_file_name=eval_filename,
        answer1_name="answer",
        answer2_name="rag_answer",
    )
    llmjudge_score, llmjudge_set = Evaluator().evaluate_file_by_llm_judge(
        json_file_name="2wikiRAGRouter",
        question_name="question",
        ground_truth_name="answer",
        prediction_name="rag_answer",
        llm=llm,
    )
    return em_score, llmjudge_score


if __name__ == "__main__":
    #'''
    from llm_factory import LLMFactory
    # Create ZhipuAI LLM instance
    zhipuai_llm = LLMFactory.create_llm(
        provider="zhipuai",
        model="glm-4",
        api_key="your key",
        temperature=0.0,
        max_tokens=1024
    )
    #'''

    pass

    '''
    # OpenAI LLM
    from llama_index.llms.openai import OpenAI
    openai_llm = OpenAI(
        model="gpt-4o",
        api_key="your key",
        temperature=0.0,
    )
    '''

    #'''
    # create embedding model------bge-base-en-v1.5
    from llama_index.embeddings.fastembed import FastEmbedEmbedding
    bge_base = FastEmbedEmbedding(model_name="BAAI/bge-base-en-v1.5")  # "BAAI/bge-base-en-v1.5", "BAAI/bge-small-en-v1.5"
    #'''

    print(datetime.now())
    em, llmjudge = main(llm=zhipuai_llm, embed_model=bge_base, dataset_name="2wiki", sample_size=1000)
    print(f"em_score, llmjudge_score = {em}, {llmjudge}")
    datetime.now()


