import logging
import os
from typing import Any, List
import json

from operation_variants import BertScoreOperation_Variant, SelfCheckGPTOperation_Variant, CheckEmbedOperation_Variant
from CheckEmbed import language_models
from CheckEmbed import embedding_models
from CheckEmbed.plotters import BertPlot
from CheckEmbed.plotters import CheckEmbedPlot
from CheckEmbed.plotters import SelfCheckGPTPlot
from CheckEmbed.plotters import RawEmbeddingHeatPlot
from CheckEmbed.parser import Parser
from CheckEmbed.scheduler import Scheduler, StartingPoint

class CustomParser(Parser):
    """
    The CustomParser class handles the dataset parsing.

    Inherits from the Parser class and implements its abstract methods.
    """

    def __init__(self, dataset_path: str, prompt_scheme_path: str, error_number: int, final_responses_path: str) -> None:
        """
        Initialize the parser.

        :param dataset_path: The path to the dataset.
        :type dataset_path: str
        :param prompt_scheme_path: The path to the prompt scheme file.
        :type prompt_scheme_path: str
        :param error_number: Number of errors that the LLM is asked in the prompts to generate.
        :type error_number: int
        :param final_responses_path: The path to the hallucinations.json file.
        :type final_responses_path: str
        """
        super().__init__(dataset_path)
        self.prompt_scheme_path = prompt_scheme_path
        self.error_number = error_number
        self.final_responses_path = final_responses_path

    def prompt_generation(self, custom_inputs: Any = None) -> List[str]:
        """
        Parse the dataset and generate the prompts for the model.

        :param custom_inputs: The custom inputs to the parser. Defaults to None.
        :type custom_inputs: Any
        :return: List of prompts.
        :rtype: List[str]
        """
        # Getting the input data from the dataset
        input_data = []
        with open(self.dataset_path) as f:
            json_data = json.load(f)

        data_array = json_data['data']
        for data in data_array:
            input_data.append(data['chunk_txt'])

        # Getting the prompt scheme
        with open(self.prompt_scheme_path) as f:
            prompt_complete = f.read()

        # Use input data as context inside the prompts
        prompts = []
        for data in input_data:
            prompt_copy = prompt_complete
            prompts.append(prompt_copy.replace("[###REPLACE WITH CONTEXT###]", data).replace("### NUMBER ###", str(self.error_number)))

        return prompts

    def ground_truth_extraction(self, custom_inputs: Any = None) -> List[str]:
        """
        Parse the dataset and extract the ground truth.

        :param custom_inputs: The custom inputs to the parser. Defaults to None.
        :type custom_inputs: Any
        :return: List of ground truths.
        :rtype: List[str]
        """
        pass

    def answer_parser(self, responses: List[List[str]], custom_inputs: Any = None) -> List[List[str]]:
        """
        Parse the responses from the model.

        :param responses: The responses from the model.
        :type responses: List[List[str]]
        :param custom_inputs: The custom inputs to the parser. Defaults to None.
        :type custom_inputs: Any
        :return: The parsed responses. The output is going to be saved in a .json file using the following structure:
        {
            "data": [
                {
                    "index": "1",
                    "hallucination": "parsed_response",
                },
                {
                    "index": "2",
                    "hallucination": "parsed_response",
                },
                ...
            ]
        }
        :rtype: List[List[str]]
        """
        new_responses = []
        hallucinations = []
        for response in responses:
            new_response = []
            hallucination = []
            for res in response:
                index = res.find("### SUMMARY ###")
                new_response.append(res[index:])
                hallucination.append(res[:index])
            new_responses.append(new_response)
            hallucinations.append(hallucination)
        
        with open(self.final_responses_path + "/hallucinations.json", "w") as f:
            hallucinations_json = [{"index": i, "hallucination": hallucinations[i]} for i in range(len(hallucinations))]
            json.dump({"data": hallucinations_json}, f, indent=4)
        return new_responses

def start(current_dir: str, ground_truth_gen: bool = False, error_number: int = 0, start: int = StartingPoint.PROMPT) -> None:
    """
    Execute the incremental forced hallucination use case with a specific error number.

    :param current_dir: Directory path from the the script is called.
    :type current_dir: str
    :param ground_truth_gen: Generate ground truth. Defaults to False.
    :type ground_truth_gen: bool
    :param error_number: Number of errors that the LLM is asked in the prompts to generate. Defaults to 0.
    :type error_number: int
    :param start: Starting point indicator. Defaults to StartingPoint.PROMPT.
    :type start: int
    """
    config_path = os.path.join(
        current_dir,
        "../../../../CheckEmbed/config.json",
    )

    # Initialize the parser and the embedder
    customParser = CustomParser(
        dataset_path = os.path.join(current_dir, "../dataset/legal_definitions.json"),
        prompt_scheme_path = os.path.join(current_dir, "../prompt_scheme.txt" if not ground_truth_gen else "../prompt_scheme_ground_truth.txt"),
        error_number = error_number,
        final_responses_path = current_dir,
    )

    # Initialize the language models
    gpt3 = language_models.ChatGPT(
        config_path,
        model_name = "chatgpt",
        cache = True,
    )

    gpt4_o = language_models.ChatGPT(
        config_path,
        model_name = "chatgpt4-o",
        cache = True,
        )

    embedd_large = embedding_models.EmbeddingGPT(
        config_path,
        model_name = "gpt-embedding-large",
        cache = False,
    )

    sfrEmbeddingMistral = embedding_models.SFREmbeddingMistral(
        config_path,
        model_name = "Salesforce/SFR-Embedding-Mistral",
        cache = False,
        batch_size=8,
    )

    e5mistral7b = embedding_models.E5Mistral7b(
        config_path,
        model_name = "intfloat/e5-mistral-7b-instruct",
        cache = False,
        batch_size=8,
    )

    gteQwen157bInstruct = embedding_models.GteQwenInstruct(
        config_path=config_path,
        model_name= "Alibaba-NLP/gte-Qwen1.5-7B-instruct",
        cache = False,
        access_token = "", # Add your access token here
        batch_size=1,
    )


    # Initialize BERTScore, SelfCheckGPT and CheckEmbedOperation operations
    bertOperation = None if ground_truth_gen else BertScoreOperation_Variant(
            os.path.join(current_dir, "BertScore"),
            os.path.join(current_dir, "../ground_truth"),
            current_dir,
        )

    selfCheckGPTOperation = None if ground_truth_gen else SelfCheckGPTOperation_Variant(
            os.path.join(current_dir, "SelfCheckGPT"),
            os.path.join(current_dir, "../ground_truth"),
            current_dir,
        )
    
    checkEmbedOperation = None if ground_truth_gen else CheckEmbedOperation_Variant(
            os.path.join(current_dir, "CheckEmbed"),
            os.path.join(current_dir, "../ground_truth/embeddings"),
            os.path.join(current_dir, "embeddings"),
        )

    # Initialize the plot operations
    bertPlot = BertPlot(
        os.path.join(current_dir, "plots", "BertScore"),
        os.path.join(current_dir, "BertScore"),
    )

    selfCheckGPTPlot = SelfCheckGPTPlot(
        os.path.join(current_dir, "plots", "SelfCheckGPT"),
        os.path.join(current_dir, "SelfCheckGPT"),
    )

    rawEmbeddingHeatPlot = RawEmbeddingHeatPlot(
        os.path.join(current_dir, "plots", "CheckEmbed"),
        os.path.join(current_dir, "embeddings"),
    )

    checkEmbedPlot = CheckEmbedPlot(
        os.path.join(current_dir, "plots", "CheckEmbed"),
        os.path.join(current_dir, "CheckEmbed"),
    )

    # Initialize the scheduler
    scheduler = Scheduler(
        current_dir,
        logging_level = logging.DEBUG,
        budget = 10,
        parser = customParser,
        lm = [gpt4_o, gpt3],
        embedding_lm = [embedd_large, sfrEmbeddingMistral, e5mistral7b, gteQwen157bInstruct],
        operations = [bertPlot, selfCheckGPTPlot, rawEmbeddingHeatPlot, checkEmbedPlot],
        bertScoreOperation = bertOperation,
        selfCheckGPTOperation = selfCheckGPTOperation,
        checkEmbedOperation = checkEmbedOperation,
    )

    # The order of lm_names and embedding_lm_names should be the same 
    # as the order of the language models and embedding language models respectively.
    scheduler.run(
        startingPoint = StartingPoint.OPERATIONS,
        bertScore = False,
        selfCheckGPT = False,
        ground_truth = False,
        spacy_separator = True,
        num_samples = 10,
        lm_names = ["gpt4-o", "gpt"],
        embedding_lm_names = ["gpt-embedding-large", "sfr-embedding-mistral", "e5-mistral-7b", "gte-qwen-1.5-7b-instruct"],
        bertScore_model = "microsoft/deberta-xlarge-mnli",
        batch_size = 64, # it may be necessary to reduce the batch size if the model is too large
        device = "cuda" # or "cpu" "mps" ...
    )

if __name__ == "__main__":
    current_dir = os.path.dirname(os.path.abspath(__file__)) + "/ground_truth"
    if not os.path.exists(current_dir):
        os.makedirs(current_dir)
    start(current_dir, ground_truth_gen=True, error_number=0)

    current_dir = os.path.dirname(os.path.abspath(__file__)) + "/error_1"
    if not os.path.exists(current_dir):
        os.makedirs(current_dir)
    start(current_dir, ground_truth_gen=False, error_number=1)

    current_dir = os.path.dirname(os.path.abspath(__file__)) + "/error_2"
    if not os.path.exists(current_dir):
        os.makedirs(current_dir)
    start(current_dir, ground_truth_gen=False, error_number=2)

    current_dir = os.path.dirname(os.path.abspath(__file__)) + "/error_3"
    if not os.path.exists(current_dir):
       os.makedirs(current_dir)
    start(current_dir, ground_truth_gen=False, error_number=3)

    current_dir = os.path.dirname(os.path.abspath(__file__)) + "/error_4"
    if not os.path.exists(current_dir):
       os.makedirs(current_dir)
    start(current_dir, ground_truth_gen=False, error_number=4)

    current_dir = os.path.dirname(os.path.abspath(__file__)) + "/error_5"
    if not os.path.exists(current_dir):
        os.makedirs(current_dir)
    start(current_dir, ground_truth_gen=False, error_number=5)

    current_dir = os.path.dirname(os.path.abspath(__file__)) + "/error_6"
    if not os.path.exists(current_dir):
        os.makedirs(current_dir)
    start(current_dir, ground_truth_gen=False, error_number=6)

    current_dir = os.path.dirname(os.path.abspath(__file__)) + "/error_7"
    if not os.path.exists(current_dir):
        os.makedirs(current_dir)
    start(current_dir, ground_truth_gen=False, error_number=7)

    current_dir = os.path.dirname(os.path.abspath(__file__)) + "/error_8"
    if not os.path.exists(current_dir):
        os.makedirs(current_dir)
    start(current_dir, ground_truth_gen=False, error_number=8)

    current_dir = os.path.dirname(os.path.abspath(__file__)) + "/error_9"
    if not os.path.exists(current_dir):
        os.makedirs(current_dir)
    start(current_dir, ground_truth_gen=False, error_number=9)

    current_dir = os.path.dirname(os.path.abspath(__file__)) + "/error_10"
    if not os.path.exists(current_dir):
        os.makedirs(current_dir)
    start(current_dir, ground_truth_gen=False, error_number=10)
