"""Wrapper around CMA to solve Pylic predicates."""
from typing import Callable, Concatenate, ParamSpec, TypeVar, Union
from typing import Any
from pylic.code_transformations import get_tape
from pylic.predicates import Predicate
from pylic.predicates import predicate_interpreter
from pylic.predicates import SolverFailedException
from pylic.predicates import CustomFunction
from pylic.predicates import CustomFilter
from concurrent.futures import ProcessPoolExecutor
import torch
import math
import cma


T = TypeVar("T")
K = TypeVar("K")
RuntimeData = ParamSpec("RuntimeData")
Parameters = torch.Tensor


def get_robustness(
        predicate: Predicate,
        parameters: list[float],
        f: Callable[Concatenate[Parameters, RuntimeData], Any],
        max_value: torch.Tensor,
        custom_functions: dict[str, CustomFunction],
        custom_filters: dict[str, CustomFilter],
        parameters_shape: list[int],
        *args: RuntimeData.args,
        **kwargs: RuntimeData.kwargs,
        ) -> float:
    """Helper function to evaluate the robustness of a predicate."""
    tparameters = torch.tensor(parameters).reshape(parameters_shape)
    tape = get_tape(f, None, tparameters, *args, **kwargs)
    return predicate_interpreter(
        predicate=predicate,
        input_tape=tape,
        max_value=max_value,
        custom_functions=custom_functions,
        custom_filters=custom_filters,
    ).item()


def solver(
        predicate: Predicate,
        starting_parameters: Parameters,
        f: Callable[Concatenate[Parameters, RuntimeData], K],
        max_value: torch.Tensor,
        custom_functions: dict[str, CustomFunction],
        custom_filters: dict[str, CustomFilter],
        initial_stdev: float,
        max_f_eval_n: int,
        verbose: bool,
        multiprocessing_workers: Union[int, None],
        opts: Union[cma.evolution_strategy.CMAOptions, None],
        candidate_processor: Callable[[Parameters], Parameters],
        *args: RuntimeData.args,
        **kwargs: RuntimeData.kwargs,
        ) -> tuple[Parameters, K]:
    """Solve the given predicate without multithreading.

    `opts` is given to `CMAEvolutionStrategy` as options dictionary.

    `candidate_processor` is applied to every candidate solution before
    function evaluation.

    Returns the solution (without applying `candidate_processor`) and the
    output of the given function on the solution, or throws
    `SolverFailedException` if none was found (also without applying `candidate_processor` to the `final_parameter` field).

    `max_value` is the maximum value in the predicate interpreter.
    """
    # Initialize CMA-ES variables
    x = starting_parameters.detach().clone().flatten()
    es = cma.CMAEvolutionStrategy(x, initial_stdev, opts)

    # Iterate gradient descent
    eval_n = 0
    parameters_shape = list(starting_parameters.shape)
    executor = ProcessPoolExecutor(
        max_workers=multiprocessing_workers
    ) if multiprocessing_workers is not None else None
    best_cost = math.inf
    best_solution = None
    while eval_n < max_f_eval_n and not es.stop():
        # Sample solutions in native Python floats
        solutions = [solution.tolist() for solution in es.ask()]

        # Compute cost of each solution
        robustness_futures = list()
        costs = list()

        if executor is not None:
            # If using multiprocessing, send tasks to process pool
            for s in solutions:
                # Apply processing function
                prenative_s = candidate_processor(
                    torch.tensor(s).reshape(parameters_shape)
                )

                # Send native list over process boundaries
                # and the processed solution's shape
                native_s = prenative_s.flatten().tolist()
                robustness_future = executor.submit(
                    get_robustness,
                    predicate=predicate,
                    parameters=native_s,
                    f=f,
                    max_value=max_value,
                    custom_functions=custom_functions,
                    custom_filters=custom_filters,
                    parameters_shape=list(prenative_s.shape),
                    **kwargs,
                )
                robustness_futures.append((s, robustness_future))

            for (s, robustness_future) in robustness_futures:
                robustness = robustness_future.result()
                cost = -robustness
                costs.append((s, cost))
        else:
            # If not using multiprocessing, evaluate each solution
            for s in solutions:
                # Apply processing function
                prenative_s = candidate_processor(
                    torch.tensor(s).reshape(parameters_shape)
                    )

                # Send solution and the processed solution's shape
                native_s = prenative_s.flatten().tolist()
                robustness = get_robustness(
                    predicate=predicate,
                    parameters=native_s,
                    f=f,
                    max_value=max_value,
                    custom_functions=custom_functions,
                    custom_filters=custom_filters,
                    parameters_shape=list(prenative_s.shape),
                    **kwargs,
                )
                cost = -robustness
                costs.append((s, cost))

        for (solution, cost) in costs:
            if cost < best_cost:
                best_cost = cost
                best_solution = solution
            if cost < 0.0:
                if executor is not None:
                    executor.shutdown()
                p = torch.tensor(solution).reshape(parameters_shape)
                pp = candidate_processor(p)
                if verbose:
                    print(f"CMA-ES solver: success!")
                return p, f(pp, *args, **kwargs)

            # Update function evaluation counter
            eval_n += 1
            if eval_n > max_f_eval_n:
                if verbose:
                    print(f"CMA-ES solver: max eval reached ({eval_n} > {max_f_eval_n})")
                break

        # Update covariance matrix
        es.tell(solutions, [c for _, c in costs])

        if verbose:
            es.disp()

    if verbose:
        print(f"CMA-ES exit without solution: {es.stop()}")
    if executor is not None:
        executor.shutdown()
    if best_solution is not None:
        final_parameters = torch.tensor(best_solution).reshape(parameters_shape)
    else:
        final_parameters = None
    raise SolverFailedException(
        "Ran out of budget and did not find solution!",
        final_parameters=final_parameters
    )
