"""Predicates are the means by which desirable properties of the
execution of a program are described. For instance, a control problem is
described through predicates.

Because predicates encode properties of the execution of a program, their the
truth value is a function of the execution of the program. The execution of
a program is stored as a `pylic.tape.Tape`.

Often, a predicate only needs to account for a subset of an execution trace.
For example, "collide with obstacle 2 after timestep 100"
only concerns the nodes in the trace where `obstacle==2` and `t>100`.
This is expressed by defining a `pylic.predicates.Filter`.
"""
from dataclasses import dataclass
from typing import Callable, Union
from pylic.tape import Tape
from pylic.tape import IfNode
from enum import Enum
from enum import auto
import torch
import torch.nn.functional
from typing import Literal
import dill


CustomFunction = Callable[[Tape], torch.Tensor]
"""Custom functions take a tape and return a `torch.Tensor`."""
CustomFilter = Callable[[Tape, int], bool]
"""Custom filters take a tape, a node index and return `True` if and only if
the node of the given index should be in the output trace of the filter."""


class SolverFailedException(Exception):
    """Raised when a predicate solver fails to find a solution."""

    def __init__(self, message: str, final_parameters=None):
        self.message = message
        self.final_parameters = final_parameters


class Constants(Enum):
    """Constants in the predicate language."""
    input_trace = auto()
    """Constant that represents the original unfiltered execution trace."""
    Top = auto()
    """Constant that represents a predicate that is always true."""


@dataclass(eq=True, frozen=True)
class Negation:
    """Negation of a predicate."""
    operand: "Predicate"


@dataclass(eq=True, frozen=True)
class Conjunction:
    """Conjunction of a list of predicates."""
    operands: list["Predicate"]


@dataclass(eq=True, frozen=True)
class Disjunction:
    """Disjunction of a list of predicates."""
    operands: list["Predicate"]


@dataclass(eq=True, frozen=True)
class IfOr:
    """True if and only if there is a `pylic.tape.IfNode` `node` in the `trace`
    such that `node.value > 0`."""
    trace: "Trace"


@dataclass
class Filter:
    """Filter trace with a filter. The filter can be specified as a
    string ID, or as a Python callable.

    If the filter is identified with a string, then the implementation
    is provided on the call to `pylic.predicates.predicate_interpreter`.

    Otherwise, it is recommended the callable has no bound variables, as it
    will be serialized using `dill`. This is due to limitations of the Python
    interpreter in pickling lambda expressions, which is necessary for
    multiprocessing.

    Two filters are equal if and only if the custom filters are equal strings
    and the input traces are equal (in general equality of arbitrary functions
    is undecidable).
    """
    custom_filter: str | CustomFilter
    trace: "Trace"

    def __eq__(self, o):
        if not isinstance(o, Filter):
            return False
        if not isinstance(self.custom_filter, str):
            return False
        if not isinstance(o.custom_filter, str):
            return False
        if self.custom_filter != o.custom_filter:
            return False
        if self.trace != o.trace:
            return False
        return True

    def __post_init__(self):
        # Serializing lambdas is hell in Python. We use dill
        # which has unexpected behaviour when computing closures
        # (serializing lambdas defined in a loop leads to inconsistent
        # lambdas). So we serialize as soon as they are passed to this
        # class
        if isinstance(self.custom_filter, str):
            custom_filter = self.custom_filter
            filter_is_lambda = False
        else:
            custom_filter = dill.dumps(self.custom_filter)
            filter_is_lambda = True
        self._custom_filter = custom_filter
        self._filter_is_lambda = filter_is_lambda

    def __getstate__(self):
        # If the filter was stored as a string, then we use the string as-is
        # Otherwise, we assume it's a lambda expression, and we use the
        # inspect module to get the source-code, which we use to serialize.
        return dict(
            trace=self.trace,
            custom_filter=self._custom_filter,
            filter_is_lambda=self._filter_is_lambda,
        )

    def __setstate__(self, d):
        self.trace = d["trace"]
        if d["filter_is_lambda"]:
            self._custom_filter = d["custom_filter"]
            self._filter_is_lambda = d["filter_is_lambda"]
            custom_filter = dill.loads(d["custom_filter"])
            self.custom_filter = custom_filter
        else:
            self._custom_filter = d["custom_filter"]
            self._filter_is_lambda = d["filter_is_lambda"]
            self.custom_filter = d["custom_filter"]


@dataclass
class FunctionCall:
    """Represents a call to a custom function. The function is identified
    with an ID string, or as a Python callable.

    If the function is identified with a string, then the implementation
    is provided on the call to `pylic.predicates.predicate_interpreter`.

    Otherwise, it is recommended the callable has no bound variables, as it
    will be serialized using `dill`. This is due to limitations of the Python
    interpreter in pickling lambda expressions, which is necessary for
    multiprocessing.

    Two function calls are equal if and only if the custom filters are equal
    strings and the input traces are equal (equality of arbitrary functions
    is undecidable).
    """
    custom_function: str | CustomFunction
    trace: "Trace"

    def __eq__(self, o):
        if not isinstance(o, FunctionCall):
            return False
        if not isinstance(self.custom_function, str):
            return False
        if not isinstance(o.custom_function, str):
            return False
        if self.custom_function != o.custom_function:
            return False
        if self.trace != o.trace:
            return False
        return True

    def __post_init__(self):
        # Serializing lambdas is hell in Python. We use dill
        # which has unexpected behaviour when computing closures
        # (serializing lambdas defined in a loop leads to inconsistent
        # lambdas). So we serialize as soon as they are passed to this
        # class
        if isinstance(self.custom_function, str):
            custom_function = self.custom_function
            function_is_lambda = False
        else:
            custom_function = dill.dumps(self.custom_function)
            function_is_lambda = True
        self._custom_function = custom_function
        self._function_is_lambda = function_is_lambda

    def __getstate__(self):
        # If the function was stored as a string, then we use the string as-is
        # Otherwise, we assume it's a lambda expression, and we use the
        # inspect module to get the source-code, which we use to serialize.
        return dict(
            trace=self.trace,
            custom_function=self._custom_function,
            function_is_lambda=self._function_is_lambda,
        )

    def __setstate__(self, d):
        self.trace = d["trace"]
        if d["function_is_lambda"]:
            self._custom_function = d["custom_function"]
            self._function_is_lambda = d["function_is_lambda"]
            custom_function = dill.loads(d["custom_function"])
            self.custom_function = custom_function
        else:
            self._custom_function = d["custom_function"]
            self._function_is_lambda = d["function_is_lambda"]
            self.custom_function = d["custom_function"]


Real = Union[
    torch.Tensor,
    FunctionCall,
]
"""Values that represent a real number."""


Trace = Union[
    Filter,
    Literal[Constants.input_trace],
]


@dataclass(eq=True, frozen=True)
class LessThan:
    """Represents the predicate `f < g`."""
    left: Real
    right: Real


Predicate = Union[
    Conjunction,
    Disjunction,
    IfOr,
    Negation,
    LessThan,
    Literal[Constants.Top],
]
"""Predicate represent desired properties of the execution of a program."""


def less_than(a, b):
    """Return greater than 0 if and only if a < b."""
    return b - a


def greater_than(a, b):
    """Return greater than 0 if and only if a > b. Wrapper
    around `less_than(b, a)`."""
    return less_than(b, a)


def negation(value):
    """Return the input value multiplied by -1."""
    return -value


def conjunction(*values):
    """Return greater than 0 if and only if all values are greater than
    zero."""
    if len(values) == 0:
        raise ValueError("Cannot evaluate conjunction on zero arguments!")
    if len(values) == 1:
        return values[0]
    min_value = min(values)
    if min_value <= 0.0:
        # if real output is a tensor and does not require grad,
        # solve some other value that do requires grad
        # This is required because some values in formulas will not be
        # a function
        # of the parameters (and so will not require_grad) until some other
        # part of the formula becomes true. Consider a predicate like
        # "collide with box AND move box to the right", the second clause
        # is a function of a control signal only if the first clause is
        # satisfied.
        grad_values = [
            o for o in values
            if o <= 0.0 and isinstance(o, torch.Tensor) and o.requires_grad
        ]
        if len(grad_values) == 1:
            return grad_values[0]
        elif len(grad_values) > 1:
            return torch.min(torch.stack(grad_values))
    return min_value


def disjunction(*values):
    """Return greater than 0 if and only if any of the values is greater than
    zero."""
    if len(values) == 0:
        raise ValueError("Cannot evaluate disjunction on zero arguments!")
    if len(values) == 1:
        return values[0]
    max_value = max(values)
    if max_value <= 0.0:
        grad_values = [
            o for o in values
            if o <= 0.0 and isinstance(o, torch.Tensor) and o.requires_grad
        ]
        if len(grad_values) == 1:
            return grad_values[0]
        elif len(grad_values) > 1:
            return torch.max(torch.stack(grad_values))
    return max_value


def trace_interpreter(
        program: Trace,
        input_tape: Tape,
        custom_filters: dict[str, CustomFilter],
        ) -> Tape:
    """Returns the output of the trace expression."""
    if isinstance(program, Filter):
        if isinstance(program.custom_filter, str):
            custom_filter = custom_filters[program.custom_filter]
        else:
            custom_filter = program.custom_filter
        tape = trace_interpreter(
            program=program.trace,
            input_tape=input_tape,
            custom_filters=custom_filters,
        )
        new_tape = Tape([
            node
            for i, node in enumerate(tape)
            if custom_filter(tape, i)
        ])
        return new_tape
    elif program is Constants.input_trace:
        return input_tape


def real_interpreter(
        program: Real,
        input_tape: Tape,
        custom_functions: dict[str, CustomFunction],
        custom_filters: dict[str, CustomFilter],
        ) -> torch.Tensor:
    """Returns the output of the real-valued expression."""
    if isinstance(program, torch.Tensor):
        return program
    elif isinstance(program, FunctionCall):
        if isinstance(program.custom_function, str):
            custom_function = custom_functions[program.custom_function]
        else:
            custom_function = program.custom_function
        current_trace = trace_interpreter(
            program=program.trace,
            input_tape=input_tape,
            custom_filters=custom_filters,
        )
        value = custom_function(current_trace)
        return value


def predicate_interpreter(
        predicate: Predicate,
        input_tape: Tape,
        max_value: torch.Tensor,
        custom_functions: dict[str, CustomFunction],
        custom_filters: dict[str, CustomFilter],
        ) -> torch.Tensor:
    """Returns the robustness measure of the predicate on the given tape,
    according to quantitative semantics. A predicate is said to be 'True' for a
    given trace if its robustness measure is greater than zero.
    """
    if isinstance(predicate, Conjunction):
        values = [
            predicate_interpreter(
                predicate=operand,
                input_tape=input_tape,
                max_value=max_value,
                custom_functions=custom_functions,
                custom_filters=custom_filters,
            )
            for operand in predicate.operands
        ]
        return conjunction(*values)
    elif isinstance(predicate, Disjunction):
        values = [
            predicate_interpreter(
                predicate=operand,
                input_tape=input_tape,
                max_value=max_value,
                custom_functions=custom_functions,
                custom_filters=custom_filters,
            )
            for operand in predicate.operands
        ]
        return disjunction(*values)
    elif isinstance(predicate, Negation):
        value = predicate_interpreter(
            predicate=predicate.operand,
            input_tape=input_tape,
            max_value=max_value,
            custom_functions=custom_functions,
            custom_filters=custom_filters,
        )
        return negation(value)
    elif isinstance(predicate, IfOr):
        tape = trace_interpreter(
            program=predicate.trace,
            input_tape=input_tape,
            custom_filters=custom_filters,
        )
        values = [
            # Nodes only guarantee to `SupportsFloat, so cast to tensor if necessary
            node.value if isinstance(node.value, torch.Tensor) else torch.tensor(float(node.value))
            for node in tape
            if isinstance(node, IfNode)
        ]
        return disjunction(*values)
    elif isinstance(predicate, LessThan):
        left = real_interpreter(
            program=predicate.left,
            input_tape=input_tape,
            custom_functions=custom_functions,
            custom_filters=custom_filters,
        )
        right = real_interpreter(
            program=predicate.right,
            input_tape=input_tape,
            custom_functions=custom_functions,
            custom_filters=custom_filters,
        )
        return less_than(left, right)
    elif predicate is Constants.Top:
        return max_value
