"""An example of using pylic over a marble simulation for trajectory
optimization over a maze.
This script benchmarks against pure gradient descent and CMA-ES.

The script takes a long time to run (20hrs in a 20-core CPU).
"""

import os
os.environ['OPENBLAS_NUM_THREADS'] = '1'

from pathlib import Path
from tasks import get_tasks
from tasks import RandomTaskImpulseParameters
from tasks import TaskImpulse as Task
from trajectory_dynamics import get_trajectory
from plotting import plot_animation
from examples.plotting import plot_results
from examples.plotting import plot_quality_results
from examples.plotting import Results
from collections import defaultdict
import time
import torch
import json
import traceback
from trajectory_pylic import pylic_grad_solve
from trajectory_pylic import pylic_cma_solve
from trajectory_pylic import marble_grad_solver
from trajectory_pylic import marble_cma_solver
from trajectory_pylic import get_reach_goal_predicate
from policy_pylic import pylic_evotorch_solve
from rl_baseline import custom_maze_rl_solver
from trajectory_baseline import custom_cem_solver
from stable_baselines3 import SAC
from examples.plotting import SolverTime
from examples.plotting import QualityResults
from pylic.predicates import SolverFailedException
from tasks import get_quality
import traceback

# When running on CPU (like here) Pytorch uses threads to distribute matrix
# operations, which we do not want because we are strictly working on the
# single-core setting.
# https://pytorch.org/docs/stable/generated/torch.set_num_threads.html
torch.set_num_threads(1)


Parameters = torch.Tensor
episode_timestep_n = 50
worker_n = None

# RL
rl_train_timestep_n = 1_000_000
rl_timestep_train_freq_n = 1000
rl_timeout_check_freq = 10  # check solution every this number of parallel time-steps
rl_grad_steps_per_train = 10  # -1 is an update for each collected transition
rl_log_freq = 1000  # in time-steps

# CMA
cma_max_f_eval_n = 3000

# Gradient descent
grad_random_restart_n = 4

# CEM
cem_sample_n = 16
cem_elite_n = 4
cem_horizon_len = 10
cem_init_stdev = 0.1
cem_inner_iter_n = 8

# Experiment parameters
tasks_per_obstacle = 1
attempts_per_task = 1
max_planning_time_s = 60*60*3  # 3 hours
experiment_timeout_s = 60*60*24  # 24 hours

# Evotorch
evotorch_max_generations_per_episode = 100

# Set to true to make the script finish quickly
quick_debug = False
if quick_debug:
    rl_train_timestep_n = 100
    rl_timestep_train_freq_n = 30
    rl_grad_steps_per_train = 2
    rl_log_freq = 30
    cma_max_f_eval_n = 100
    grad_random_restart = 2
    task_per_obstacle = 1
    attempts_per_task = 2
    evotorch_max_generations_per_episode = 3
    max_planning_time_s = 10
    experiment_timeout_s = 60


def rl_solver(
        task: Task,
        log_path: Path,
        seed: int,
        ) -> tuple[list[tuple[SolverTime, Parameters]], bool]:
    """Solve the task using an RL environment custom made for the maze task."""
    def get_sb_algorithm(env) -> SAC:
        return SAC(
            "MlpPolicy",
            env,
            #train_freq=(rl_timestep_train_freq_n, "step"),
            #gradient_steps=rl_grad_steps_per_train,
            #action_noise=NormalActionNoise(noise_mean, noise_sigma),
            use_sde=True,
            verbose=2,
            seed=seed,
            #tensorboard_log=log_path,
            device='cpu',
        )
    log, is_success = custom_maze_rl_solver(
        task,
        task.max_timesteps,
        rl_train_timestep_n,
        get_sb_algorithm,
        timeout_s=max_planning_time_s,
        worker_n=worker_n,
        timeout_check_freq=rl_timeout_check_freq,
        verbose=True,
    )
    return log, is_success


def cem_solver(
        task: Task,
        log_path: Path,
        seed: int,
        ) -> tuple[list[tuple[SolverTime, Parameters]], bool]:
    """Solve the task using an RL environment custom made for the maze task."""
    assert (worker_n is None or worker_n == 1)
    log, is_success = custom_cem_solver(
        task,
        sample_n=cem_sample_n,
        elite_n=cem_elite_n,
        horizon_len=cem_horizon_len,
        init_stdev=cem_init_stdev,
        cem_inner_iter_n=cem_inner_iter_n,
        timeout_s=max_planning_time_s,
        verbose=True,
    )
    return log, is_success


def pylic_cma_solver(
        task: Task,
        log_path: Path,
        seed: int,
        ) -> tuple[list[tuple[SolverTime, Parameters]], bool]:
    """Solve the task using Pylic."""
    log = list()
    start_t = time.time()
    while time.time() - start_t < max_planning_time_s:
        substart_t = time.time() - start_t
        elapsed_s = time.time() - start_t
        remaining_s = max_planning_time_s - elapsed_s
        nlog, success = pylic_cma_solve(
            task=task,
            episode_timestep_n=episode_timestep_n,
            cma_max_f_eval_n=cma_max_f_eval_n,
            worker_n=worker_n,
            seed=seed,
            timeout_s=remaining_s,
            log_path=log_path,
        )
        for (t, p) in nlog:
            log.append((t+substart_t, p))
        if success:
            return log, True
    return log, False


def pylic_grad_solver(
        task: Task,
        log_path: Path,
        seed: int,
        ) -> tuple[list[tuple[SolverTime, Parameters]], bool]:
    """Solve the task using Pylic."""
    log = list()
    start_t = time.time()
    while time.time() - start_t < max_planning_time_s:
        substart_t = time.time() - start_t
        elapsed_s = time.time() - start_t
        remaining_s = max_planning_time_s - elapsed_s
        nlog, success = pylic_grad_solve(
            task=task,
            episode_timestep_n=episode_timestep_n,
            random_restart_n=grad_random_restart_n,
            worker_n=worker_n,
            seed=seed,
            timeout_s=remaining_s,
            log_path=log_path,
        )
        for (t, p) in nlog:
            log.append((t+substart_t, p))
        if success:
            return log, True
    return log, False


def grad_solver(
        task: Task,
        log_path: Path,
        seed: int,
        ) -> tuple[list[tuple[SolverTime, Parameters]], bool]:
    """Solve the given task using gradient descent."""
    log = list()
    start_t = time.time()

    # Start with empty trajectory
    starting_parameters = torch.full(
        size=(1, 2),
        fill_value=0.0,
        )
    log.append((time.time()-start_t, starting_parameters))

    # Optimize the entire trajectory at once
    while time.time() - start_t < max_planning_time_s:
        try:
            parameters = marble_grad_solver(
                predicate=get_reach_goal_predicate(task),
                starting_parameters=starting_parameters,
                initial_state=task.initial_state,
                episode_timestep_n=task.max_timesteps,
                random_restart_n=grad_random_restart_n,
                seed=seed,
                worker_n=worker_n,
            )
            log.append((time.time()-start_t, parameters))
            return log, True
        except SolverFailedException as e:
            traceback.print_exc()
            final_parameters = e.final_parameters
            if final_parameters is not None:
                log.append((time.time()-start_t, final_parameters))
    return log, False


def cma_solver(
        task: Task,
        log_path: Path,
        seed: int,
        ) -> tuple[list[tuple[SolverTime, Parameters]], bool]:
    """Solve the given task using CMA-ES."""
    log = list()

    start_t = time.time()
    while time.time() - start_t < max_planning_time_s:
        starting_parameters = torch.zeros((task.max_timesteps, 2))
        log.append((time.time()-start_t, starting_parameters))
        try:
            parameters = marble_cma_solver(
                predicate=get_reach_goal_predicate(task),
                starting_parameters=starting_parameters,
                initial_state=task.initial_state,
                cma_max_f_eval_n=cma_max_f_eval_n,
                episode_timestep_n=task.max_timesteps,
                worker_n=worker_n,
                seed=seed,
            )
            log.append((time.time()-start_t, parameters))
            return log, True
        except SolverFailedException as e:
            traceback.print_exc()
            final_parameters = e.final_parameters
            if final_parameters is not None:
                log.append((time.time()-start_t, final_parameters))
        return log, False


def pylic_cosyne_solver(
        task: Task,
        log_path: Path,
        seed: int,
        ) -> tuple[list[tuple[SolverTime, Parameters]], bool]:
    """Solve the given task using Pylic and CoSyNE."""
    log = list()
    start_t = time.time()
    try:
        parameters = pylic_evotorch_solve(
            task=task,
            seed=seed,
            episode_timestep_n=episode_timestep_n,
            worker_n=worker_n,
            max_generations_per_episode=evotorch_max_generations_per_episode,
            timeout_s=max_planning_time_s,
        )
        log.append((time.time()-start_t, parameters))
        return log, True
    except SolverFailedException:
        traceback.print_exc()
    return log, False


ObstacleNumber = int
SolverID = str
SolverTime = int | None


def main():
    """Run an experiment comparing Pylic's planner with gradient descent
    and CMA-ES."""
    # Parse arguments
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--maze_len',
        type=int,
        required=True,
    )
    parser.add_argument(
        '--experiment_seed',
        type=int,
        required=True,
    )
    parser.add_argument(
        '--output_dir',
        type=Path,
        required=True,
    )
    args = parser.parse_args()
    maze_lens = [args.maze_len]
    experiment_seed = args.experiment_seed

    start_t = time.time()
    output_path = args.output_dir
    output_path.mkdir()

    # Create tasks
    tasks = list()
    for maze_len in maze_lens:
        task_parameters = RandomTaskImpulseParameters(
            goal_radius=0.2,
            marble_radius=0.05,
            segment_radius=0.03,
            dt=0.1,
            impulse_scale=0.4,
            drag_constant=0.3,
            coefficient_of_restitution=0.99,
            corridor_width=0.4,
            maze_len=maze_len,
            )
        new_tasks = get_tasks(
            tasks_per_obstacle,
            task_parameters,
            f"seed_{maze_len}_{experiment_seed}",
            )
        tasks.extend([(maze_len, task) for task in new_tasks])

    # Try each task multiple times
    original_tasks = list(tasks)
    for _ in range(attempts_per_task-1):
        tasks.extend(original_tasks)

    # Write results as JSON
    def write_results(results: (Results | QualityResults | list[str]), path: Path):
        with open(path, "wt") as fp:
            json.dump(results, fp, indent=2)
        print(f"Wrote {path}")

    time_results_path = output_path/"results.json"
    quality_results_path = output_path/"quality_results.json"
    fatal_errors = list()
    fatal_errors_path = output_path/"fatal_errors.json"

    solvers = [
        ("Pylic[GD]", pylic_grad_solver),
        #pylic_cma_solver,
        #grad_solver,
        #cma_solver,
        ("SAC", rl_solver),
        ("MPC[CEM]", cem_solver),
        #pylic_cosyne_solver,
    ]

    # Run tasks
    experiment_start_t = time.time()
    results: Results = defaultdict(lambda: defaultdict(list))
    quality_results = list()
    experiment_id = 0

    for i, (maze_len, task) in enumerate(tasks):
        if time.time() - experiment_start_t > experiment_timeout_s:
            break
        # Create task directory
        local_output_path = output_path/f"task_{i}"
        local_output_path.mkdir()

        output_animation = local_output_path/"demo.mp4"
        demo_parameters = torch.tensor([1.0, 1.0]).expand(
            task.max_timesteps, -1
            )
        plot_animation(
            states=get_trajectory(demo_parameters, task.initial_state),
            fps=60,
            goal_position=tuple(task.goal_circle.position.tolist()),
            goal_radius=task.goal_circle.radius,
            output_path=output_animation,
        )
        print(f"Wrote {output_animation}")

        # Benchmark each solver
        for solve_id, solve in solvers:
            experiment_id += 1
            log_path = local_output_path/f"{solve_id}_{i}_log"
            log_path.mkdir()

            # Try to solve with current solver
            start_t = time.time()
            try:
                log, is_solved = solve(task, seed=i, log_path=log_path)
            except Exception:
                print(traceback.format_exc())
                print("THIS SHOULD NEVER HAPPEN!!!!")
                fatal_errors.append(f"Fatal error on {solve_id} i={i}")
                log = list()
                is_solved = False

            # Always start with 0 quality
            quality_results.append((
                maze_len,
                solve_id,
                0.0,  # time
                0.0,  # quality
                experiment_id,
            ))

            # Log quality results
            for (solver_time, parameters) in log:
                quality_results.append((
                    maze_len,
                    solve_id,
                    solver_time,
                    get_quality(parameters, task),
                    experiment_id,
                ))

            # Log time result
            if is_solved:
                # Store solver time
                end_t = time.time()
                total_t = end_t-start_t
                results[maze_len][solve_id].append(total_t)

                # Save animation if successful
                output_animation = log_path/f"{solve_id}.mp4"
                parameters = log[-1][1]
                plot_animation(
                    states=get_trajectory(parameters, task.initial_state),
                    fps=60,
                    goal_position=tuple(task.goal_circle.position.tolist()),
                    goal_radius=task.goal_circle.radius,
                    output_path=output_animation,
                )
                print(f"Wrote {output_animation}")
            else:
                results[maze_len][solve_id].append(None)

            # Update results
            write_results(results, time_results_path)
            write_results(quality_results, quality_results_path)
            write_results(fatal_errors, fatal_errors_path)

    # Update results
    write_results(results, time_results_path)
    write_results(quality_results, quality_results_path)
    write_results(fatal_errors, fatal_errors_path)

    # Write results as plot
    plot_dir = output_path/"results_figs"
    plot_dir.mkdir()
    plot_results(results, plot_dir)

    # Plot quality results
    output_quality_plots_dir = output_path/"figures_quality"
    output_quality_plots_dir.mkdir()
    plot_quality_results(quality_results, output_quality_plots_dir)

    # Output total time
    end_t = time.time()
    elapsed_s = end_t - experiment_start_t
    print(f"Total experiment wall-clock time (s): {elapsed_s}")


if __name__ == "__main__":
    main()
