"""Often, control problems result in non-convex optimization problems that are
very difficult to solve in practice. Pylic naturally lends itself to tackling
problems that can be decomposed into discrete parts, defining a predicate for
each part and them incrementally. This is analogous to many traditional
planning approaches. In simple problems the problem decomposition is known is
advance, and so planning becomes as easy as solving a fixed sequence of
predicates.

For more complicated problems a particular decomposition may not be known in
advance, and this is where a planning approach comes in. In
this case, we can explore multiple sequences of predicates until we discover a
decomposition that solves the problem. In particular, a natural way to encode
strategies that decompose the problem is to frame the search for a
decomposition as a tree search. For this, you only need to define how to pick a
node from the tree, and how to expand the node if it can be solved. To follow
this approach, you might want to look at `concolic_planner`.
"""
from typing import Generic, ParamSpec, TypeVar
from typing import Callable
from typing import Concatenate
from typing import Any
from typing import Union
from dataclasses import dataclass
from pylic.code_transformations import Tape
from pylic.code_transformations import get_tape
from pylic.predicates import Predicate
from pylic.predicates import Constants
from pylic.predicates import SolverFailedException
from pylic.tape import IfNode
import traceback
import torch

RuntimeData = ParamSpec("RuntimeData")
Input = TypeVar("Input")


@dataclass(eq=True, frozen=True)
class SearchNode(Generic[Input]):
    """Represents a generic search node in a planning tree."""
    predicate: Predicate
    parameters: Input
    children: list["SearchNode"]


FilteredNode = Union[IfNode, Predicate]


def select_node_depth_first(
        search_tree: SearchNode,
        explored_nodes: list[SearchNode],
        depth_bound: int,  # longest branch length
        ) -> SearchNode:
    """Select a node in the tree through depth-first
    search. This function is provided as an example showing how to design
    a tree-search with pylic. This example assumes that the parameters support
    the `!=` operator."""
    if depth_bound+1 == 0:
        raise ValueError("Cannot find valid node to explore")
    # If node is leaf, select it if and only if it has not
    # been explored
    node = search_tree
    if len(node.children) == 0:
        # Check if node has been explored before
        for explored_node in explored_nodes:
            if node.predicate == explored_node.predicate:
                if isinstance(node.parameters, torch.Tensor):
                    parameters_equal = node.parameters.equal(explored_node.parameters)
                else:
                    parameters_equal = node.parameters == explored_node.parameters
                if parameters_equal:
                    raise ValueError("Cannot find valid node to explore")
        return node
    # Search depth first for a non-explored leaf
    for child in search_tree.children:
        try:
            selected_node = select_node_depth_first(
                search_tree=child,
                explored_nodes=explored_nodes,
                depth_bound=depth_bound-1,
            )
            return selected_node
        except ValueError:
            continue
    raise ValueError("Cannot find valid node to explore")


def concolic_planner(
        f: Callable[Concatenate[Input, RuntimeData], Any],
        is_sat: Callable[[Input], bool],
        select_node: Callable[[SearchNode, list[SearchNode]], SearchNode],
        get_next_predicates: Callable[
            [Tape, Input, Predicate],
            list[Predicate]
        ],
        solver: Callable[
            [Predicate, Input],
            Input
            ],
        starting_parameters: Input,
        get_child_parameters: Callable[
            [Tape, Input],
            Input
        ],
        verbose: bool,
        *args: RuntimeData.args,
        **kwargs: RuntimeData.kwargs,
        ) -> Input:
    """Generic template to search in a space of sequences of sub-problems that
    solve a satisfaction problem over the execution of a program. The search
    space for these problems is assumed to consist of input parameters to the
    program. That is, we want to find input parameters to a program
    that satisfy a constraint, and we have a way to guide the search
    by framing the problem as a tree search over easier constraint problems.
    Each constraint problems is defined by a predicate.

    :param f: Program for which we want to find input parameters.
    :param is_sat : Constraint that we want to satisfy.
    :param select_node: Function to pick the next node in the search tree to explore, given a list of previously explored nodes (including nodes not in the tree because they could not be solved).
    :param get_next_predicates: Function that returns new nodes to expand the tree given an execution trace induced by some input which satisfies a predicate.
    :param solver: Solver which returns parameters that satisfy the given predicate.
    :param starting_parameters: Input which defines the start of the search.
    :param get_child_parameters: Function that given an execution trace induced by an input returns the predicates that children nodes should use as starting parameters.
    :param verbose: if `True`, then output status updates to stdout.
    """
    # Initialize empty search tree
    search_tree = SearchNode(
            predicate=Constants.Top,
            parameters=starting_parameters,
            children=list()
        )
    explored_nodes = list()

    while True:
        # Select unexplored node in search tree
        n = select_node(search_tree, explored_nodes)
        if verbose:
            print(f"Search node predicate: {n.predicate}")

        # Mark node as explored
        explored_nodes.append(n)

        try:
            # Solve the node
            parameters = solver(n.predicate, n.parameters)
        except SolverFailedException:
            print(traceback.format_exc())
            print("Failed to solve node!")
            continue

        if verbose:
            print("Solved node")

        # Check if parameters satisfy the specification
        if is_sat(parameters):
            return parameters

        # Trace function with new parameters
        tape = get_tape(f, None, parameters, *args, **kwargs)

        # Get next set of predicates
        predicates = get_next_predicates(tape, parameters, n.predicate)

        # Get new predicates
        new_parameters = get_child_parameters(tape, parameters)

        for predicate in predicates:
            child = SearchNode(
                predicate=predicate,
                parameters=new_parameters,
                children=list(),
            )
            n.children.append(child)
