"""Wrapper around EvoTorch to solve Pylic predicates (neuroevolution)."""
from typing import Callable, Concatenate, ParamSpec, TypeVar
from typing import Any
from pylic.code_transformations import get_tape
from pylic.predicates import Predicate
from pylic.predicates import predicate_interpreter
from pylic.predicates import CustomFunction
from pylic.predicates import CustomFilter
from pylic.predicates import SolverFailedException
from evotorch.neuroevolution import NEProblem
from evotorch.algorithms import SearchAlgorithm
from evotorch.logging import StdOutLogger
import torch


T = TypeVar("T")
K = TypeVar("K")
RuntimeData = ParamSpec("RuntimeData")


def get_robustness(
        predicate: Predicate,
        parameters: torch.nn.Module,
        f: Callable[Concatenate[torch.nn.Module, RuntimeData], Any],
        max_value: torch.Tensor,
        custom_functions: dict[str, CustomFunction],
        custom_filters: dict[str, CustomFilter],
        *args: RuntimeData.args,
        **kwargs: RuntimeData.kwargs,
        ) -> torch.Tensor:
    """Helper function to evaluate the robustness of a predicate."""
    tape = get_tape(f, None, parameters, *args, **kwargs)
    return predicate_interpreter(
        predicate=predicate,
        input_tape=tape,
        max_value=max_value,
        custom_functions=custom_functions,
        custom_filters=custom_filters,
    )


class NEProblemWrapper(NEProblem):
    def __init__(
            self,
            network: (str | torch.nn.Module | Callable[[], torch.nn.Module]),
            predicate: Predicate,
            f: Callable[Concatenate[torch.nn.Module, RuntimeData], Any],
            max_value: torch.Tensor,
            custom_functions: dict[str, CustomFunction],
            custom_filters: dict[str, CustomFilter],
            evotorch_kwargs,
            *args: RuntimeData.args,
            **kwargs: RuntimeData.kwargs,
            ):
        super().__init__(
            objective_sense="max",
            network=network,
            **evotorch_kwargs,
        )
        self.network = network
        self.predicate = predicate
        self.f = f
        self.max_value = max_value
        self.custom_functions = custom_functions
        self.custom_filters = custom_filters
        self.evotorch_kwargs = evotorch_kwargs
        self.args = args
        self.kwargs = kwargs

    def _evaluate_network(self, network: torch.nn.Module) -> torch.Tensor:
        return get_robustness(
            predicate=self.predicate,
            parameters=network,
            f=self.f,
            max_value=self.max_value,
            custom_functions=self.custom_functions,
            custom_filters=self.custom_filters,
            *self.args,
            **self.kwargs,
        )


def solver(
        predicate: Predicate,
        network: (str | torch.nn.Module | Callable[[], torch.nn.Module]),
        f: Callable[Concatenate[torch.nn.Module, RuntimeData], Any],
        max_value: torch.Tensor,
        custom_functions: dict[str, CustomFunction],
        custom_filters: dict[str, CustomFilter],
        algorithm: Callable[[NEProblem], SearchAlgorithm],
        num_generations: int,
        verbose: bool,
        evotorch_kwargs,
        *args: RuntimeData.args,
        **kwargs: RuntimeData.kwargs,
        ) -> torch.nn.Module:
    """Wrap Evotorch as the solver of Pylic predicates."""
    # Instantiate NEProblem
    problem = NEProblemWrapper(
        network=network,
        predicate=predicate,
        f=f,
        max_value=max_value,
        custom_functions=custom_functions,
        custom_filters=custom_filters,
        evotorch_kwargs=evotorch_kwargs,
        *args,
        **kwargs,
    )
    searcher = algorithm(problem)
    if verbose:
        StdOutLogger(searcher)

    # Search until finding solution or running out of budget
    for _ in range(num_generations):
        searcher.step()
        if searcher.status['best_eval'] > 0:
            break

    # Get best network
    best = problem.parameterize_net(searcher.status["best"])

    # Check that network satisfies solution
    final_robustness = get_robustness(
        predicate=predicate,
        parameters=best,
        f=f,
        max_value=max_value,
        custom_functions=custom_functions,
        custom_filters=custom_filters,
        *args,
        **kwargs,
    )
    if float(final_robustness) <= 0.0:
        raise SolverFailedException()
    return best
