"""Implements code-transformations for dynamic program analysis."""
from pylic.tape import Tape
from pylic.tape import IfNode
from pylic.tape import ReturnNode
from pylic.tape import ForNode
from pylic.tape import ForNodeEnd
from pylic.tape import NodeID
from typing import Callable
from typing import Union
from typing import Optional
from typing import SupportsFloat
from typing import ParamSpec
from typing import TypeVar
from dataclasses import dataclass
from functools import lru_cache
import ast
import inspect
import warnings
import re


TAPE_PARAMETER_NAME = "tape"
FIXED_TAPE_PARAMETER_NAME = "fixed_tape"
SPECIAL_VARIABLES = [
    TAPE_PARAMETER_NAME,
    FIXED_TAPE_PARAMETER_NAME,
]

node_id_regex = r'.*#\s*ID:\s(\S+)'
"""Users can assign custom IDs to control-flow instructions in order to make it
easier to refer to them in predicates. A custom ID should be a sequence of
non-whitespace characters. To assign a custom ID, write a comment of the form
`ID: custom_id` after a control flow instruction. For example:

```py
if x < 3.5:  # ID: my_check
```
"""

T = TypeVar("T")


@dataclass(frozen=True, eq=True, order=True)
class VariableShadowingException(Exception):
    """Represents an unsupported shadowing of variables."""
    node: ast.AST


@dataclass(frozen=True, eq=True, order=True)
class DifferentiableConditionTransformationException(Exception):
    """Represents a generic boolean expression transformation error."""
    node: ast.AST


def get_differentiable_boolean_op(node: ast.AST) -> ast.AST:
    """Transform the node if it is a string of the following
    language:

        COND := COND and COND | COND or COND

    Else, throw an exception.

    The returned node is an expression whose value is zero if and only if the
    original comparison evaluates to true.
    """
    # Test node is `and` or `or`
    if isinstance(node, ast.BoolOp) and isinstance(node.op, (ast.And, ast.Or)):
        # Test each value is differentiable
        # (`a or b or c or d` is condensed into a single node,
        # so a node can have any number of values).
        # This list comprehension recursively transforms conditions and passes
        # exceptions to the caller, if any.
        values = [
            get_differentiable_condition_value_node(value)
            for value in node.values
        ]
        if isinstance(node.op, ast.And):
            function = ast.Attribute(
                value=ast.Name(id="pylic.predicates"),
                attr="conjunction",
            )
        else:
            function = ast.Attribute(
                value=ast.Name(id="pylic.predicates"),
                attr="disjunction",
            )
        return ast.Call(
            function,
            values,
            [],
        )
    # Else, give up
    raise DifferentiableConditionTransformationException(node)


def get_differentiable_compare_node(node: ast.Compare) -> ast.AST:
    """Transform the given node if it is the following language:

        S := VAL < VAL | VAL > VAL
        VAL := "anything"

    Else, throw an exception.

    The returned node is an expression whose value is greater than zero if and
    only if the comparison evaluates to true.
    """
    # Test node is a simple comparison (i.e. not
    # of the form `a < b < c`).
    if isinstance(node, ast.Compare) and len(node.ops) == 1:
        left = node.left
        right = node.comparators[0]
        # Test node is of the form `a < b`
        if isinstance(node.ops[0], ast.Lt):
            function = ast.Attribute(
                value=ast.Name(id="pylic.predicates"),
                attr="less_than",
            )
        # Test node is of the form `a > b`
        elif isinstance(node.ops[0], ast.Gt):
            function = ast.Attribute(
                value=ast.Name(id="pylic.predicates"),
                attr="greater_than",
            )
        else:
            raise DifferentiableConditionTransformationException(node)
        return ast.Call(
            function,
            [left, right],
            [],
        )
    # Else, give up
    raise DifferentiableConditionTransformationException(node)


def get_differentiable_condition_value_node(node: ast.AST) -> ast.AST:
    """Transform the given node if it is in the following
    language:

    Else, throw an exception.

        COND := VAL < VAL | VAL > VAL | COND and COND | COND or COND
        VAL := "anything"

    The returned node is an expression whose value is zero if and only if the
    original comparison evaluates to true.
    """
    # Attempt each possible transformer
    transducers = [
        get_differentiable_boolean_op,
        get_differentiable_compare_node,
    ]
    for transducer in transducers:
        try:
            return transducer(node)
        except DifferentiableConditionTransformationException:
            pass
    raise DifferentiableConditionTransformationException(node)


def if_node_constructor(
        id: NodeID,
        value_var: str,
        variables_in_scope: list[str],
        for_vars: list[str],
        ) -> ast.Call:
    """Returns the AST of `IfNode(id, value, vars_in_scope)`."""
    return ast.Call(
        ast.Attribute(
            value=ast.Name("pylic.code_transformations"),
            attr=IfNode.__name__,
        ),
        args=[],
        keywords=[
            ast.keyword('id', ast.Constant(id)),
            ast.keyword('value', ast.Name(value_var)),
            ast.keyword('variables_in_scope', ast.Tuple([
                ast.Tuple([ast.Constant(var), ast.Name(var)], ast.Load())
                for var in variables_in_scope
            ], ast.Load())),
            ast.keyword('for_vars', ast.Tuple([
                ast.Constant(var)
                for var in for_vars
            ], ast.Load())),
        ],
        kwargs=[],
    )


def return_node_constructor(
        id: NodeID,
        value_var: str,
        variables_in_scope: list[str],
        for_vars: list[str],
        ) -> ast.Call:
    """Returns the AST of `ReturnNode(id, value, vars_in_scope)`."""
    return ast.Call(
        ast.Attribute(
            value=ast.Name("pylic.code_transformations"),
            attr=ReturnNode.__name__,
        ),
        args=[],
        keywords=[
            ast.keyword('id', ast.Constant(id)),
            ast.keyword('value', ast.Name(value_var)),
            ast.keyword('variables_in_scope', ast.Tuple([
                ast.Tuple([ast.Constant(var), ast.Name(var)], ast.Load())
                for var in variables_in_scope
            ], ast.Load())),
            ast.keyword('for_vars', ast.Tuple([
                ast.Constant(var)
                for var in for_vars
            ], ast.Load())),
        ],
        kwargs=[],
    )


def for_node_constructor(
        id: NodeID,
        for_vars: list[str],
        variables_in_scope: list[str]
        ) -> ast.Call:
    """Returns the AST of `ForNode(id, value, vars_in_scope)`."""
    return ast.Call(
        ast.Attribute(
            value=ast.Name("pylic.code_transformations"),
            attr=ForNode.__name__,
        ),
        args=[],
        keywords=[
            ast.keyword('id', ast.Constant(id)),
            ast.keyword('for_vars', ast.Tuple([
                ast.Constant(var)
                for var in for_vars
            ], ast.Load())),
            ast.keyword('variables_in_scope', ast.Tuple([
                ast.Tuple([ast.Constant(var), ast.Name(var)], ast.Load())
                for var in variables_in_scope
            ], ast.Load())),
        ],
        kwargs=[],
    )


def for_end_node_constructor(
        id: NodeID,
        for_vars: list[str],
        variables_in_scope: list[str]
        ) -> ast.Call:
    """Returns the AST of `ForNode(id, value, vars_in_scope)`."""
    return ast.Call(
        ast.Attribute(
            value=ast.Name("pylic.code_transformations"),
            attr=ForNodeEnd.__name__,
        ),
        args=[],
        keywords=[
            ast.keyword('id', ast.Constant(id)),
            ast.keyword('for_vars', ast.Tuple([
                ast.Constant(var)
                for var in for_vars
            ], ast.Load())),
            ast.keyword('variables_in_scope', ast.Tuple([
                ast.Tuple([ast.Constant(var), ast.Name(var)], ast.Load())
                for var in variables_in_scope
            ], ast.Load())),
        ],
        kwargs=[],
    )


class DifferentiableIfRecorderTransformer(ast.NodeTransformer):
    """AST transformer that transforms functions so that:
        1. The execution of (some) control-flow statements is recorded
           to a `tape` upon execution.
        2. A `fixed_tape` parameter can be given to "fix" the program
           path followed when the tape was recorded, so that every
           execution follows the same program path.
    """

    def __init__(
            self,
            tape_parameter_name: str,
            fixed_tape_parameter_name: str,
            source: str,
            ):
        self.tape_parameter_name = tape_parameter_name
        self.fixed_tape_parameter_name = fixed_tape_parameter_name
        self.ifs_read = 0
        self.returns_read = 0
        self.fors_read = 0
        self.variables_in_scope = list()
        self.for_vars = list()  # all for-vars in the current program path
        self.source = source

    def get_node_id(self, node: ast.AST) -> (int|str):
        """Return the ID for the given AST node. If the user annotated the
        control flow statement with an ID (see
        pylic.code_transformations.node_id_regex), that ID is returned,
        otherwise, the ID is assigned based on the control-flow statements
        previously visited."""
        # First, compute the "default" ID, which is simply the number
        # of nodes of this type that have been traversed already.
        if isinstance(node, ast.If):
            default_id = self.ifs_read
        elif isinstance(node, ast.For):
            default_id = self.fors_read
        elif isinstance(node, ast.Return):
            default_id = self.returns_read
        else:
            raise ValueError(f'Cannot assign ID to node of type {type(node)}')

        # Then, check if the user provided a custom ID. If not,
        # return the original node id.
        node_source = ast.get_source_segment(self.source, node)
        if node_source is None:
            return default_id
        node_source = node_source.split('\n')[0]
        # Complimentary stripping of trailing whitespace
        node_source = node_source.strip()
        node_id_match = re.match(node_id_regex, node_source)
        if node_id_match is None:
            return default_id
        node_id = node_id_match.group(1)
        return node_id

    def visit_Assign(self, node: ast.Assign):
        # Add variable to variables in scope
        for variable in node.targets:
            if isinstance(variable, ast.Name):
                self.variables_in_scope.append(variable.id)
            elif isinstance(variable, (ast.List, ast.Tuple)):
                self.variables_in_scope.extend(
                    (
                        t.id
                        for t in variable.elts
                    )
                )
        return node

    def visit_FunctionDef(self, node: ast.FunctionDef):
        # Add function arguments to list of variables in scope
        arg_nodes = (
            arg_node for arg_node in (
                *node.args.posonlyargs,
                *node.args.args,
                *node.args.kwonlyargs,
                node.args.vararg,
                node.args.kwarg,
            )
            if arg_node is not None
        )
        for arg_node in arg_nodes:
            if arg_node.arg in SPECIAL_VARIABLES:
                continue
            self.variables_in_scope.append(arg_node.arg)
        return self.generic_visit(node)

    def visit_For(self, node: ast.For):
        # Add the iteration variables to the list of variables in scope
        # Get the variable name on the for
        target = node.target
        for_vars = list()
        if isinstance(target, ast.Name):
            for_vars.append(target.id)
        elif isinstance(target, (ast.Tuple, ast.List)):
            def flatten(l: Union[ast.Tuple, ast.List]) -> list[ast.Name]:
                y = []
                for i in l.elts:
                    if isinstance(i, (ast.Tuple, ast.List)):
                        y.extend(flatten(i))
                    else:
                        y.append(i)
                return y
            for_vars.extend((t.id for t in flatten(target)))
        else:
            ValueError(f"Failed to analyze {node}")

        # Add the variable to the variables in scope
        variables_in_scope_outside = list(self.variables_in_scope)
        for_vars_outside = list(self.for_vars)
        self.for_vars.extend(for_vars)
        self.variables_in_scope.extend(for_vars)

        # Get node ID
        for_id = self.get_node_id(node)

        # Update number of for loops visited
        self.fors_read += 1

        # Visit children
        self.generic_visit(node)

        # Add tracing statement at start of for body
        node_constructor = for_node_constructor(
            id=for_id,
            for_vars=self.for_vars,  # vars of all for-loops in program path
            variables_in_scope=variables_in_scope_outside+for_vars,
        )

        # Build a node for `for_id_var = for_id_node`
        for_node_var = ast.Name(f"_pylic_for_{for_id}_node")
        node_statement = ast.Assign([for_node_var], node_constructor)

        # Build a node for variable `tape`
        tape_node = ast.Name(self.tape_parameter_name)
        trace_statement = ast.Expr(ast.Call(
            ast.Attribute(tape_node, "append", ast.Load()),
            [for_node_var, ast.Load()],
            [],
        ))
        node.body.insert(0, node_statement)
        node.body.insert(1, trace_statement)

        # Do the same for the end-for node
        node_end_constructor = for_end_node_constructor(
            id=for_id,
            for_vars=self.for_vars,
            variables_in_scope=self.variables_in_scope,
        )
        for_end_node_var = ast.Name(f"pylic_for_end_{for_id}_node")
        node_end_statement = ast.Assign([for_end_node_var], node_end_constructor)
        tape_node = ast.Name(self.tape_parameter_name)
        end_trace_statement = ast.Expr(ast.Call(
            ast.Attribute(tape_node, "append", ast.Load()),
            [for_end_node_var, ast.Load()],
            [],
        ))
        node.body.append(node_end_statement)
        node.body.append(end_trace_statement)

        # Restore the variables in scope
        self.variables_in_scope = variables_in_scope_outside
        self.for_vars = for_vars_outside

        return node

    def visit_If(self, node: ast.If):
        # The node will be returned untouched if differentiable transformation
        # of the condition fails
        variables_in_scope_outside = list(self.variables_in_scope)
        output = node

        try:
            # Attempt to translate into a differentiable condition.
            # If successful, add a line to record its value into the tape
            value_node = get_differentiable_condition_value_node(node.test)
        except DifferentiableConditionTransformationException:
            # The translation into a differentiable condition may fail and
            # throw an exception, in which case nothing happens
            # Transform child nodes
            warnings.warn(
                f"Could not smooth condition test: {ast.unparse(node.test)})"
            )
            value_node = node.test

        # Obtain ID
        if_id = self.get_node_id(node)

        # Update counter for number of if statements read
        self.ifs_read += 1

        # Transform child nodes
        self.generic_visit(node)

        # Build a fresh variable
        if_value_var = ast.Name(f"_pylic_if_{if_id}_value")

        # Build a node for `if_value_var = value`
        statement1 = ast.Assign([if_value_var], value_node)

        # Build a node for the if statement node
        if_node = if_node_constructor(
            id=if_id,
            variables_in_scope=variables_in_scope_outside,
            value_var=if_value_var.id,
            for_vars=self.for_vars,
        )

        # Build a node for `if_id_var = if_id_node`
        if_node_var = ast.Name(f"_pylic_if_{if_id}_node")
        statement2 = ast.Assign([if_node_var], if_node)

        # Build a node for variable `tape`
        tape_node = ast.Name(self.tape_parameter_name)

        # Build a node for `tape.append(if_node)`
        statement3 = ast.Expr(ast.Call(
            ast.Attribute(tape_node, "append", ast.Load()),
            [if_node_var, ast.Load()],
            [],
        ))

        # Build a node for `if_id_value > 0.0`
        new_test = ast.Compare(if_value_var, [ast.Gt()], [ast.Constant(0.0)])

        # Now, add a node that checks if the value of the if-statement
        # should be overridden from a fixed tape
        # if fixed_tape is not None and if_id in fixed_tape: fresh_var = fixed_tape[if_id]
        # Build a node for variable `fixed_tape`
        fixed_tape_node = ast.Name(self.fixed_tape_parameter_name)

        # Build a node for `fixed_tape is not None and if_id_var in fixed_tape`
        fixed_test = ast.BoolOp(
            op=ast.And(),
            values=[
                ast.Compare(
                    fixed_tape_node, [ast.IsNot()], [ast.Constant(None)]
                ),
                ast.Compare(
                    if_node_var, [ast.In()], [fixed_tape_node]
                ),
            ]
        )

        # Build a node for `if_value_var = fixed_tape[if_id_var]`
        subscript = if_node_var
        fixed_assign = ast.Assign(
            [if_value_var],
            ast.Subscript(fixed_tape_node, subscript, ast.Load()),
        )

        # Build a node for
        # `if fixed_tape is not None: fresh_var = fixed_tape[id]`
        fixed_tape_if = ast.If(fixed_test, [fixed_assign], [])

        # Transform the if statement to `if fresh_var > 0:`
        node.test = new_test

        # The new statements are the output
        output = [statement1, statement2, statement3, fixed_tape_if, node]

        # Restore variables in scope
        self.variables_in_scope = variables_in_scope_outside
        return output

    def visit_Return(self, node: ast.Return):
        # Get node ID
        node_id = self.get_node_id(node)

        # Build `var = value`
        var = ast.Name(f"_pylic_return_{node_id}")
        var_assign = ast.Assign([var], node.value)

        # Build `ReturnNode(var)`
        return_node = return_node_constructor(
            id=self.returns_read,
            value_var=var.id,
            variables_in_scope=self.variables_in_scope,
            for_vars=self.for_vars,
        )

        # Build `tape.append(ReturnNode(var))`
        append_ast = ast.Expr(ast.Call(
            ast.Attribute(ast.Name(self.tape_parameter_name), "append", ast.Load()),
            [return_node, ast.Load()],
            [],
        ))

        # Build `return var`
        return_ast = ast.Return(var)

        # Update number of returns read
        self.returns_read += 1
        return [var_assign, append_ast, return_ast]

    def generic_visit_SaveRestoreVariablesInScope(self, node: ast.AST):
        # Save and restore variables in scope
        variables_in_scope_outside = list(self.variables_in_scope)
        self.generic_visit(node)
        self.variables_in_scope = variables_in_scope_outside
        return node

    def visit_While(self, node: ast.While):
        self.generic_visit_SaveRestoreVariablesInScope(node)

    def visit_Try(self, node: ast.Try):
        self.generic_visit_SaveRestoreVariablesInScope(node)

    def visit_With(self, node: ast.With):
        self.generic_visit_SaveRestoreVariablesInScope(node)


@lru_cache(maxsize=None)
def get_tracing_transformed_source(source: str) -> tuple[str, str]:
    """Helper function that takes a python program as string and returns
    an 'if-transformed version' as string."""
    tree = ast.parse(source)

    # We know f is a function, so the AST is going to be of the form
    # Module(body=[FunctionDef(...)])
    # Add the tape parameter.
    tape_arg = ast.arg(arg=TAPE_PARAMETER_NAME)
    fdef = tree.body[0]
    if not isinstance(fdef, ast.FunctionDef):
        raise ValueError(f"Parsed code does not start with function definition!")
    fdef.args.kwonlyargs.append(tape_arg)
    fdef.args.kw_defaults.append(None)

    # Add the fixed tape parameter.
    fixed_tape_arg = ast.arg(
        arg=FIXED_TAPE_PARAMETER_NAME,
    )
    fdef.args.kwonlyargs.append(fixed_tape_arg)
    fdef.args.kw_defaults.append(None)

    # Add tape recording to all if statements
    tree = DifferentiableIfRecorderTransformer(
        tape_parameter_name=TAPE_PARAMETER_NAME,
        fixed_tape_parameter_name=FIXED_TAPE_PARAMETER_NAME,
        source=source,
    ).visit(tree)

    # Add `import pylic.predicates`
    tree.body[0].body.insert(0, ast.Import(names=[ast.alias("pylic.predicates", None)]))

    # Add `import pylic.code_transformations`
    tree.body[0].body.insert(0, ast.Import(names=[ast.alias("pylic.code_transformations", None)]))

    # Change the name of the function
    new_name = tree.body[0].name + "_pylic_transform"
    tree.body[0].name = new_name

    # Update node location information because we added new lines
    tree = ast.fix_missing_locations(tree)

    new_source = ast.unparse(tree)
    return new_source, new_name


P = ParamSpec("P")
def get_tape(
        f: Callable[P, T],
        fixed_tape: Optional[list[tuple[IfNode, SupportsFloat]]],
        *args: P.args,
        **kwargs: P.kwargs,
        ) -> Tape[T]:
    """Record the execution of `f` when called with `*args` and `**kwargs`.
    The exact function call will be `f(*args, **kwargs)`.
    """
    # Transform the source code of f into new source code
    source = inspect.getsource(f)
    new_source, new_name = get_tracing_transformed_source(source)

    # Recompile executing the new code under the
    # same __globals__ context as f
    _locals = dict()
    exec(new_source, f.__globals__, _locals)
    f_transformed = _locals[new_name]

    # Get the fixed tape as a dictionary
    dict_fixed_tape = None if fixed_tape is None else dict(fixed_tape)

    tape = Tape()
    f_transformed(*args, **kwargs, tape=tape, fixed_tape=dict_fixed_tape)
    return tape
