"""
A thought structure to support the thought rollback.
"""

import logging
from typing import List


from llmpebase.model.thought_structure import trees
from llmpebase.model.thought_structure.structure_generic import BasicNode


class ThoughtRollbackStructure(trees.DFGTreeThoughtStructure):
    """
    A thought structure to perform the adaptive reasoning process by
    continuously rolling back of thoughts.
    To facilitate the reasoning, the basic structure is built upon a tree
    with depth-wise growth manner.
    """

    def __init__(
        self,
        thought_model,
        model_config,
        logging_config,
        visualizer,
    ):
        super().__init__(thought_model, model_config, logging_config, visualizer)

        # Get the configuration
        config = model_config["thought_structure"]
        # Get the maximum number of rollbacks generated by this node.
        # This is to avoid that one node creates too many
        # rollbacks to other nodes.
        # The rollback here means rollback from the node 'to' other nodes
        self.num_max_rollbacks_to = config["num_max_rollbacks_to"]

        # Get the maximum number of rollbacks received by this node.
        # This is to avoid that one node received too many
        # rollbacks from other nodes.
        # The rollback here means rollback from the node 'to' other nodes
        self.num_max_rollbacks_incoming = config["num_max_rollbacks_incoming"]

        # Get the maximum number of rollbacks that the node in one chain can
        # receive. This is to avoid that one node in a chain has many rollbacks
        # from other nodes of the chain.
        # For example, 1 -> 2 -> 3 -> 4 -> 5... If 3 has an error, 4, 5 ... will
        # continuously roll back to 3, making 3 has many rollbacks.
        self.num_max_rollbacks_from_chain = config["num_max_rollbacks_from_chain"]

        # Get the maximum number of rollbacks that are allowed to be existed in a
        # chain. This is to avoid that one chain has too many rollbacks.
        # For instance, 1 -> 2 -> 3 -> 4 -> 5.., If 3 is absolutely wrong, any
        # subsequent will generate a rollback to 2 toward revision. This will
        # introduce unnecessary complexity to the reasoning.
        self.num_max_chain_rollbacks = config["num_max_chain_rollbacks"]

        self.num_max_solutions = config["num_max_solutions"]

        # A variable to track which node is being rolled back to
        # thus creating a new reasoning path from it.
        self.rolling_back_state = None

        # A variable to determine whether to also rollback the
        # experience, i.e., error analysis, to the to be rolled back node to
        # facilitate the subsequent reasoning.
        self.do_experience_rollback = config["do_experience_rollback"]

    def set_node_growth(self, node_id: str):
        """Set the node to be the growable one."""

        # Need to know that before setting the node growth,
        # the position of the node has been determined

        # Get all edges of one node
        node_edges = list(self.graph.edges(node_id))
        # Get the node ids that are the next steps of the reasoning step of the node
        forward_edges = [edge for edge in node_edges if int(edge[0]) < int(edge[1])]
        # Get edges that are the rollbacks generated by the node
        # The rollback here means rollback from the node 'to' other nodes
        rollback_edges = self.get_outgoing_rollbacks(node_id)

        # Close growth of the node as this node has enough children
        # or has enough rollbacks
        if len(forward_edges) >= self.num_next_steps:
            # Set the node to be un-growable
            self.node_pool[node_id].set_growth("Un-growable")

        # For the sink nodes, when its #rollback does not reach the limit,
        # open its growth
        # We only need to set the growth for the sink node as the intermediate
        # node is default to be growable and will be processed in the get_grow_node
        # part
        if self.is_node_sink(node_id):
            if len(rollback_edges) < self.num_max_rollbacks_to:
                # Open the growth of the node as the edge is the rollback edge
                self.node_pool[node_id].set_growth("Growable")
            else:
                # Set the node to be un-growable
                self.node_pool[node_id].set_growth("Un-growable")

    def get_grow_node(self):
        """Get the node to be grown next relying on the reasoning condition."""
        # Get the node based on depth-wise search
        node = super().get_grow_node()
        # Set the rolling back node to be None so that
        # clean the rollback state
        self.rolling_back_state = None

        if node is None:
            return node

        # Collect the number of sink nodes in the current structure
        # When the number of sink nodes reaches the limit, no further
        # reasoning will be performed
        sink_nodes = self.get_sink_nodes()
        print(f"Number of solutions: {len(sink_nodes)}")
        if len(sink_nodes) >= self.num_max_solutions:
            return None

        # Get the reasoning path from the root to the node

        nodes = self.get_node_path(
            src_node_id=self.root.identity, dst_node_id=node.identity
        )
        # When there is no thought but only the root in the structure,
        # no rollback is needed
        if len(nodes) - 1 == 0:
            return node

        # Set the index of the rollback step that the node will rollback to
        rollback_step_idx = None
        # The responses of the LLM
        # The bad steps' indexes
        rollback_result = None
        # The analysis of the bad steps
        analysis = None
        # Get the rollbacks of the node. The rollback here is
        # the rollback from the node to other nodes
        outgoing_rollbacks = self.get_outgoing_rollbacks(node.identity)

        # Get the rollbacks of the chain
        num_chain_rollbacks = self.get_chain_rollbacks(
            src_id=self.root.identity, dst_id=node.identity
        )

        # Once
        #  1. the node has not reached the limit of rollbacks,
        #  2. the chain has not reached the limit of rollbacks,
        # generate the rollback condition to see whether to perform rollback
        # and which step to rollback to
        if (
            len(outgoing_rollbacks) < self.num_max_rollbacks_to
            and num_chain_rollbacks <= self.num_max_chain_rollbacks
        ):
            print("Try Rollback for Node ", node.identity)
            # Get the rollback condition
            (
                bad_step_idxes,
                rollback_result,
                analysis,
            ) = self.thought_model.generate_rollback(thought_chain=nodes)

            # Set the indexes of bad steps
            bad_step_idxes = [] if bad_step_idxes is None else bad_step_idxes
            print("bad_step_idxes: ", bad_step_idxes)
            # Visit the chain to get all rollbacks to the node
            for step_idx in bad_step_idxes:
                # Rollback to one previous node of the bad node
                rollback_to = nodes[step_idx - 1]
                # Get the rollbacks of the rollback node. The rollback here is
                # the ones received by the node
                node_chain_rollbacks = self.get_node_rollbacks_from_chain(
                    node.identity, rollback_to_id=rollback_to.identity
                )
                # Get the rollbacks received by the node. The rollback here is
                # the ones received by the node
                incoming_rollbacks = self.get_incoming_rollbacks(rollback_to.identity)
                print(f"num_chain_rollbacks: {num_chain_rollbacks}")
                print(
                    f"{rollback_to.identity} node_chain_rollbacks: {node_chain_rollbacks}"
                )
                print(
                    f"{rollback_to.identity} incoming_rollbacks: {incoming_rollbacks}"
                )
                print(
                    "Node %s receives %d rollbacks from the chain"
                    % (rollback_to.identity, len(incoming_rollbacks))
                )
                print(
                    "Node %s receives %d rollbacks"
                    % (rollback_to.identity, len(incoming_rollbacks))
                )
                if (
                    len(node_chain_rollbacks) <= self.num_max_rollbacks_from_chain
                    and len(incoming_rollbacks) <= self.num_max_rollbacks_incoming
                ):
                    print("Rollback to idx: ", rollback_step_idx)
                    rollback_step_idx = step_idx
                    break
        print(rollback_step_idx)
        # Once within the generated bad steps, there is one step that
        # the node can rollback to, perform the rollback
        if rollback_step_idx is not None:
            # Rollback to one previous node of the bad node
            rollback_node = nodes[rollback_step_idx - 1]
            logging.info(
                "Roll back from node %s (Step %s) to node %s (Step %s)",
                node.identity,
                node.step_idx,
                rollback_node.identity,
                rollback_node.step_idx,
            )

            # Set the rollback node to be growable there by allowing to
            # generate new reasoning path from it
            self.node_pool[rollback_node.identity].set_growth("Growable")

            edge_id = self.generate_edge_id(node.identity, rollback_node.identity)
            # Create the unique for this rollback edge
            n_rollbacks = len(self.get_outgoing_rollbacks(node.identity))
            edge_id = f"{edge_id}_R{n_rollbacks+1}"

            new_edge = self.create_edge(
                edge_id=edge_id,
                edge_type="Rollback",
                src_node_id=node.identity,
                dst_node_id=rollback_node.identity,
                reasoning=self.thought_model.prompter.rollback_controller_prompt,
                evaluation=self.thought_model.prompter.rollback_analysis_prompt,
                edge_score=1.0,
                auxiliary={
                    "AnalysisSteps": self.thought_model.prompter.organize_chain_prompt(
                        nodes[1:],
                        with_step_idx=True,
                        with_flag=False,
                        with_evaluation_score=False,
                    ),
                    "RollbackAnalysis": analysis,
                    "RollbackResult": rollback_result,
                    "RollbackCondition": f"Error in Step {rollback_step_idx}",
                },
            )
            # Add the edge to the pool
            self.edge_pool[edge_id] = new_edge
            # Add the edge between the node and the rollback node
            self.graph.add_edge(
                node.identity,
                rollback_node.identity,
                edge_type="Rollback",
                edge_id=edge_id,
                weight=2.0,
            )
            self.rolling_back_state = {"node": rollback_node, "rollback_edge": edge_id}

            # Add the experience obtained from this rollback to the node
            # Save as the edge_id: experience
            if self.do_experience_rollback:
                if edge_id not in rollback_node.auxiliary:
                    rollback_node.auxiliary[edge_id] = new_edge.auxiliary

            return rollback_node

        else:  # Address the special case in the sink node
            # If the sink node has reached the limit of rollbacks,
            # it should be set to be un-growable so that no further reasoning
            # will be performed on it
            if self.is_node_sink(node.identity):
                # Close the growth of the node as the node does not need
                # to be rolled back
                self.node_pool[node.identity].set_growth("Un-growable")

                node = self.get_grow_node()
                return node

        return node

    def add_node(
        self,
        thought: str,
        prev_node_id: str,
        thought_score: float = 1.0,
        edge_weight: float = 1.0,
        **kwargs,
    ) -> str:
        """Adding one node to the tree."""

        node_id = super().add_node(
            thought=thought,
            prev_node_id=prev_node_id,
            thought_score=thought_score,
            edge_weight=edge_weight,
            **kwargs,
        )

        # Check whether the new node is created from the node
        # that has been rolled back to
        if (
            self.rolling_back_state is not None
            and self.rolling_back_state["node"].identity == prev_node_id
        ):
            rollback_edge_id = self.rolling_back_state["rollback_edge"]
            # Add more information to the reasoning edge by allowing
            # one to know that the node is created from which rollback,
            # i.e., From Rollback
            edge_id = self.generate_edge_id(prev_node_id, node_id)
            edge = self.edge_pool[edge_id]
            edge.auxiliary["FromRollback"] = rollback_edge_id

            if self.do_experience_rollback:
                edge.auxiliary["RollbackExperience"] = self.edge_pool[
                    rollback_edge_id
                ].auxiliary

            self.edge_pool[edge_id] = edge

            # Add the corresponding information to the edge of the graph
            # The return edge data of graph[prev_node_id][node_id] will be a tuple
            # in which key is the edge idx while the value is a dict containing the
            # attributes of the edge
            graph_edge = self.graph[prev_node_id][node_id]
            graph_edge[0]["FromRollback"] = rollback_edge_id

        return node_id

    def generate_next_thoughts(self, thought_path: List[BasicNode]):
        """Generate the next thoughts for the node_id."""

        # Get the edges of the path
        path_edges = self.get_path_edges(thought_path)

        # Generate and then evaluate the next thoughts
        thoughts, gen_prompt = self.thought_model.generate_thoughts(
            thought_chain=thought_path,
            num_thoughts=self.num_next_steps,
            thought_edges=path_edges,
            rollback_state=self.rolling_back_state,
        )
        return thoughts, gen_prompt

    def get_outgoing_rollbacks(self, node_id: str):
        """Get the outgoing rollbacks of one node."""
        # Get all edges of one node
        node_edges = list(self.graph.edges(node_id))
        # Get the edges that are the generated rollbacks from the node to
        # other nodes
        return [edge for edge in node_edges if int(edge[0]) >= int(edge[1])]

    def get_incoming_rollbacks(self, node_id: str):
        """Get the incoming rollbacks of one node."""
        # Get all edges of the graph node
        node_edges = list(self.graph.edges())
        # Get the edges that are the rollbacks received by the node
        return [
            edge
            for edge in node_edges
            if int(edge[1]) == int(node_id) and int(edge[0]) >= int(edge[1])
        ]

    def get_node_rollbacks_from_chain(self, node_id: str, rollback_to_id: str):
        """Geth a node's rollbacks from a chain."""
        # Get the reasoning path
        chain = self.get_node_path(src_node_id=rollback_to_id, dst_node_id=node_id)
        # Visit this chain to get all rollbacks whose destination is the
        # rollback_to_id
        to_node_rollbacks = [
            edge
            for node in chain
            for edge in self.get_outgoing_rollbacks(node.identity)
            if int(edge[1]) == rollback_to_id
        ]

        return to_node_rollbacks

    def get_chain_rollbacks(self, src_id: str, dst_id: str):
        """Get the #rollbacks from src_id to dst_id."""
        # Get the reasoning path
        chain = self.get_node_path(src_node_id=src_id, dst_node_id=dst_id)
        num_rollbacks = 0
        for node in chain:
            num_rollbacks += len(self.get_outgoing_rollbacks(node.identity))

        return num_rollbacks
