"""Use pylic to design a maze that takes a marble to touch all target
platforms."""
from pathlib import Path
from examples.marble_drop.environment import Puzzle
from examples.marble_drop.environment import Action
from examples.marble_drop.environment import simulation
from examples.marble_drop.environment import ButtonRow
from examples.marble_drop.environment import get_puzzle_button_n
from examples.marble_drop.trajectory_pylic import FPS_PER_ACTION
from examples.marble_drop.trajectory_pylic import pylic_cma_solver as pylic_cma
from examples.marble_drop.cma_facade import cma_facade as pure_cma
from examples.marble_drop.trajectory_cem import custom_cem_solver
from examples.plotting import Results
from examples.plotting import QualityResults
from examples.plotting import plot_results
from examples.plotting import plot_quality_results
from collections import defaultdict
from functools import reduce
from itertools import product
import traceback
import json
import time
import random

# Experiment parameters
worker_n = 48
solver_timeout_s = 60*60  # 60 minutes per task
PuzzleKey = list[int]  # Represent puzzles with lists of integers
experiment_timeout_s = 60*60*16  # 16 hours

# Task generation parameters
min_row_size = 3
max_row_size = 5
min_row_n = 1
max_row_n = 3

# Numerical search parameters
numerical_search_random_restart_n = 2

quick = False  # True for quick debug
if quick:
    min_row_size = 1
    max_row_size = 2
    min_row_n = 1
    max_row_n = 1
    solver_timeout_s = 60


def get_difficulty(puzzle_key: PuzzleKey) -> int:
    """Return a measure of difficulty of the puzzle. The difficulty
    is exponential with respect to the number of buttons."""
    return reduce(lambda a, b: (a+1)*(b+1), puzzle_key)


def get_puzzle(puzzle_key) -> Puzzle:
    """Return the puzzle represented by the given key."""
    return [ButtonRow(i) for i in puzzle_key]


def get_quality(actions: list[Action], puzzle: Puzzle) -> float:
    """Return the quality of the given actions. Quality is defined
    as the fraction of buttons that were touched."""
    timestep_n = int(FPS_PER_ACTION*(len(actions)+1))
    pressed_buttons = simulation(actions, puzzle, None, timestep_n)
    score = len(pressed_buttons)/max(1, get_puzzle_button_n(puzzle))
    return score


if __name__ == "__main__":
    output_path = Path()/"marble_drop"
    output_path.mkdir()

    experiment_start_t = time.time()

    # Instantiate tasks
    tasks: list[PuzzleKey] = [
        [1],
        [2],
        [3],
        [4],
        [2, 2],
        [3, 2],
        [3, 2, 1],
    ]

    # Suffle tasks in case we are not able to go through all of them
    # before timeout
    #random.shuffle(tasks)

    # List of solvers to benchmark
    solvers = [
        ("Pylic[CMA-ES]", pylic_cma),
        ("CMA-ES", pure_cma),
        ("MPC[CEM]", custom_cem_solver),
    ]

    # Track solver statistics
    time_results: Results = defaultdict(lambda: defaultdict(list))
    quality_results: QualityResults = list()

    # Write current results regularly
    time_results_path = output_path/"results.json"
    quality_results_path = output_path/"quality_results.json"

    def write_results(results: (Results | QualityResults), path: Path):
        with open(path, "wt") as fp:
            json.dump(results, fp, indent=2)
        print(f"Wrote {path}")

    # Benchmark each solver
    experiment_id = 0
    for i, puzzle_key in enumerate(tasks):
        if time.time()-experiment_start_t > experiment_timeout_s:
            break

        for solver_id, solver in solvers:
            experiment_id += 1

            # Create a directory for this solver's results
            solver_output_path = output_path/solver_id
            solver_output_path.mkdir(exist_ok=True)

            # Measure puzzle difficulty
            difficulty = get_difficulty(puzzle_key)

            # Solve task
            starting_parameters = list()

            # Execute solver
            solutions, is_successful = solver(
                puzzle=get_puzzle(puzzle_key),
                timeout_s=solver_timeout_s,
                worker_n=worker_n,
                random_restart_n=numerical_search_random_restart_n,
            )

            if is_successful:

                # Extract parameters
                actions = solutions[-1][0]

                # Compute total time
                total_t = solutions[-1][1]

                # Log experiment time
                time_results[difficulty][solver_id].append(total_t)
                write_results(time_results, time_results_path)

                # Plot solution
                output_animation = solver_output_path/f"puzzle_{i}_{str(puzzle_key)}.mp4"
                timestep_n = int(FPS_PER_ACTION*(len(actions)+1))
                simulation(
                    actions,
                    get_puzzle(puzzle_key),
                    output_animation,
                    timestep_n,
                )
                print(f"Wrote {output_animation}")

            else:
                time_results[difficulty][solver_id].append(None)
                print(traceback.format_exc())

            # Log quality of each returned solution
            for parameters, total_t in solutions:
                quality = get_quality(
                    parameters,
                    get_puzzle(puzzle_key),
                )
                quality_result = (
                    difficulty,
                    solver_id,
                    total_t,
                    quality,
                    experiment_id,
                )
                quality_results.append(quality_result)
            write_results(quality_results, quality_results_path)

    # Write experiment results
    write_results(time_results, time_results_path)
    write_results(quality_results, quality_results_path)

    # Plot time results
    output_plots_dir = output_path/"figures_time"
    output_plots_dir.mkdir()
    plot_results(time_results, output_plots_dir)
    print(f"Wrote {output_plots_dir}")

    # Plot quality results
    output_quality_plots_dir = output_path/"figures_quality"
    output_quality_plots_dir.mkdir()
    plot_quality_results(quality_results, output_quality_plots_dir)
