"""A tape is a list of the program states that were encountered while executing
a piece of Python code.

The act of recording an execution into a tape requires transforming the code,
and this happens automatically when evaluating a predicate. Should you wish to
extract a tape outside of the context of predicate evaluation, check
pylic.code_transformations.get_tape.
"""

from collections import UserList
from typing import Any
from typing import Union
from typing import TypeVar
from typing import Generic
from dataclasses import field
from dataclasses import dataclass
from typing import SupportsFloat
from enum import Enum
from enum import auto


VariablesInScope = tuple[tuple[str, Any], ...]
"""The program state recorded in the nodes of a trace consists of tuples
including variable names and their values at the moment of recording."""

T = TypeVar("T")


class NodeTypes(Enum):
    """Node type enum for filter operations."""
    FOR = auto()
    RETURN = auto()
    IF = auto()


NodeID = (int|str)
"""Node identifiers are either derived from the position of the node in the
source code, in which case they are numbers, or supplied by the user, in which
case they are strings."""


@dataclass(frozen=True, eq=True)
class TraceNode:
    """Base class for nodes in a trace."""
    id: NodeID
    for_vars: list[str]
    variables_in_scope: VariablesInScope = field(compare=False, hash=False)

    @property
    def program_state(self) -> dict[str, Any]:
        program_state = dict(self.variables_in_scope)
        return program_state


@dataclass(frozen=True, eq=True)
class IfNode(TraceNode):
    """Represents the execution of an if-statement execution, storing the
    program state before executing the if-statement:

    :param id: identifies the if-statement in the program code.
    :param value: represents the truth value of the condition. It can be a boolean, a float, or a torch tensor. An if-statement evaluates to True if its value is greater then zero.
    :param variables_in_scope: list of `(var, value)` *before* the if-statement.
    :param for_vars: variables used as targets of all for-loops *in this program path*, including the for-loop represented by the node.
    """
    id: NodeID
    value: SupportsFloat
    for_vars: list[str]
    variables_in_scope: VariablesInScope = field(compare=False, hash=False)


@dataclass(frozen=True, eq=True)
class ForNode(TraceNode):
    """Represents the start of the execution of the inner body of a for-loop,
    storing the program state at the start of each loop:
    :param id: identifies the for-statement in the program code. Matches a corresponding `ForNodeEnd`.
    :param for_vars: variables used as targets of all for-loops *in this program path*, including the for-loop represented by the node.
    :param variables_in_scope: program variables before execution of the for-loop body, and for-loop variables.
    """
    id: NodeID
    for_vars: list[str]
    variables_in_scope: VariablesInScope = field(compare=False, hash=False)


@dataclass(frozen=True, eq=True)
class ForNodeEnd(TraceNode):
    """Represents the end execution of the inner body of a for-loop, storing the
    program state at the end of each loop:
    :param id: identifies the for-statement in the program code. Matches a corresponding `ForNode`.
    :param for_vars: variables used as targets of all for-loops *in this program path*, including the for-loop represented by the node.
    :param variables_in_scope: program variables before execution of the for-loop body, and for-loop variables.
    """
    id: NodeID
    for_vars: list[str]
    variables_in_scope: VariablesInScope = field(compare=False, hash=False)


@dataclass(frozen=True, eq=True)
class ReturnNode(Generic[T], TraceNode):
    """Represents the execution of a return-statement:
    :param id: identifies the return-statement in the program code.
    :param value: the value that was returned
    :param for_vars: variables used as targets of all for-loops *in this program path*, including the for-loop represented by the node.
    """
    id: NodeID
    value: T
    for_vars: list[str]
    variables_in_scope: VariablesInScope = field(compare=False, hash=False)


Node = Union[IfNode, ForNode, ForNodeEnd, ReturnNode[T]]


class Tape(UserList[Node[T]]):
    """A tape is a list of the program states that were encountered while
    executing a piece of Python code.

    The process of recording the execution of a Python program happens
    behind-the-scenes at `pylic.code_transformations.get_tape`."""
    pass
