"""Facade on Pylic for solving marble drop puzzles."""
from examples.marble_drop.environment import Puzzle
from examples.marble_drop.environment import Action
from examples.marble_drop.environment import simulation
from examples.marble_drop.environment import get_puzzle_button_n
from examples.plotting import SolverTime
from pylic.tape import Tape
import pylic.predicates
from pylic.predicates import Predicate
from pylic.predicates import Constants
from pylic.predicates import SolverFailedException
from pylic.solvers.cma_es import solver as cma_solver
from examples.marble_drop.environment import get_puzzle_width
from examples.marble_drop.environment import get_puzzle_height
from examples.marble_drop.environment import BUTTON_WIDTH
from examples.marble_drop.environment import get_pymunk_space
from examples.marble_drop.trajectory_pylic import FPS_PER_ACTION
import torch
import cma
import time
import random
import pymunk


def get_score(tape: Tape) -> float:
    """Return the cost of the given actions. Score is defined
    as the fraction of buttons that were touched."""
    raw_actions = tape[0].program_state["raw_actions"]
    puzzle = tape[0].program_state["puzzle"]
    timestep_n = int(FPS_PER_ACTION*(len(raw_actions)+1))
    pressed_buttons = simulation(raw_actions, puzzle, None, timestep_n)
    score = len(pressed_buttons)/max(1, get_puzzle_button_n(puzzle))
    return score


def cma_facade(
        puzzle: Puzzle,
        timeout_s: float,
        worker_n: int,
        random_restart_n: int,
        ) -> tuple[list[tuple[list[Action], SolverTime]], bool]:
    """Use pylic to find and return a list of actions that solve the maze.

    This solver will fail only once the timeout is reached, retrying the
    stochastic numerical search on each of the sampled initial candidates.
    """
    button_n = get_puzzle_button_n(puzzle)
    max_actions = button_n
    timestep_n = int(FPS_PER_ACTION*max_actions)
    start_t = time.time()
    puzzle_width = get_puzzle_width(puzzle)
    puzzle_height = get_puzzle_height(puzzle)

    def get_action() -> tuple[float, float, float, float]:
        return [0.0, 0.0, 0.0, 0.0]
    initial_stdev = 1.0

    # Keep track of candidate solutions
    log = [([], time.time()-start_t)]

    # Define optimization predicate
    predicate = pylic.predicates.LessThan(
        torch.tensor(0.99999),
        pylic.predicates.FunctionCall(
            "get_score",
            trace=Constants.input_trace
        ),
    )

    def preprocessor(p: torch.Tensor) -> torch.Tensor:
        return p

    while time.time()-start_t <= timeout_s:
        for _ in range(random_restart_n):
            # Sensible initial parameters
            starting_parameters = [
                get_action()
                for i in range(max_actions)
            ]
            log.append((starting_parameters, time.time()-start_t))

            try:
                result, _ = cma_solver(
                    predicate=predicate,
                    starting_parameters=torch.tensor(starting_parameters),
                    f=simulation,
                    max_value=torch.tensor(10.0),
                    custom_functions=dict(get_score=get_score),
                    custom_filters=dict(),
                    initial_stdev=initial_stdev,
                    max_f_eval_n=100000,
                    verbose=True,
                    multiprocessing_workers=worker_n,
                    opts=cma.evolution_strategy.CMAOptions(
                        popsize=worker_n*2,
                        timeout=timeout_s,
                        #seed=289518,
                        #tolfun=0.0,
                        tolflatfitness=2,
                    ),
                    candidate_processor=preprocessor,
                    puzzle=puzzle,
                    timestep_n=timestep_n,
                    output_animation_path=None,
                )
                actions = preprocessor(result).tolist()
                log.append((actions, time.time()-start_t))
                return log, True
            except SolverFailedException as e:
                result = e.final_parameters
                actions = preprocessor(result).tolist()
                log.append((actions, time.time()-start_t))
    return log, False
