import itertools
import numpy as np
import re
import os
import torch
from tqdm import tqdm
from flashrag.utils import get_retriever, get_generator
from flashrag.prompt import PromptTemplate
from flashrag.pipeline import BasicPipeline

api_key = os.environ.get("OPENAI_API_KEY", "")

default_config = {
    "data_dir": "dataset/",
    "corpus_path": "/home/ckddls1321/.cache/indexes/wiki18_100w.jsonl",
    "index_path": "/home/ckddls1321/.cache/indexes/wiki18_100w_e5_flat.index",
    "retrieval_method": "e5",
    "retrieval_model_path": "intfloat/e5-base-v2",
    # "corpus_path": "indexes/wiki18_100w.jsonl",
    # "index_path": "indexes/wiki18_100w_drama_flat.index",
    # "retrieval_method": "drama",
    # "retrieval_model_path": "facebook/drama-large",
    # "corpus_path": "indexes/wiki18_100w.jsonl",
    # "index_path": "indexes/wiki18_100w_gte_modernbert_flat.index",
    # "retrieval_method": "gte",
    # "retrieval_model_path": "Alibaba-NLP/gte-modernbert-base",
    # "index_path": "indexes/wiki18_fulldoc_bgem3_index.index",
    # "corpus_path": "indexes/wiki18_fulldoc.jsonl",
    # "retrieval_method": "bge",
    # "retrieval_model_path": "BAAI/bge-m3",
    # "index_path": "indexes/wiki22_128w_drama_flat.index",
    # "corpus_path": "indexes/wiki22_128w.jsonl",
    # "retrieval_method": "drama",
    # "retrieval_model_path": "facebook/drama-large",
    # "index_path": "indexes/wiki22_128w_gte_modernbert_flat.index",
    # "corpus_path": "indexes/wiki22_128w.jsonl",
    # "retrieval_method": "gte",
    # "retrieval_model_path": "Alibaba-NLP/gte-modernbert-base",
    "model2path": {
        "e5": "intfloat/e5-base-v2",
        "bge": "BAAI/bge-m3",
        "drama": "facebook/drama-large",
        "llama3-8B-instruct": "/home/ckddls1321/.cache/checkpoints/Llama-3.1-8B-Instruct",
        "llama3-3B-instruct": "/home/ckddls1321/.cache/checkpoints/Llama-3.2-3B-Instruct",
        "qwen2.5-3B-instruct": "/home/ckddls1321/.cache/checkpoints/Qwen2.5-3B-Instruct",
        "qwen2.5-7B-instruct": "/home/ckddls1321/.cache/checkpoints/Qwen2.5-7B-Instruct",
    },
    "openai_setting": {"api_key": api_key},
    "gpu_id": "0,1",
    "gpu_num": 2,
    "device": "cuda",
    # "framework": "fschat",
    # "generator_model": "llama3-8B-instruct",
    # "generator_model_path": "models/Llama-3.1-8B-Instruct",
    # "framework": "hf",
    # "generator_model": "llama3-3B-instruct",
    # "generator_model_path": "models/Llama-3.2-3B-Instruct",
    "framework": "hf",
    "generator_model": "qwen2.5-7B-instruct",
    "generator_model_path": "/home/ckddls1321/.cache/checkpoints/Qwen2.5-7B-Instruct",
    # "framework": "vllm",
    # "generator_model": "llama3-3B-instruct",
    # "generator_model_path": "meta-llama/Llama-3.2-3B-Instruct",
    # "framework": "openai",
    # "generator_model": "gpt-4o-mini",
    "generator_batch_size": 1,
    "generation_params": {
        "max_new_tokens": 16 * 1024,
    },
    "generator_max_input_len": 4 * 1024,
    "metrics": ["em", "f1", "sub_em"],
    "use_multi_retriever": False,
    # "instruction": "Query: ",
    "instruction": "",
    "retrieval_topk": 15,
    "retrieval_batch_size": 1,
    "faiss_gpu": False,
    "use_sentence_transformer": True,
    "use_retrieval_cache": "",
    "use_reranker": False,
    "retrieval_pooling_method": "",
    "retrieval_use_fp16": True,
    "retrieval_query_max_length": False,
    "retrieval_cache_path": False,
    "save_retrieval_cache": False,
    "save_intermediate_data": False,
}


class WikiRAG:
    def __init__(
        self,
        config=default_config,
        prompt_template=None,
        retriever=None,
        generator=None,
    ):
        self.config = config
        if generator is None:
            generator = get_generator(config)
        if retriever is None:
            retriever = get_retriever(config)

        self.generator = generator
        self.retriever = retriever
        self.load_prompts()

    def load_prompts(self):
        # Define the prompts used in the pipeline
        P_CAN_INSTRUCT = (
            "Below are {N} passages related to the question at the end. After reading the passages, provide three candidates knowledge with confidence score which is relevant to answer question. \n"
            "Each answer should be in the following format: \n"
            "(a) Knowledge 1, (b) Knowledge 2, (c) Knowledge 3\n"
            "For the factual knowledge, you can provide knowledge related to entity of question. \n"
            "For the complex reasoning, you can provide solving procedure and related concept, principles to address question.\n"
            "Each knowledge candidate should be clear and concise, providing relevant context that is need to solve the question.\n\n"
            "{reference}"
            "Question: {question}\n"
            "Answer:"
        )

        P_SUM_INSTRUCT = (
            "Reference:\n{reference}\n"
            "Your job is to act as a professional writer. You need to write a"
            "good-quality passage that can support the given prediction about the"
            "question only based on the information in the provided supporting passages.\n"
            "Now, let's start. After you write, please write [DONE] to indicate you"
            "are done. Do not write a prefix (e.g., 'Response:') while writing a passage.\n"
            "Question: {question}\n"
            "Prediction: {pred}\n"
            "Passage:"
        )

        P_VAL_INSTRUCT = (
            "Question: {question}\n"
            "Prediction: {pred}\n"
            "Passage: {summary}\n"
            "Does the passage correctly support the prediction? Choices: [True,False].\n"
            "Answer:"
        )

        P_RANK_INSTRUCT = (
            "Question: Given the following passages, determine which one provides a"
            "more informative answer to the subsequent question.\n"
            "Passage 1: {summary1}\n"
            "Passage 2: {summary2}\n"
            "Target Question: {question}\n"
            "Your Task:\n"
            "Identify which passage (Passage 1 or Passage 2) is more relevant and"
            "informative to answer the question at hand. Choices: [Passage 1,Passage 2].\n"
            "Answer:"
        )

        self.P_CAN_TEMPLATE = PromptTemplate(self.config, "", P_CAN_INSTRUCT)
        self.P_SUM_TEMPLATE = PromptTemplate(self.config, "", P_SUM_INSTRUCT)
        self.P_VAL_TEMPLATE = PromptTemplate(self.config, "", P_VAL_INSTRUCT)
        self.P_RANK_TEMPLATE = PromptTemplate(self.config, "", P_RANK_INSTRUCT)

    def compute_query(self, query):
        """Override to compute the query as needed."""
        # Basic example: Convert the query to lowercase and strip extra spaces
        return query.lower().strip()

    def format_ref(self, titles, texts):
        """Format the retrieved documents for input into the prompts."""
        formatted_ref = ""
        idx = 1
        for title, text in zip(titles, texts):
            formatted_ref += f"Passage #{idx} Title: {title}\n"
            formatted_ref += f"Passage #{idx} Text: {text}\n"
            formatted_ref += "\n"
            idx += 1
        return formatted_ref

    @staticmethod
    def parse_candidates(model_response):
        """Parse candidates from the model response."""
        model_response = model_response.strip("\n").strip()
        candidates = re.findall("\((\w+)\)\s*([^()]+)", model_response)
        candidates = [cand[1].split("\n")[0].strip() for cand in candidates]
        candidates = [cand.replace(",", "").strip() for cand in candidates]
        return candidates

    @staticmethod
    def parse_validation(model_response):
        """Parse the validation result from the model's response."""
        model_response = model_response.strip().lower()
        return 1 if "true" in model_response else 0

    @staticmethod
    def parse_ranking(model_response):
        """Parse ranking from the model's response."""
        model_response = model_response.strip().lower()
        if "passage 1" in model_response:
            return 1
        elif "passage 2" in model_response:
            return 0
        else:
            return 0.5

    def format_passages(self, passages):
        """
        Format passages into a structured string for use in prompts.
        """
        formatted_ref = ""
        for idx, passage in enumerate(passages, 1):
            formatted_ref += f"Passage #{idx} Title: {passage['title']}\n"
            formatted_ref += f"Passage #{idx} Text: {passage['text']}\n\n"
        return formatted_ref

    def retrieve_and_format(self, query):
        """
        Retrieve documents and format them into structured references.
        """
        retrieval_results, scores = self.retriever.search(query, return_score=True)
        for doc_item, score in zip(retrieval_results, scores):
            if "title" not in doc_item or "text" not in doc_item:
                doc_item["title"] = doc_item["contents"].split("\n")[0]
                doc_item["text"] = "\n".join(doc_item["contents"].split("\n")[1:])

        formatted_ref = self.format_passages(retrieval_results)
        return retrieval_results, formatted_ref, scores

    def format_candidates(self, query, formatted_ref, retrieval_results):
        """
        Generate candidates based on the formatted references and query.
        """
        # Prepare input prompt for candidate generation
        input_prompt = self.P_CAN_TEMPLATE.get_string(
            N=len(retrieval_results),
            reference=formatted_ref,
            question=query,
        )
        output = self.generator.generate([input_prompt])[0]
        candidates = self.parse_candidates(output)
        return candidates

    def summarize_candidates(self, query, formatted_ref, candidates):
        """
        Generate summarizations for each candidate.
        """
        if not candidates:
            return []

        # Prepare input prompts for summarization
        input_prompts = [
            self.P_SUM_TEMPLATE.get_string(
                question=query,
                pred=cand,
                reference=formatted_ref,
            )
            for cand in candidates
        ]
        summaries = self.generator.generate(input_prompts)
        return summaries

    def validate_summary(self, query, candidates, summaries):
        """
        Validate each summary to determine if it supports the candidate prediction.
        """
        input_prompts = [
            self.P_VAL_TEMPLATE.get_string(question=query, pred=cand, summary=summary)
            for cand, summary in zip(candidates, summaries)
        ]
        val_results = self.generator.generate(input_prompts)
        return [self.parse_validation(res) for res in val_results]

    def rerank_summary(self, query, summaries):
        """
        Rank summaries using pair-wise comparisons to determine their relative informativeness.
        """
        summary_num = len(summaries)
        score_matrix = np.zeros((summary_num, summary_num))
        iter_idxs = list(itertools.permutations(range(summary_num), 2))
        input_prompts = [
            self.P_RANK_TEMPLATE.get_string(
                question=query,
                summary1=summaries[idx_tuple[0]],
                summary2=summaries[idx_tuple[1]],
            )
            for idx_tuple in iter_idxs
        ]
        ranking_output = self.generator.generate(input_prompts)
        ranking_scores = [self.parse_ranking(res) for res in ranking_output]
        for idx_tuple, score in zip(iter_idxs, ranking_scores):
            score_matrix[idx_tuple[0], idx_tuple[1]] = score
        return score_matrix.sum(axis=1).squeeze().tolist()

    def retrieve_query(self, query, options=None):
        """
        Perform the full RAG workflow: retrieve documents, generate candidates,
        and summarize results.
        """
        retrieval_results, formatted_ref, scores = self.retrieve_and_format(query)
        candidates = self.format_candidates(query, formatted_ref, retrieval_results)
        if len(candidates) == 0:
            print("No candidates has been searched")
            return ""
        # if len(candidates) > 1:
        #    summaries = self.summarize_candidates(query, formatted_ref, candidates)
        #    val_scores = self.validate_summary(query, candidates, summaries)
        #    ranking_scores = self.rerank_summary(query, summaries)
        #    total_scores = [val + rank for val, rank in zip(val_scores, ranking_scores)]
        #    # best_idx = np.argmax(total_scores)
        #    # pred = candidates[best_idx]
        #    # return pred
        passage = ""
        for i, cand in enumerate(zip(candidates)):
            passage += f"Passage {i+1}: {cand}\n"
        return passage


if __name__ == "__main__":
    # Initialize the WikiRAG class with the default configuration
    image_description = (
        "The image shows young Buddhist monk standing outdoors on rocky terrain."
    )
    query = "What is the outfit this man is wearing called?"
    wiki_rag = WikiRAG()
    image_description = "young Buddhist monk standing outdoors on rocky terrain"
    query = "What is the outfit this man is wearing called?"
    print("Testing Retriever...")
    retrieval_results, formatted_ref = wiki_rag.retrieve_and_format(
        image_description + query
    )
    print("Retrieved Documents:")
    for doc in retrieval_results:
        print(f"Title: {doc['title']}")
        print(f"Text: {doc['text'][:300]}...\n")  # Print only first 300 characters
    print("=" * 50)
    print("Generating Candidates...")
    candidates = wiki_rag.format_candidates(query, formatted_ref, retrieval_results)
    print("Generated Candidates:")
    print(candidates)
    print("=" * 50)
    print("Summarizing Candidates...")
    summaries = wiki_rag.summarize_candidates(query, formatted_ref, candidates)
    print("Summaries for Each Candidate:")
    for cand, summary in zip(candidates, summaries):
        print(f"Candidate: {cand}\nSummary: {summary}\n")
    print("=" * 50)
    print("Validating Summaries...")
    val_scores = wiki_rag.validate_summary(query, candidates, summaries)
    print("Validation Scores:")
    for cand, score in zip(candidates, val_scores):
        print(f"Candidate: {cand}, Valid: {bool(score)}")
    print("=" * 50)
    # print("Reranking Summaries...")
    # ranking_scores = wiki_rag.rerank_summary(query, summaries)
    # print("Ranking Scores:")
    # for cand, score in zip(candidates, ranking_scores):
    #     print(f"Candidate: {cand}, Rank Score: {score}")
    # print("=" * 50)
