import random
from abc import ABC, abstractmethod
from collections import deque
from agents.insight_agent import InsightAgentAPI
from agents.debug_agent import DebugAgentAPI
from agents.code_agent import CodeAgentAPI
from utils.utils import collect_sort_nodes_into_phases, calculate_softmax_weights


class BaseMergingAlgorithm(ABC):
    def __init__(self, config, is_higher_better, debugger: DebugAgentAPI,
                 coder: CodeAgentAPI, insighter: InsightAgentAPI, sub_logger, debug_logger):
        self.config = config
        self.insight_tree = insighter.insight_tree
        self.is_higher_better = is_higher_better
        self.debugger = debugger
        self.coder = coder
        self.insighter = insighter
        self.sub_logger = sub_logger
        self.debug_logger = debug_logger
        self.memory = set()
        self.memory_long = {}

    def _log_memory(self):
        memory_log = '-' * 84 + "\n" + "Memory:\n" + str(self.memory) + "\n" + '-' * 84
        long_memory_log = '-' * 84 + "\n" + "Long memory:\n" + str(self.memory_long) + "\n" + '-' * 84
        self.debug_logger.info(memory_log)
        self.debug_logger.info(long_memory_log)

    @abstractmethod
    def merge(self, all_idea_indexes: list, pre_path: str, baseline_score: float) -> list:
        pass

    def merge_2_ideas(self, node_index_1, node_index_2, phase_index, parent_index, pre_path):
        current_index = self.insight_tree.current_index
        phase, filename, _ = self.config['phases'][phase_index]
        new_filename = f"{filename.strip('.py')}_merging_{node_index_1}-{node_index_2}_{current_index}.py"

        idea1 = self.insight_tree.nodes[node_index_1].idea
        idea2 = self.insight_tree.nodes[node_index_2].idea

        code1 = self.insight_tree.nodes[node_index_1].code
        code2 = self.insight_tree.nodes[node_index_2].code

        score1 = self.insight_tree.nodes[node_index_1].mean_score
        score2 = self.insight_tree.nodes[node_index_2].mean_score

        data_from_the_first_idea = (idea1, code1, score1)
        data_from_the_second_idea = (idea2, code2, score2)

        task_filepath, code, _, score = self.debugger.generate_and_debug_code(
            filename=new_filename,
            pre_path=pre_path,
            submission_name=f"my_submission_{current_index}.csv",
            coder_implement_func=self.coder.merge_insights,
            code_params=dict(
                data_from_the_first_idea=data_from_the_first_idea,
                data_from_the_second_idea=data_from_the_second_idea,
                current_task=phase,
                node_index=current_index,
                parent_index=parent_index
            ),
            agent_name="Coder",
            needs_invalid=bool(phase == "Model training")
        )
        return code, _, score

    def greedily_choose_the_nodes(self, idea_indexes):
        for the_best_node_id in range(len(idea_indexes)):
            the_best_node_index = idea_indexes[the_best_node_id]
            for best_node_id in range(the_best_node_id + 1, len(idea_indexes)):
                best_node_index = idea_indexes[best_node_id]
                merge_index = f"{the_best_node_index}-{best_node_index}"
                if merge_index in self.memory or merge_index[::-1] in self.memory:
                    continue

                # Long memory
                if the_best_node_index > best_node_index:
                    sort_merge_index = merge_index
                else:
                    sort_merge_index = f"{best_node_index}-{the_best_node_index}"

                if sort_merge_index in self.memory_long:
                    if self.memory_long[sort_merge_index] == self.config['max_memory_long']:
                        continue
                    self.memory_long[sort_merge_index] += 1
                else:
                    self.memory_long[sort_merge_index] = 1

                return the_best_node_index, best_node_index
        return None, None

    def add_edge_between_nodes_from_different_phases(self, parent_index, children_index, pre_path):
        current_index = self.insight_tree.current_index
        phase_index = 0
        phase, filename, _ = self.config['phases'][phase_index]
        new_filename = f"{filename.strip('.py')}_merging_{parent_index}-{children_index}_{current_index}.py"

        task_filepath, code, _, score = self.debugger.generate_and_debug_code(
            filename=new_filename,
            pre_path=pre_path,
            submission_name=f"my_submission_{current_index}.csv",
            coder_implement_func=self.coder.implement_task,
            code_params=dict(
                idea=self.insight_tree.nodes[children_index].idea,
                previous_code=self.insight_tree.nodes[parent_index].code,
                task_name=phase,
                eda_output=None, eda_images=None,
                node_index=current_index,
                parent_index=parent_index
            ),
            agent_name="Coder",
            needs_invalid=True
        )

        if code is not None:
            idea_index = self.insighter.insight_tree.add_idea(
                self.insight_tree.nodes[children_index].idea, code, parent_index
            )
            self.insight_tree.nodes[idea_index].mean_score = score
            self.insight_tree.nodes[idea_index].num_of_eval = 1
            self.sub_logger.info(
                f"✅ An edge was added from node {parent_index} to node {children_index}\n"
                f"Score: {score}"
            )
            self.debug_logger.info(f"Parent: {self.insight_tree.nodes[parent_index].idea}")
            self.debug_logger.info(f"Child: {self.insight_tree.nodes[children_index].idea}")
            return idea_index
        else:
            self.sub_logger.info(
                f"❌ Failed to add edge from node {parent_index} to node "
                f"{children_index}"
            )
            return None


class GreedyMergingAlgorithm(BaseMergingAlgorithm):
    def merge(self, all_idea_indexes, pre_path, baseline_score):
        """
        Greedy node merging algorithm:
        1. Take the best node on the phase
        2. Add each node to it one by one
        (starting with the most successful and ending with the least successful).
        3. If there has been an increase in score, the remaining nodes will connect to the new merging node.
        4. Otherwise, we remove the added node from the ones we're looking at and try to join the rest of the nodes.
        """
        self.sub_logger.info("🔷 Start greedy merging algorithm")
        added_node_indexes_merging = [[] for _ in range(len(all_idea_indexes))]
        # Going from top (FE) to bottom (to modeling)
        for phase_index in range(len(all_idea_indexes) - 1, -1, -1):
            phase, filename, _ = self.config['phases'][phase_index]
            self.sub_logger.info(f"\t 🔷 Merging. Stage: {phase}")
            selected_node_indexes_in_phase = deque(all_idea_indexes[phase_index])

            if phase_index == len(all_idea_indexes) - 1:
                parent_index = None
            else:
                parent_index_phase = len(all_idea_indexes) - phase_index - 1
                parents = added_node_indexes_merging[parent_index_phase]
                if parents:
                    parent_index = parents[0]
                else:
                    parent_index = all_idea_indexes[phase_index + 1][0]

            id_merge = 0
            best_node_index = selected_node_indexes_in_phase[0]
            best_score = self.insight_tree.nodes[best_node_index].mean_score
            limit_iter = self.config['max_iteration_without_update_best_score']
            add_node_index = None

            while len(selected_node_indexes_in_phase) > 1:
                id_merge += 1
                node_index_1 = selected_node_indexes_in_phase.popleft()
                node_index_2 = selected_node_indexes_in_phase.popleft()

                idea1 = self.insight_tree.nodes[node_index_1].idea
                idea2 = self.insight_tree.nodes[node_index_2].idea

                debug_code, _, score = self.merge_2_ideas(
                    node_index_1=node_index_1, node_index_2=node_index_2,
                    phase_index=phase_index,
                    parent_index=parent_index,
                    pre_path=pre_path
                )

                if debug_code is not None:
                    # Logic depends on whether higher score is better
                    is_improvement = (self.is_higher_better and score > best_score) or \
                                     (not self.is_higher_better and score < best_score)

                    if phase_index == 0 and not is_improvement:
                        self.sub_logger.info(
                            f"\tMerge score: {score} is not an improvement over best score ({best_score}). Don't save merge"
                        )
                        limit_iter -= 1
                        if limit_iter == 0:
                            break
                        selected_node_indexes_in_phase.appendleft(node_index_1)
                        continue

                    index = self.insight_tree.add_idea(f'Merge: {idea1}\n{idea2}', debug_code, parent_index)
                    if phase_index == 0:
                        self.insight_tree.nodes[index].mean_score = score
                        self.insight_tree.nodes[index].num_of_eval = 1
                    else:
                        self.insight_tree.set_group_for_node("Merge ideas", index)

                    selected_node_indexes_in_phase.appendleft(index)
                    add_node_index = index
                    best_score = score
                    limit_iter = self.config['max_iteration_without_update_best_score']

                    self.sub_logger.info(f"\tMerge\n1: ({node_index_1}) {idea1}\n2: ({node_index_2}) {idea2}")
                    if phase_index == 0:
                        self.sub_logger.info(
                            f"\tMerge score: {best_score}. Update best_score to {best_score}. Save merge")

            if add_node_index is not None:
                added_node_indexes_merging[phase_index].append(add_node_index)

        return added_node_indexes_merging


class MergerByPartsAlgorithm(BaseMergingAlgorithm):
    def merge(self, all_idea_indexes, pre_path, baseline_score):
        self.sub_logger.info("🔷 Start merger by parts")
        added_node_indexes_merging = [[] for _ in range(len(all_idea_indexes))]

        # FE (Feature Engineering)
        phase_index = 1
        phase, _, _ = self.config['phases'][phase_index]
        fe_add_indexes = {}
        self.sub_logger.info(f"\t 🔷 Merging. Stage: {phase}")
        for id_merge in range(self.config['number_of_iterations_parents']):
            node_index_1, node_index_2 = self.greedily_choose_the_nodes(all_idea_indexes[phase_index])

            if node_index_1 is None:
                if id_merge == 0:
                    return added_node_indexes_merging
                else:
                    break

            idea1 = self.insight_tree.nodes[node_index_1].idea
            idea2 = self.insight_tree.nodes[node_index_2].idea
            debug_code, _, score = self.merge_2_ideas(
                node_index_1=node_index_1, node_index_2=node_index_2,
                phase_index=phase_index, parent_index=None, pre_path=pre_path
            )
            if debug_code is not None:
                add_node_index = self.insight_tree.add_idea(f'Merge: {idea1}\n{idea2}', debug_code, None)
                self.insight_tree.set_group_for_node("Merge ideas", add_node_index)
                added_node_indexes_merging[phase_index].append(add_node_index)
                fe_add_indexes[add_node_index] = (node_index_1, node_index_2)
                self.sub_logger.info(f"\tMerge\n1: ({node_index_1}) {idea1}\n2: ({node_index_2}) {idea2}")
                self.memory.add(f"{node_index_1}-{node_index_2}")
                self.memory.add(f"{node_index_2}-{node_index_1}")

        # Modeling
        phase_index = 0
        phase, _, _ = self.config['phases'][phase_index]
        self.sub_logger.info(f"\t 🔷 Merging. Stage: {phase}")

        # 1. Adding children of merging ideas from the previous stage
        self.sub_logger.info(f"\t\t 🔷 Merging. Adding children of merging ideas from the previous stage")
        for new_node_index, (right_node_idx, left_node_idx) in fe_add_indexes.items():
            right_node_best_children = self.insight_tree.select_the_best_children(
                parent_index=right_node_idx, is_higher_better=self.is_higher_better
            )
            idea_index_right = self.add_edge_between_nodes_from_different_phases(
                parent_index=new_node_index, children_index=right_node_best_children, pre_path=pre_path
            )
            if idea_index_right is not None:
                added_node_indexes_merging[phase_index].append(idea_index_right)

            left_node_best_children = self.insight_tree.select_the_best_children(
                parent_index=left_node_idx, is_higher_better=self.is_higher_better
            )
            idea_index_left = self.add_edge_between_nodes_from_different_phases(
                parent_index=new_node_index, children_index=left_node_best_children, pre_path=pre_path
            )
            if idea_index_left is not None:
                added_node_indexes_merging[phase_index].append(idea_index_left)

        # 2. Merge ideas in the remaining nodes
        self.sub_logger.info(f"\t\t 🔷 Merging. Merge ideas in the remaining nodes")
        nodes_score = [self.insight_tree.nodes[idx].mean_score for idx in all_idea_indexes[1]]
        weighted_scores = calculate_softmax_weights(nodes_score, self.is_higher_better)

        chosen_parent_indexes = random.choices(
            all_idea_indexes[1],
            k=min(self.config['number_of_selected_node_merging'], len(all_idea_indexes[1])),
            weights=weighted_scores
        )

        nodes = self.insight_tree.nodes
        for parent_index in chosen_parent_indexes:
            children = [child.index for child in nodes[parent_index].children]
            node_index_1, node_index_2 = self.greedily_choose_the_nodes(children)
            self.sub_logger.info(f"Nodes selected: {node_index_1}, {node_index_2}")
            if node_index_1 is None:
                continue

            idea1 = nodes[node_index_1].idea
            idea2 = nodes[node_index_2].idea

            debug_code, _, score = self.merge_2_ideas(
                node_index_1=node_index_1, node_index_2=node_index_2,
                phase_index=phase_index, parent_index=parent_index, pre_path=pre_path
            )

            if debug_code is not None:
                score1 = nodes[node_index_1].mean_score
                score2 = nodes[node_index_2].mean_score

                best_score = max(score1, score2) if self.is_higher_better else min(score1, score2)

                is_improvement = (self.is_higher_better and score > best_score) or \
                                 (not self.is_higher_better and score < best_score)

                if is_improvement:
                    add_node_index = self.insight_tree.add_idea(f'Merge: {idea1}\n{idea2}', debug_code, parent_index)
                    added_node_indexes_merging[phase_index].append(add_node_index)
                    self.insight_tree.nodes[add_node_index].mean_score = score
                    self.insight_tree.nodes[add_node_index].num_of_eval = 1
                    self.sub_logger.info(f"\tMerge\n1: ({node_index_1}) {idea1}\n2: ({node_index_2}) {idea2}")
                    comparison_symbol = ">" if self.is_higher_better else "<"
                    self.sub_logger.info(f"\tMerge score: {score} {comparison_symbol} {best_score}. Save merge")
                else:
                    comparison_symbol = "<=" if self.is_higher_better else ">="
                    self.sub_logger.info(
                        f"\tMerge score: {score} {comparison_symbol} best score ({best_score}). Don't save merge"
                    )

                self.memory.add(f"{node_index_1}-{node_index_2}")
                self.memory.add(f"{node_index_2}-{node_index_1}")

        return added_node_indexes_merging


class MergingManager:
    def __init__(self, config, is_higher_better, debugger: DebugAgentAPI,
                 coder: CodeAgentAPI, insighter: InsightAgentAPI, sub_logger, debug_logger):
        self.config = config
        self.insight_tree = insighter.insight_tree
        self.is_higher_better = is_higher_better

        algo_args = {
            "config": config,
            "is_higher_better": is_higher_better,
            "debugger": debugger,
            "coder": coder,
            "insighter": insighter,
            "sub_logger": sub_logger,
            "debug_logger": debug_logger
        }

        match self.config['algorithm_type_merging']:
            case "greedy_algorithm":
                self.algorithm = GreedyMergingAlgorithm(**algo_args)
            case "merger_by_parts":
                self.algorithm = MergerByPartsAlgorithm(**algo_args)
            case _:
                raise ValueError(f"Merging algorithm type ‘{config['algorithm_type_merging']}’ not found")

    def merge_ideas(self, pre_path: str, indexes_for_merging_modelling: list, baseline_score: float) -> list:
        indexes_for_merging_modelling = [
            node_index for node_index in indexes_for_merging_modelling
            if not isinstance(self.insight_tree.nodes[node_index], str)
        ]

        indexes_for_merging_modelling.sort(
            key=lambda node_index: self.insight_tree.nodes[node_index].mean_score,
            reverse=self.is_higher_better
        )

        all_idea_indexes = collect_sort_nodes_into_phases(
            indexes_for_merging_modelling,
            self.insight_tree,
            self.is_higher_better
        )

        self.algorithm._log_memory()
        added_node_indexes_merging = self.algorithm.merge(all_idea_indexes, pre_path, baseline_score)
        self.algorithm._log_memory()

        return added_node_indexes_merging