import json
import os
from utils.rag_retrieval import encode_text


class Node:
    def __init__(self, index, idea, code, parent=None, embedding_idea=None):
        self.index = index
        self.idea = idea
        self.code = code
        self.parent = parent
        self.children = []
        self.depth = parent.depth + 1 if parent else 1
        self.mean_score = None
        self.group = None
        self.num_of_eval = 0
        self.score_for_prediction = None
        self.embedding = embedding_idea
        self.predicted_score = None  # Needs for scoring model test
        self.description = ""

    def __str__(self):
        return f"\tIndex: {self.index};\n" \
               f"\tIdea: {self.idea};\n" \
               f"\tparent: {self.parent};\n" \
               f"\tNum of eval: {self.num_of_eval};\n" \
               f"\tDepth: {self.depth};\n" \
               f"\tNum children: {len(self.children)};\n" \
               f"\tScore: {self.mean_score}\n"

    def predict_score_info(self):
        return {
            'description': self.description,
            'predicted_score': self.predicted_score,
            'score': self.mean_score,
            "node_index": self.index,
            "idea": self.idea
        }

    def update_score(self, score: float):
        """
        Update the average score of the node using the moving average formula
        or sets the speed at the vertex if it was not previously defined

        :param score: newly obtained score (float)
        :return: None
        """
        self.num_of_eval += 1
        if self.mean_score is None:
            self.mean_score = score
        else:
            self.mean_score = (self.mean_score * (self.num_of_eval - 1) + score) / self.num_of_eval


class InsightTree:
    def __init__(self, retrieve_model):
        # Nodes are stored as a dictionary,
        # where keys are node indexes and values are references to the Node class.
        self.nodes = {}
        self.retrieve_model = retrieve_model
        self.current_index = 0

    def add_idea(self, idea: str, code: str, parent_index=None, embedding_idea=None):
        """
        Add a new idea to the tree

        :param idea: textual description of the idea
        :param code: idea code
        :param parent_index: parent idea index
        :return: Added idea index
        """
        new_index = self.current_index
        self.current_index += 1

        parent_node = None
        if parent_index is not None:
            if parent_index not in self.nodes:
                raise ValueError("Parent index does not exist")

            parent_node = self.nodes[parent_index]

            if parent_node.depth >= 3:
                raise ValueError("Maximum tree depth reached (3 levels)")

        if embedding_idea is None:
            embedding_idea = encode_text(idea, self.retrieve_model)

        new_node = Node(new_index, idea, code, parent_node, embedding_idea)
        if new_node.depth > 3:
            raise ValueError("Cannot exceed maximum tree depth of 3")

        self.nodes[new_index] = new_node

        if parent_node is not None:
            parent_node.children.append(new_node)

        return new_index

    def get_branch(self, target_node_index: int) -> list:
        """
        Get all the code and idea for this branch, starting from the transmitted vertex in the form

        :param target_node_index: node index from which the branch code and idea should be obtained
        :return: array of code and ideas for all vertices in this branch and above `target_node_index`
        """
        if target_node_index not in self.nodes:
            raise ValueError("Target index does not exist")

        branch = []
        current_node = self.nodes[target_node_index]

        while current_node is not None:
            branch.append({
                'idea': current_node.idea,
                'code': current_node.code
            })
            current_node = current_node.parent

        return list(reversed(branch))

    def get_all_code_in_branch(self, target_node_index: int) -> str:
        """
        Get all the code in this branch in string format

        :param target_node_index: node index from which the branch code should be obtained
        :return: All code in this branch
        """
        code = []
        node = self.nodes[target_node_index]

        code.append(node.code)
        while node.parent is not None:
            node = node.parent
            code.append(node.code)
        if code:
            return '\nprint("""""")\n'.join(code[::-1])

        raise ValueError("There's no code in the branch")

    def select_the_best_children(self, parent_index, is_higher_better):
        best_children = self.nodes[parent_index].children[0].index
        best_score = self.nodes[best_children].mean_score
        for children in self.nodes[parent_index].children:
            if is_higher_better and self.nodes[children.index].mean_score > best_score:
                best_score = self.nodes[children.index].mean_score
                best_children = children.index
            elif not is_higher_better and self.nodes[children.index].mean_score < best_score:
                best_score = self.nodes[children.index].mean_score
                best_children = children.index
        return best_children

    def set_code_for_node(self, code: str, node_index: int):
        """
        Sets the node code as the transmitted code

        :param code: the code that must be set for this node
        :param node_index: node index for which the code must be set
        :return: None
        """
        if isinstance(code, str):
            self.nodes[node_index].code = code
        else:
            raise Exception('Code is not string')

    def set_group_for_node(self, group: str, node_index: int):
        self.nodes[node_index].group = group

    def remove_node(self, node_index: int):
        """
        Removes a node from the tree by index.

        :param node_index: index of the node to be deleted
        :return: None
        """
        try:
            node_to_remove = self.nodes.pop(node_index)
            if node_to_remove.parent is not None:
                if node_to_remove.parent.index in self.nodes:
                    node_to_remove.parent.children.remove(node_to_remove)
        except KeyError:
            print(f"Warning: Node with index {node_index} not found. Nothing to remove.")
        except ValueError as e:
            print(f"Error during child removal: {e}")

    def backprop(self, node_index: int, score: float, debug_logger,
                 need_update_score_for_this_node: bool = True):
        """
        Update the average value of the score for all nodes
        that are in the same branch as the transmitted one.

        :param node_index: the index of the node from which the back propagation should be performed
        :param score: the resulting node score
        :param debug_logger: logger for debugging
        :param need_update_score_for_this_node: Is it necessary to update the score for this node? (bool)
        :return: None
        """
        node = self.nodes[node_index]

        if need_update_score_for_this_node:
            node.update_score(score)
        while node.parent is not None:
            node = node.parent
            node.update_score(score)
            debug_logger.info(node)

    def parse_final_nodes_to_json(self, path):
        selected_nodes = []
        for s in self.nodes.values():
            if s and s.depth == 2:
                selected_nodes.append(s.predict_score_info())

        with open(os.path.join(f"{path}", "nodes_scores.json"), 'w') as f:
            json.dump(selected_nodes, f, indent=4, ensure_ascii=False)

    def get_nodes_info(self, nodes=None):
        if nodes is None:
            nodes = self.nodes.items()

        nodes_list = []
        for node_id, node_obj in nodes:
            if node_obj == '':
                continue

            node_info = {
                "index": node_obj.index,
                "idea": node_obj.idea,
                "parent_index": node_obj.parent.index if node_obj.parent else None,
                "children_indexes": [child.index for child in node_obj.children],
                "depth": node_obj.depth,
                "mean_score": node_obj.mean_score,
                "num_of_eval": node_obj.num_of_eval,
                "description": node_obj.description
            }
            nodes_list.append(node_info)
        return nodes_list

    def to_json(self, file_path: str):
        """
        Saves the entire InsightTree structure to a JSON file.

        The JSON structure will represent the tree hierarchy with nodes, their ideas,
        code, and parent/children relationships.

        :param file_path: The file path where the JSON file will be saved.
        :return: None
        """
        tree_data = {}

        nodes_list = self.get_nodes_info()
        tree_data["nodes"] = nodes_list
        tree_data["current_index"] = self.current_index

        with open(file_path, 'w', encoding='utf-8') as f:
            json.dump(tree_data, f, indent=4, ensure_ascii=False)
