from typing import ParamSpec
from typing import Callable
from typing import Concatenate
from typing import Any
from typing import Union
from dataclasses import dataclass
from pylic.code_transformations import Tape
from pylic.code_transformations import get_tape
from pylic.predicates import Predicate
from pylic.predicates import Constants
from pylic.predicates import SolverFailedException
from pylic.predicates import CustomInterpreter
from pylic.tape import IfNode
import torch

RuntimeData = ParamSpec("RuntimeData")
Parameters = torch.Tensor


@dataclass(eq=True, frozen=True)
class SearchNode:
    predicate: Predicate
    parameters: Parameters
    children: list["SearchNode"]


ParametricFunction = Callable[Concatenate[Parameters, RuntimeData], Any]
FilteredNode = Union[IfNode, Predicate]


def concolic_planner(
        f: ParametricFunction,
        is_sat: Callable[[Parameters], bool],
        select_node: Callable[[SearchNode, list[SearchNode]], SearchNode],
        get_next_predicates: Callable[[Tape, Parameters], SearchNode],
        solver: Callable[
            [Predicate, Parameters, ParametricFunction],
            Parameters
            ],
        starting_parameters: Parameters,
        max_value: torch.Tensor,
        custom_interpreter: CustomInterpreter,
        get_child_parameters: Callable[[Tape, Parameters], Parameters],
        *args: RuntimeData.args,
        **kwargs: RuntimeData.kwargs,
        ) -> Parameters:
    # Initialize empty search tree
    search_tree = SearchNode(
            predicate=Constants.Top,
            parameters=starting_parameters,
            children=list()
        )
    explored_nodes = list()

    while True:
        # Select unexplored node in search tree
        n = select_node(search_tree, explored_nodes)
        print(n)

        # Mark node as explored
        explored_nodes.append(n)

        try:
            # Solve the node
            parameters = solver(n.predicate, n.parameters, f)
        except SolverFailedException:
            continue
        print("Solved a node")

        # Check if parameters satisfy the specification
        if is_sat(parameters):
            return parameters

        # Trace function with new parameters
        tape = get_tape(f, None, parameters, *args, **kwargs)

        # Get next set of predicates
        predicates = get_next_predicates(tape, parameters)

        # Get new predicates
        new_parameters = get_child_parameters(tape, parameters)

        for predicate in predicates:
            child = SearchNode(
                predicate=predicate,
                parameters=new_parameters,
                children=list(),
            )
            n.children.append(child)
