"""CMA-Es solver for Pylic predicates."""
from typing import Callable, Concatenate, ParamSpec, TypeVar
from typing import Union
from pylic.code_transformations import get_tape
from pylic.predicates import Predicate
from pylic.predicates import predicate_interpreter
from pylic.predicates import SolverFailedException
from pylic.predicates import CustomFunction
from pylic.predicates import CustomFilter
import torch


T = TypeVar("T")
K = TypeVar("K")
RuntimeData = ParamSpec("RuntimeData")
Parameters = torch.Tensor


def solver(
        predicate: Predicate,
        starting_parameters: Parameters,
        f: Callable[Concatenate[Parameters, RuntimeData], K],
        max_value: torch.Tensor,
        custom_functions: dict[str, CustomFunction],
        custom_filters: dict[str, CustomFilter],
        grad_mask: torch.Tensor,
        iter_n: int,
        learning_rate: float,
        momentum_beta: float,
        grad_norm_min: float,  # gradient norm to terminate
        normalize_gradient: bool,
        verbose: bool,
        dampening: float,  # <= 1.0 (1.0 no dampening)
        line_search_scale: Union[None, float],  # >0, < 1
        line_search_max_iter_n: int, # Maximum number of iterations for line search
        max_consecutive_non_improvements: int,  # over the best so far
        max_consecutive_sat_improvements: int,  # continue optimizing after solving?
        *args: RuntimeData.args,
        **kwargs: RuntimeData.kwargs,
        ) -> tuple[Parameters, K]:
    """Solve a predicate using vanilla gradient descent. If a solution is not
    found within the given number of iterations, `pylic.predicates.SolverFailedException` will be raised.

    :param predicate: Predicate to solve.
    :param starting_parameters: Initial guess for the gradient descent algorithm.
    :param f: Function over which the predicate will be solved.
    :param max_value: Maximum value which a predicate can take.
    :param custom_interpreter: Interpreter for user-defined functions and predicates.
    :param grad_mask: Tensor of the same shape as `starting_parameters` by which the gradient will be multiplied. Useful to 'freeze' a subset of the parameters.
    :param iter_n: Maximum of iterations before failing.
    :param learning_rate: Learning rate for the gradient descent algorithm.
    :param momentum_beta: Simple momentum coefficient between 0 and 1. If `None`, no momentum will be applied. Useful to escape small local minima.
    :param grad_norm_min: Minimum norm gradient at which the search will be halted. Useful to detect convergence.
    :param normalize_gradient: Whether the gradient will be normalized to l2-norm of 1 before performing a step.
    :param verbose: Whether to print status updates to stdout.
    :param dampening: Exponential decay factor of the learning rate. A value of 1 means no decay.
    :param line_search_scale: Scale for the line search decreasing factor. If `None`, no line search will be performed.
    :param line_search_max_iter_n: Maximum number of iterations for line search.
    :param max_consecutive_non_improvements: Maximum non-improving steps before halting execution with error.
    :param max_consecutive_sat_improvements: Maximum steps taken after the predicate has been solved. Usually 0.
    :param args: Positional arguments for `f`.
    :param kwargs: Keyword arguments for `f`.
    """
    if verbose:
        print(f"Starting gradient descent on size {starting_parameters.size()}")

    gd_iter_n = iter_n
    lr = learning_rate
    beta = momentum_beta

    # Cost function executes program and evaluates node predicate
    def get_robustness(parameters: Parameters) -> torch.Tensor:
        tape = get_tape(f, None, parameters, *args, **kwargs)
        return predicate_interpreter(
            predicate=predicate,
            input_tape=tape,
            max_value=max_value,
            custom_functions=custom_functions,
            custom_filters=custom_filters,
        )

    # Initialize gradient descent variables
    x = starting_parameters.detach().clone()
    x.requires_grad_(True)
    z = torch.zeros_like(x)

    # Iterate gradient descent
    current_robustness = get_robustness(x)
    best_robustness = current_robustness
    consecutive_non_improvements = 0
    consecutive_sat_improvements = 0
    best_x = x.clone().detach()
    for t in range(gd_iter_n):
        if verbose:
            print(
                f"Iter {t}, robustness: {current_robustness.item()}"
            )
        if not current_robustness.requires_grad:
            # Sometimes the robustness does not requires_grad_
            # (i.e. it is not a function of the parameters).
            # In general, a hybrid gradient-based and gradient-free
            # optimization scheme should be implemented to handle
            # this case in its full generality.
            # In this code we simply set the gradient to zero. This will
            # generally mean that the solver will get stuck in a local
            # minimum and fail, unless there is enough momentum to escape.
            grad = torch.zeros_like(x)
        else:
            grad = torch.autograd.grad(current_robustness, [x])[0]

        # Apply gradient mask
        grad = grad*grad_mask
        grad = grad.nan_to_num()

        # Normalize gradient
        if normalize_gradient and grad.norm() > grad_norm_min:
            grad = grad/grad.norm()

        # Momentum
        new_z = beta*z + (dampening**t)*lr*grad

        if (new_z).norm() < grad_norm_min:
            # Converged!
            if verbose:
                print("Gradient descent converged!")
            break

        # Gradient ascent step
        new_x = (x + new_z).detach()

        # Line search
        step_size = 1.0
        if line_search_scale is not None:
            new_robustness = get_robustness(new_x)
            line_search_i = 0
            while new_robustness < current_robustness\
                 and (step_size*new_z).norm() > grad_norm_min:
                if line_search_i >= line_search_max_iter_n:
                    break
                line_search_i += 1
                step_size = step_size*line_search_scale
                new_x = (x + step_size*new_z)
                new_robustness = get_robustness(new_x)

            # If line search failed to find a better step...
            if new_robustness < current_robustness:
                # ...take original step
                step_size = 1.0

        # Gradient ascent step
        z = (step_size*new_z).detach()
        x = new_x.detach().requires_grad_(True)
        current_robustness = get_robustness(x)

        if current_robustness <= best_robustness:
            consecutive_non_improvements += 1
            if consecutive_non_improvements > max_consecutive_non_improvements:
                if verbose:
                    print(
                        "Gradient descent maximum non-improvements reached!"
                    )
                break
        else:
            best_robustness = current_robustness
            best_x = x.clone().detach()
            consecutive_non_improvements = 0

        if best_robustness > 0.0:
            consecutive_sat_improvements += 1
        if consecutive_sat_improvements > max_consecutive_sat_improvements:
            if verbose:
                print(
                    "Gradient descent maximum sat-improvements reached!"
                )
            break

    if best_robustness > 0.0:
        x = best_x.clone().detach()
        return x, f(x, *args, **kwargs)
    raise SolverFailedException(
        "Ran out of budget and did not find solution!",
        final_parameters=x,
    )
