# Standard Imports
import json
import re
import sys
from datetime import datetime

# App Imports
from src.constants import Task, LLM
from src.subjects import VARIABLE_MAPPING, EVALUATION_MAPPING
from src.llm_manager import LLMManager


class TaskManager:
    @staticmethod
    def process(llm, model_type, task, dataset):
        if task == Task.Verbalization.value:
            VerbalizationManager(llm, model_type, task, dataset).process()
        elif task == Task.Generation.value:
            GenerationManager(llm, model_type, task, dataset).process()


class VerbalizationManager:
    def __init__(self, llm, model_type, task, dataset):
        self.llm = llm
        self.model_type = model_type
        self.task = task
        self.dataset = dataset

        # Get Variables
        get_verbalization_variables = VARIABLE_MAPPING[self.model_type][self.task]
        variables = get_verbalization_variables(self.dataset)

        # Get Prompts
        self.system_prompt = variables.get("system_prompt")
        self.user_prompt = variables.get("user_prompt")
        self.dataset_parameters = variables.get("dataset_parameters")

    def process(self):
        # Get LLM Response
        llm_response = self.process_llm()

        # Display LLM Response
        print(llm_response)

        # Log LLM Response
        self.log_response(llm_response)

    def process_llm(self):
        # Display Prompts
        self.display_prompts(self.system_prompt, self.user_prompt)

        # Call LLM Manager
        llm_manager = LLMManager(
            self.llm, self.task, self.system_prompt, self.user_prompt
        )
        response = llm_manager.call_llm()
        return response

    def log_response(self, response):
        logfile_path = self.dataset_parameters.get("logfile")
        dataset_label = self.dataset_parameters.get("dataset_label")
        important = self.process_important()

        # First read the existing logs
        with open(logfile_path, "r") as logfile:
            logs = json.load(logfile)

        # Add the new log entry
        id = len(logs)
        log = {
            "id": id,
            "date": datetime.now().isoformat(),
            "llm": self.llm,
            "models": self.model_type,
            "task": self.task,
            "dataset": self.dataset,
            "dataset_label": dataset_label,
            "system_prompt": self.system_prompt,
            "query_prompt": self.user_prompt,
            "response": response,
            "important": important,
            "generation_logs": [],
        }
        logs.append(log)

        # Write the updated logs back to the file in "w" mode
        with open(logfile_path, "w") as logfile:
            json.dump(logs, logfile, indent=2)

    # Helper Functions
    def display_prompts(self, system_prompt, user_prompt):
        print("------------System Prompt------------")
        print(system_prompt)
        print("-------------------------------------")
        print("-------------User Prompt-------------")
        print(user_prompt)
        print("-------------------------------------")

    def process_important(self):
        yes = {"yes", "y", "ye"}
        no = {"no", "n", ""}

        choice = input("------------Mark as Important? y/N?------------").lower()
        if choice in yes:
            return True
        elif choice in no:
            return False
        else:
            sys.stdout.write("Please respond with 'yes' or 'no'")


class GenerationManager:
    def __init__(self, llm, model_type, task, dataset):
        self.llm = llm
        self.model_type = model_type
        self.task = task
        self.dataset = dataset

        # Get Variables
        get_generation_variables = VARIABLE_MAPPING[self.model_type][self.task]
        variables = get_generation_variables(self.dataset)

        # Get Prompts
        self.system_prompt = variables.get("system_prompt")
        self.user_prompt = variables.get("user_prompt")
        self.dataset_parameters = variables.get("dataset_parameters")
        self.verbalization_id = variables.get("verbalization_id")
        self.verbalization = variables.get("verbalization")

    def process(self):
        # Process LLM call
        responses, parsed_responses = self.process_llm()

        synthetic_output = self.flatten_parsed_response(parsed_responses)
        true_output = self.dataset_parameters.get("dataset_evaluation_sample")
        
        print("-------------------SYNTHETIC OUTPUT-----------------------")
        print(synthetic_output)
        print("----------------------------------------------------------")
        print("---------------------TRUE OUTPUT--------------------------")
        print(true_output)
        print("----------------------------------------------------------")

        # Evaluation - This should happen before we log the response
        evaluation = EVALUATION_MAPPING[self.model_type](synthetic_output, true_output)
        evaluation_scores = evaluation.get_evaluation_scores()
        print("--------------------EVALUATION SCORES---------------------")
        print(evaluation_scores)
        print("----------------------------------------------------------")

        # Log LLM response
        self.log_response(responses, synthetic_output, evaluation_scores)

    def process_llm(self):
        dataset_sample = self.dataset_parameters.get("dataset_sample")
        batch_size = 16

        # Split the dataset into batches
        batches = [
            dataset_sample[i : i + batch_size]
            for i in range(0, len(dataset_sample), batch_size)
        ]

        all_responses = []
        all_parsed_responses = []

        for batch_index, batch in enumerate(batches):
            # Format the prompts for the entire batch
            batch_params = self.dataset_parameters.copy()
            batch_params["dataset_sample"] = batch

            formatted_system_prompt = self.system_prompt.format(**batch_params)
            formatted_user_prompt = self.user_prompt.format(**batch_params, verbalization=self.verbalization)

            # Display Prompts
            self.display_prompts(formatted_system_prompt, formatted_user_prompt)

            # Call LLM Manager for the batch
            llm_manager = LLMManager(self.llm, self.task, formatted_system_prompt, formatted_user_prompt)
            response = llm_manager.call_llm()
            print("------------------LLM RESPONSE----------------------")
            print(response)
            print("----------------------------------------------------")

            # Parse the response
            parsed_response = self.parse_llm_response(response)

            # Call LLM Manager for the batch
            if self.llm == LLM.Gemini.value:
                i = 2
                while(self.retry_if_not_match(parsed_response, batch) and i > 0):
                    print(f"Trying for {i}")
                    llm_manager = LLMManager(self.llm, self.task, formatted_system_prompt, formatted_user_prompt)
                    response = llm_manager.call_llm()
                    print("------------------LLM RESPONSE----------------------")
                    print(response)
                    print("----------------------------------------------------")
                    parsed_response = self.parse_llm_response(response)
                    i -= 1

            print("------------------PARSED RESPONSE--------------------")
            print(parsed_response)
            print("-----------------------------------------------------")

            all_responses.append(response)

            if parsed_response is not None:
                all_parsed_responses.append(parsed_response)
                print(f"Successfully processed batch {batch_index + 1} with {len(batch)} samples.")
            else:
                print(f"Unable to parse JSON for batch {batch_index + 1}. Moving to the next batch.")

        print(f"Total batches processed: {len(batches)}. Total successfully parsed responses: {len(all_parsed_responses)}")
        return all_responses, all_parsed_responses

    def log_response(self, responses, synthetic_output, evaluation_scores):
        logfile_path = self.dataset_parameters.get("logfile")
        dataset_label = self.dataset_parameters.get("dataset_label")
        important = self.process_important()

        # First, read the existing logs
        with open(logfile_path, "r") as logfile:
            verb_logs = json.load(logfile)

        verb_log = verb_logs[self.verbalization_id]
        gen_logs = verb_log.get("generation_logs")

        # Add the new log entry
        id = len(gen_logs)
        log = {
            "id": id,
            "date": datetime.now().isoformat(),
            "llm": self.llm,
            "models": self.model_type,
            "task": self.task,
            "dataset": self.dataset,
            "dataset_label": dataset_label,
            "system_prompt": self.system_prompt,
            "query_prompt": self.user_prompt,
            "important": important,
            "responses": responses,
            "synthetic_output": synthetic_output,
            "evaluation_scores": evaluation_scores
        }
        gen_logs.append(log)
        
        # Write the updated logs back to the file in "w" mode
        with open(logfile_path, "w") as logfile:
            json.dump(verb_logs, logfile, indent=2)

    def retry_if_not_match(self, parsed_response, batch_sample):
        if not parsed_response:
            return True
        if len(parsed_response) != len(batch_sample):
            return True
        else:
            for i in range(len(parsed_response)):
                if parsed_response[i]["input"] != batch_sample[i]["input"]:
                    return True
        return False


    # Helper Functions
    def flatten_parsed_response(self, parsed_responses):
        return [item for response in parsed_responses for item in response]

    def parse_llm_response(self, llm_response):
        # Find the JSON part using regular expressions
        json_match = re.search(r"```json\s*([\s\S]*?)\s*```", llm_response)

        if json_match:
            json_str = json_match.group(1)
            json_str = json_str.replace("'", '"')
            # Parse the JSON string
            try:
                json_data = json.loads(json_str)
                return json_data
            except json.JSONDecodeError as e:
                print(f"Error decoding JSON: {e}")
                return None
        else:
            print("No JSON found in the text")
            return None

    def display_prompts(self, system_prompt, user_prompt):
        print("------------System Prompt------------")
        print(system_prompt)
        print("-------------------------------------")
        print("-------------User Prompt-------------")
        print(user_prompt)
        print("-------------------------------------")

    def process_important(self):
        yes = {"yes", "y", "ye"}
        no = {"no", "n", ""}

        choice = input("------------Mark as Important? y/N?------------").lower()
        if choice in yes:
            return True
        elif choice in no:
            return False
        else:
            sys.stdout.write("Please respond with 'yes' or 'no'")
