import random

from agents.base_agent import BaseAgent_openai
from agents.code_agent import CodeAgentAPI
from agents.rag_agent import RagAgentAPI
from prompts.insight_agent_prompts import *
from utils.insight_tree import InsightTree
from utils.extractors import *
from algorithms.rag import RAG
import utils


class InsightAgentAPI(BaseAgent_openai):
    def __init__(self, config, agent_name,
                 main_logger, sub_logger, debug_logger,
                 coder, checker,
                 rag_agent, retrieve_model):
        super().__init__(config, agent_name, debug_logger, main_logger, sub_logger, checker=checker)

        # init agents and tree
        self.rag_agent: RagAgentAPI = rag_agent
        self.coder: CodeAgentAPI = coder
        self.insight_tree = InsightTree(retrieve_model=retrieve_model)
        self.rag = RAG(tree=self.insight_tree, retrieve_model=retrieve_model)
        self.memory_context = None

        # init paths
        self.data_dir_path = self.config['save_path']
        background_data_path = self.config['background_data_path']
        with open(background_data_path, 'r') as f:
            self.background_data = '\n'.join(f.readlines())

        # init loggers
        self.main_logger = main_logger
        self.sub_logger = sub_logger

        # init number_of_ideas
        self.number_of_ideas_eda = self.config['number_of_ideas_eda']
        self.number_of_ideas_data = self.config['number_of_ideas_data']
        self.number_of_ideas_modelling = self.config['number_of_ideas_modelling']

    def execute_insight_prompt(self, prompt_type: str, input_text: str = '', eda_images: list | None = None, **params):
        """
        Generate a response from the model based on the prompt

        :param prompt_type: A label indicating which specific prompt can be taken
        :param input_text: additional text to the instructions
        :param eda_images: array of images in base64 format
        :param params: parameters for substitution in the prompt
        :return: result: generated model text
        """
        prompt_mapping = {
            "generate_insight": UPDATE_INSTRUCTION_PROMPT,
            "add_insights": ADD_INSIGHTS_PROMPT,
            "merge_insights": MERGE_INSIGHTS_PROMPT,
            "generate_insight_eda": EDA_PROMPT,
            "add_insights_no_parent_phase": ADD_INSIGHTS_NO_PARENT_PHASE_PROMPT,
            "generate_insight_for_complex_training": GENERATE_IDEAS_FOR_COMPLEX_TRAINING,
            "split_ideas_into_groups": SPLIT_IDEAS_INTO_GROUPS,
            "select_group": SELECT_GROUP,
            "select_nodes": SELECT_NODES,
            "add_ideas_to_groups": ADD_IDEAS_TO_GROUPS
        }
        self.clear_context()
        prompt_template = prompt_mapping.get(prompt_type)
        if prompt_template is None:
            raise ValueError(f"Unknown prompt type: {prompt_type}")

        self.instructions = prompt_template.format(**params)

        if prompt_type == "generate_insight":
            result = self.code_generation_from_several_attempts(
                number_of_attempts=self.config["number_of_attempts_insight"],
                input_text=input_text,
                task_name=params["current_task"],
                images=eda_images,
                number_of_ideas=params['number_of_ideas'],

            )
        else:
            result = self.generate_response(input_text, images=eda_images)
        return result

    def get_number_of_ideas(self, task_name: str) -> int:
        """
        Returns the maximum possible number of ideas for this stage

        :param task_name: stage name
        :return: maximum number of ideas
        """
        match task_name:
            case "Model training":
                number_of_ideas = self.number_of_ideas_modelling
            case "Data preparation and feature engineering":
                number_of_ideas = self.number_of_ideas_data
            case _:
                raise KeyError(f"Task name ‘{task_name}’ not found")
        return number_of_ideas

    def generate_insight(self, current_task: str, previous_ideas: list | None = None,
                         eda_output: str | None = None, eda_images: list | None = None):
        """
        Generates ideas at this stage

        :param current_task: the stage for which ideas need to be generated
        :param previous_ideas: parenting ideas from the previous stage
        :param eda_output: text output from the EDA stage
        :param eda_images: images generated during the EDA phase
        :return: N generated ideas (N is defined in the configuration)
        """
        number_of_ideas = self.get_number_of_ideas(current_task)
        use_rag = False
        if self.rag_agent:
            use_rag = self.rag_agent.function_calling(self.background_data,
                                                      current_task, previous_ideas)

        if use_rag:
            rag_context = self.rag_agent.retrieve_rag_ideas(self.config['number_rag_ideas'])
            if not rag_context:
                rag_context = 'RAG was not used'
                self.sub_logger.info("RAG was not used")
            else:
                self.sub_logger.info("RAG was used")
        else:
            self.sub_logger.info("RAG was not used")
            rag_context = 'RAG was not used'

        match self.config['memory_algorithm']:
            case 'random_nodes':
                self.memory_context = self.rag.sample_random_nodes(
                    exclude_nodes=previous_ideas, n=self.config['memory_size'], add_context=True
                )
            case 'distant_nodes':
                if previous_ideas:
                    parent_idea = previous_ideas[-1]
                    self.memory_context = self.rag.get_most_distant_nodes(
                        parent_idea, top_n=self.config['memory_size'], distant_ideas=True,
                        add_context=True
                    )
            case 'nearest_nodes':
                if previous_ideas:
                    parent_idea = previous_ideas[-1]
                else:
                    nodes = self.insight_tree.nodes
                    neighbours_node = [
                        {'idea': nodes[node_ind].idea, 'code': nodes[node_ind].code} for node_ind in nodes
                    ]
                    if neighbours_node:
                        parent_idea = random.choice(neighbours_node)
                    else:
                        parent_idea = None
                        self.memory_context = 'Empty'

                if parent_idea is not None:
                    self.memory_context = self.rag.get_most_distant_nodes(
                        parent_idea, top_n=self.config['memory_size'], distant_ideas=False,
                        add_context=True
                    )
            case _:
                self.memory_context = ''

        return self.execute_insight_prompt(
            prompt_type="generate_insight",
            background_data=self.background_data,
            eda_images=eda_images,
            current_task=current_task,
            previous_ideas=previous_ideas,
            number_of_ideas=number_of_ideas,
            eda_output=eda_output,
            memory_context=self.memory_context,
            memory_algorithm=self.config['memory_algorithm'],
            rag_context=rag_context
        )

    def generate_competitive_insights(self, eda_output: str | None = None, eda_images: list | None = None):
        """
        Generates diverse modeling ideas and a cross-dataset insight based on EDA results
        for use in a machine learning competition setting.

        :param eda_output: text output from the EDA stage
        :param eda_images: images generated during the EDA phase
        :return: Tuple:
            - ideas (list[str]): Diverse modeling ideas.
            - number_of_ideas (int): Count of generated ideas.
            - cross_dataset_insight (str): General insight about feature impact across datasets.
        :raises ValueError: If the response format is invalid or JSON extraction fails.
        """
        for attempt in range(3):
            ideas_json = extract_json(self.execute_insight_prompt(
                prompt_type="generate_insight_for_complex_training",
                background_data=self.background_data,
                eda_images=eda_images,
                previous_ideas='',
                number_of_ideas_min=self.config['number_of_ideas_min'],
                number_of_ideas_max=self.config['number_of_ideas_max'],
                max_minutes_to_run=self.config['max_minutes_to_run_for_complex_training'],
                eda_output=eda_output,
            ))

            if ideas_json is None:
                self.sub_logger.info("❌ The insighter was unable to generate the correct ideas")
                if attempt != 2:
                    self.sub_logger.info("Retrying")
                continue
            try:
                ideas = [idea["description"] for idea in ideas_json['complex_training_strategies']]
                cross_dataset_insight = ideas_json['reference_strategy_for_feature_comparison']["description"]
                return ideas, len(ideas), cross_dataset_insight
            except KeyError:
                self.sub_logger.info("❌ The insighter was unable to generate the correct ideas")
                if attempt != 2:
                    self.sub_logger.info("Retrying")
        raise KeyError("❌ The insighter was unable to generate the correct ideas")

    def add_insights(self, code: str, score: str, all_ideas: list, current_task: str,
                     eda_output: str | None = None, eda_images: list | None = None):
        """
        Generates new ideas based on past ideas

        :param code: code for this specific idea
        :param score: code for this specific idea
        :param all_ideas: all ideas are on the same level as this particular idea
        :param eda_output: text output from the EDA stage
        :param eda_images: images generated during the EDA phase
        :param current_task: the stage for which ideas need to be generated
        :return: generated ideas
        """
        result = extract_json(self.execute_insight_prompt(
            "add_insights",
            input_text='',
            solution_code=code,
            score=score,
            background_info=self.background_data,
            previous_approaches=all_ideas,
            eda_output=eda_output,
            eda_images=eda_images,
            current_task=current_task,
            device_info=utils.device_info.get_system_info(),
            max_add_idea=self.config['max_add_idea']
        ))
        return result

    def add_insights_no_parent_phase(self, task_name: str, all_ideas_and_scores, group_name: str | None,
                                     eda_output: str | None = None, eda_images: list | None = None):
        """
        Add ideas for a phase that has no parent vertices

        :param group_name:
        :param task_name: the stage for which ideas need to be generated
        :param all_ideas_and_scores: all ideas at this stage and their scores
        :param eda_output: text output from the EDA stage
        :param eda_images: images generated during the EDA phase
        :return: generated ideas
        """
        result = extract_json(self.execute_insight_prompt(
            "add_insights_no_parent_phase",
            background_info=self.background_data,
            eda_images=eda_images,
            task_name=task_name,
            all_ideas_and_scores=all_ideas_and_scores,
            eda_output=eda_output,
            group_name=group_name,
            max_add_idea=self.config['max_add_idea']
        ))
        return result

    def generate_eda_ideas(self):
        """
        Generate and debug code for the EDA stage, including images and text information based on data

        :return: eda_ideas: generated ideas
        """

        # Generate ideas
        eda_ideas_json = extract_json(self.execute_insight_prompt(
            prompt_type="generate_insight_eda",
            background_data=self.background_data,
            number_of_ideas=self.number_of_ideas_eda
        )
        )
        if eda_ideas_json is None or 'insights' not in eda_ideas_json:
            self.main_logger.info("❌ The insighter was unable to generate the correct ideas")
            return []

        # Сut out unnecessary ideas
        eda_ideas = eda_ideas_json['insights'][:self.number_of_ideas_eda]
        ideas_text = ""

        # Logs
        for num_idea, idea in enumerate(eda_ideas):
            ideas_text += f"\t{num_idea + 1}. {idea}\n"
        self.sub_logger.info(f"⬛\tEDA ideas:\n{ideas_text}")

        return eda_ideas

    def generate_phase_ideas(self, task_name: str, previous_ideas,
                             eda_output: str | None = None, eda_images: list | None = None):
        """
        Generate ideas for this stage

        :param task_name: the stage for which ideas need to be generated
        :param previous_ideas: parenting ideas
        :param eda_output: text output from the EDA stage
        :param eda_images: images generated during the EDA phase
        :return: generated ideas, number of ideas generated
        """
        # Generate ideas
        phase_ideas_json = extract_json(
            self.generate_insight(
                current_task=task_name, previous_ideas=previous_ideas,
                eda_output=eda_output, eda_images=eda_images,
            )
        )
        if phase_ideas_json is None:
            self.main_logger.info("❌ The insighter was unable to generate the correct ideas")
            return None, None

        number_of_ideas = self.get_number_of_ideas(task_name)

        # Сut out unnecessary ideas
        ideas = phase_ideas_json['insights'][:number_of_ideas]

        # Logs
        self.debug_logger.info(ideas)
        return ideas, len(ideas)

    def split_ideas_into_groups(self, ideas):
        result = extract_json(self.execute_insight_prompt(
            "split_ideas_into_groups",
            ideas=ideas
        ))
        return result

    def add_ideas_to_groups(self, group_names, ideas_without_group):
        result = extract_json(self.execute_insight_prompt(
            "add_ideas_to_groups",
            group_names=group_names,
            ideas_without_group=ideas_without_group
        ))
        return result

    def select_group(self, groups: dict, task_name: str):
        groups_to_prompt = ""
        for group in groups:
            groups_to_prompt += f"{group}\n"
            for keys in groups[group]:
                if keys not in ("node_indexes", "scores"):
                    groups_to_prompt += f"\t{keys}: {groups[group][keys]}\n"
            groups_to_prompt += "\n"
        result = extract_json(self.execute_insight_prompt(
            "select_group",
            groups=groups_to_prompt,
            task_name=task_name
        ))
        return result['select_group']

    def select_nodes(self, all_ideas_and_scores):
        result = extract_json(self.execute_insight_prompt(
            "select_nodes",
            all_ideas_and_scores=all_ideas_and_scores,
            number_of_selected_node=self.config['number_of_selected_node']
        ))
        return result['select_nodes']
