import random
from agents.insight_agent import InsightAgentAPI
from utils.insight_tree import InsightTree
from utils.utils import collect_sort_nodes_into_phases, calculate_softmax_weights
import numpy as np


def get_parameters_to_base_adding(**params):
    """
    Select parameters for calling the base_random_adding algorithm

    :param params: dict of parameters for adding algorithms
    :return: parameters for calling the base_random_adding algorithm
    """
    params_names = (
        'phase_index', 'all_idea_indexes',
        'eda_output', 'eda_images', 'mode'
    )
    return {param_name: params[param_name] for param_name in params_names}


def get_parameters_to_groups_split(**params):
    params_names = (
        'phase_index', 'all_idea_indexes',
        'eda_output', 'eda_images'
    )
    return {param_name: params[param_name] for param_name in params_names}


class Adding_algorithms:
    def __init__(self, config, is_higher_better, insighter: InsightAgentAPI, sub_logger):
        self.config = config
        self.insight_tree: InsightTree = insighter.insight_tree
        self.is_higher_better = is_higher_better
        self.insighter = insighter
        self.sub_logger = sub_logger

        self.mode = None
        if config['algorithm_type_adding'] == 'base_random_adding':
            self.adding_algorithm = self.base_adding
            self.get_parameters = get_parameters_to_base_adding
            self.mode = 'random'
        elif config['algorithm_type_adding'] == 'top_n_adding':
            self.adding_algorithm = self.base_adding
            self.get_parameters = get_parameters_to_base_adding
            self.mode = 'top_n'
        elif config['algorithm_type_adding'] == "adding_probability_distribution":
            self.adding_algorithm = self.base_adding
            self.get_parameters = get_parameters_to_base_adding
            self.mode = 'probability_distribution'
        elif config['algorithm_type_adding'] == "adding_greedy_epsilon":
            self.adding_algorithm = self.base_adding
            self.get_parameters = get_parameters_to_base_adding
            self.mode = 'greedy_epsilon'
        elif config['algorithm_type_adding'] == 'adding_with_groups_split':
            self.adding_algorithm = self.add_with_groups_split
            self.get_parameters = get_parameters_to_groups_split
        else:
            raise ValueError(f"Adding algorithm type ‘{config['algorithm_type_adding']}’ not found")

    def add_new_ideas(
            self, task_name: str, node_indexes_on_the_phase: list,
            eda_output: str | None, eda_images: list | None
    ):
        """
        Sort ideas by phase and average score and call the adding algorithm.

        :param task_name: stage name
        :param node_indexes_on_the_phase: indexes of the last stage (modeling),
                              from which all available nodes can be recursively collected
        :param eda_output: text output from the EDA stage
        :param eda_images: images generated during the EDA phase
        :return: new idea indexes
        """
        if self.is_higher_better:
            # Sort by score descending
            node_indexes_on_the_phase.sort(key=lambda x: -self.insight_tree.nodes[x].mean_score)
        else:
            # Sort by score ascending
            node_indexes_on_the_phase.sort(key=lambda x: self.insight_tree.nodes[x].mean_score)

        # all_idea_indexes = [[Modeling node indexes], [Data preparation and feature engineering node indexes]]
        all_idea_indexes = collect_sort_nodes_into_phases(
            node_indexes_on_the_phase,
            self.insight_tree,
            self.is_higher_better
        )

        all_params = {
            'phase_index': None, 'all_idea_indexes': all_idea_indexes,
            'eda_output': eda_output, 'eda_images': eda_images,
            'mode': self.mode
        }
        match task_name:
            case "Model training":
                all_params['phase_index'] = 0
            case "Data preparation and feature engineering":
                all_params['phase_index'] = 1
            case _:
                raise KeyError(f"Task name ‘{task_name}’ not found")

        params = self.get_parameters(**all_params)
        new_idea_indexes = self.adding_algorithm(**params)
        return new_idea_indexes

    def base_adding(
            self, phase_index: int, all_idea_indexes: list, mode: str,
            eda_output: str | None, eda_images: list | None
    ):
        """
        If at the topmost stage, we select random nodes to which we add ideas,
        and then generate the rest of the subtree.

        Otherwise, we select random parents at the previous stage,
        add new nodes to each parent, and regenerate the remaining subtree.

        :param phase_index: Phase index
        :param all_idea_indexes: two-dimensional array of indices of all nodes grouped by phases
        :param eda_output: text output from the EDA stage
        :param eda_images: images generated during the EDA phase
        :return: a dictionary where the keys are parents
                 and the values are dictionaries with ideas and parent code
        """
        task_name, filename, _ = self.config['phases'][phase_index]
        idea_indexes_at_the_stage = all_idea_indexes[phase_index]

        self.sub_logger.info(f"➕ Adding. Stage: {task_name}")
        match task_name:
            case "Model training":
                if mode == 'random':
                    # Selecting random parents
                    k = min(self.config['number_of_selected_node'], len(all_idea_indexes[phase_index + 1]))
                    chosen_parent_indexes = random.choices(
                        all_idea_indexes[phase_index + 1],
                        k=k
                    )

                elif mode == "top_n":
                    chosen_parent_indexes = all_idea_indexes[phase_index + 1][:self.config['number_of_selected_node']]
                elif mode == "probability_distribution":
                    nodes_score = [
                            self.insight_tree.nodes[node_index].mean_score for node_index in all_idea_indexes[phase_index + 1]
                    ]
                    weighted_scores = calculate_softmax_weights(nodes_score, self.is_higher_better)
                    k = min(self.config['number_of_selected_node'], len(all_idea_indexes[phase_index + 1]))
                    chosen_parent_indexes = random.choices(
                        all_idea_indexes[phase_index + 1],
                        k=k,
                        weights=weighted_scores
                    )
                elif mode == "greedy_epsilon":
                    if self.config['adding_epsilon'] >= random.random():
                        chosen_parent_indexes = all_idea_indexes[phase_index + 1][:self.config['number_of_selected_node']]
                    else:
                        k = min(self.config['number_of_selected_node'], len(all_idea_indexes[phase_index + 1]))
                        chosen_parent_indexes = random.choices(
                            all_idea_indexes[phase_index + 1],
                            k=k
                        )
                else:
                    raise ValueError(f"Mode ‘{mode}’ not found")

                ideas = {}
                for parent_index in chosen_parent_indexes:
                    # We add ideas for every parent
                    code = self.insight_tree.get_all_code_in_branch(parent_index)
                    score = self.insight_tree.nodes[parent_index].mean_score
                    all_ideas = list(set(
                        [str(self.insight_tree.nodes[i].idea) for i in idea_indexes_at_the_stage]
                    ))
                    add_idea = self.insighter.add_insights(
                        code=code, score=score,
                        all_ideas=all_ideas,
                        eda_output=eda_output,
                        eda_images=eda_images,
                        current_task=task_name
                    )['insights']
                    ideas[parent_index] = {"idea": add_idea, "previous_code": code, "group": None}
            case "Data preparation and feature engineering":
                # Add ideas and regenerate the remaining subtree
                all_ideas_and_scores = [
                    (self.insight_tree.nodes[node_index].idea, self.insight_tree.nodes[node_index].mean_score)
                    for node_index in idea_indexes_at_the_stage
                ]
                add_idea = self.insighter.add_insights_no_parent_phase(
                    task_name,
                    all_ideas_and_scores,
                    eda_output=eda_output,
                    eda_images=eda_images,
                    group_name=None
                )['insights']
                ideas = {-1: {"idea": add_idea, "previous_code": None, "group": None}}
            case _:
                raise KeyError(f"Task name ‘{task_name}’ not found")

        return ideas

    def add_with_groups_split(self, phase_index, all_idea_indexes, eda_output, eda_images):
        task_name, filename, _ = self.config['phases'][phase_index]
        idea_indexes_at_the_stage = all_idea_indexes[phase_index]
        idea_indexes_at_the_FE = all_idea_indexes[1]

        self.sub_logger.info(f"➕ Adding. Stage: {task_name}")
        current_groups = []
        node_indexes_without_group = []
        number_idea_with_group = 0
        for node_index in idea_indexes_at_the_FE:
            node = self.insight_tree.nodes[node_index]
            if node.group is not None:
                if node.group not in current_groups:
                    current_groups.append(node.group)
                number_idea_with_group += 1
            else:
                node_indexes_without_group.append(node_index)

        if not current_groups:
            # current_groups is empty
            ideas = ""
            for node_index in idea_indexes_at_the_FE:
                node = self.insight_tree.nodes[node_index]
                ideas += f"\nIndex: {node.index}\nIdea: {node.idea}\n\n"

            groups = self.insighter.split_ideas_into_groups(ideas)
            for group in groups:
                if group == "thoughts":
                    continue
                for node_index in groups[group]:
                    self.insight_tree.nodes[node_index].group = group
        elif number_idea_with_group != len(idea_indexes_at_the_FE):
            group_names = "\n".join(current_groups)
            ideas_without_group = ""
            for node_index in node_indexes_without_group:
                node = self.insight_tree.nodes[node_index]
                ideas_without_group += f"\nIndex: {node.index}\nIdea: {node.idea}\n\n"
            groups = self.insighter.add_ideas_to_groups(group_names, ideas_without_group)
            for group in groups:
                if group == "thoughts":
                    continue
                for node_index in groups[group]:
                    self.insight_tree.nodes[node_index].group = group

        groups = {}

        for node_index in idea_indexes_at_the_FE:
            node = self.insight_tree.nodes[node_index]
            if node.group not in groups:
                groups[node.group] = {}
                groups[node.group]['node_indexes'] = [node_index]
                groups[node.group]['number_of_inputs'] = node.num_of_eval
                groups[node.group]['scores'] = np.array([node.mean_score])
            else:
                groups[node.group]['node_indexes'].append(node_index)
                groups[node.group]['number_of_inputs'] += node.num_of_eval
                np.append(groups[node.group]['scores'], node.mean_score)

        for group in groups:
            groups[group]['number of ideas in the group'] = len(groups[group]['scores'])
            groups[group]['max_score'] = np.max(groups[group]['scores'])
            groups[group]['min_score'] = np.min(groups[group]['scores'])
            groups[group]['mean_score'] = np.mean(groups[group]['scores'])
            groups[group]['median_score'] = np.median(groups[group]['scores'])
            groups[group]['std_score'] = np.std(groups[group]['scores'])

        select_group = self.insighter.select_group(groups, task_name)
        ideas = {}
        match task_name:
            case "Model training":
                all_ideas_and_scores = ""
                for node_index in groups[select_group]['node_indexes']:
                    node = self.insight_tree.nodes[node_index]
                    if node is not None:
                        all_ideas_and_scores += f"Idea index: {node.index}\nI" \
                                                f"Mean score: {node.mean_score}\n" \
                                                f"Number of inputs: {node.num_of_eval}\n\n"

                parent_indexes = self.insighter.select_nodes(all_ideas_and_scores)
                for parent_index in parent_indexes:
                    # We add ideas for every parent
                    code = self.insight_tree.get_all_code_in_branch(parent_index)
                    score = self.insight_tree.nodes[parent_index].mean_score
                    all_ideas = [self.insight_tree.nodes[i].idea for i in idea_indexes_at_the_stage]
                    add_idea = self.insighter.add_insights(
                        code=code, score=score,
                        all_ideas=all_ideas,
                        eda_output=eda_output,
                        eda_images=eda_images,
                        current_task=task_name
                    )['insights']
                    ideas[parent_index] = {"idea": add_idea, "previous_code": code, "group": None}

            case "Data preparation and feature engineering":
                if select_group in groups:
                    all_ideas_and_scores = [
                        (self.insight_tree.nodes[node_index].idea, self.insight_tree.nodes[node_index].mean_score)
                        for node_index in groups[select_group]['node_indexes']
                    ]
                else:
                    all_ideas_and_scores = "There are no ideas in this group yet."
                add_idea = self.insighter.add_insights_no_parent_phase(
                    task_name,
                    all_ideas_and_scores,
                    eda_output=eda_output,
                    eda_images=eda_images,
                    group_name=select_group
                )['insights']
                ideas = {-1: {"idea": add_idea, "previous_code": None, "group": select_group}}
            case _:
                raise KeyError(f"Task name ‘{task_name}’ not found")
        return ideas
