"""A simple example of using pylic for exploiting program structure in a
locomotion task.

The task is to reach a target location, which requires pushing a button
to remove obstacles. This structure is reflected in the simulator. We use
the simulator code to instruct pylic how to optimize parameters.

The solution is found by iteratively solving the conjunction of "push button
AND reach the goal".
"""
from pathlib import Path
from plotting import plot_animation
from examples.plotting import plot_results
from examples.plotting import plot_quality_results
from examples.plotting import Results
from examples.plotting import QualityResults
from examples.plotting import SolverTime
from collections import defaultdict
from rl_baseline import rl_tight_guidance_solver
from trajectory_pylic import pylic_cma_solver
from trajectory_baseline import cem_solver
from pylic.predicates import SolverFailedException
from quality import get_quality
import traceback
import torch
import random
import time
import json
import argparse
import itertools


Parameters = torch.Tensor

actuator_n = 8
max_episodes_per_predicate = 2
worker_n = 32
# timeout_s = 60*240 # 240 minutes
timeout_s = 60*960
sub_step_s = 0.1
# episode_timestep_n = int((1/sub_step_s)*4)  # 4 seconds for each sub-goal
episode_timestep_n = int((1/sub_step_s)*2)

# CMA
cma_max_f_eval_n = 16000

# RL
rl_train_timestep_n = 1_000_000_000
rl_episode_timestep_n = int((1/sub_step_s)*12)  # 12 seconds for entire RL episode
rl_grad_steps_per_train = 1
rl_timestep_train_freq_n = 1
rl_eval_freq = 5000
rl_password_so_far_encoding_size = 5


# CEM
cem_horizon_len = 10
cem_init_stdev = 0.2
cem_inner_iter_n = 16
cem_sample_n = worker_n*2
cem_elite_n = 8


# Evaluate on all (full) tasks up to length 3, including empty password
tasks: list[tuple[int, tuple[int, ...]]] = [(0, tuple())]
for button_n in [1, 2, 3]:
    buttons = list(range(button_n))
    for password in itertools.permutations(buttons):
        tasks.append((button_n, password))
task_iteration_n = 1  # number of times to evaluate each task
max_button_n = max(button_n for button_n, _ in tasks)
assert max_button_n <= rl_password_so_far_encoding_size

# Repeat tasks
for task in list(tasks):
    for _ in range(task_iteration_n-1):
        tasks.append(task)

# Set to true to make the script finish early
quick = False
if quick:
    worker_n = 3
    cma_max_f_eval_n = 30
    rl_train_timestep_n = 100
    timeout_s = 10


def pylic_cma(
        target_password: tuple[int, ...],
        actuator_n: int,
        button_n: int,
        ) -> tuple[list[tuple[SolverTime, Parameters]], bool]:
    solutions = list()
    is_successful = False
    start_t = time.time()
    try:
        parameters, log = pylic_cma_solver(
            target_password=target_password,
            episode_timestep_n=episode_timestep_n,
            actuator_n=actuator_n,
            button_n=button_n,
            max_f_eval_n=cma_max_f_eval_n,
            max_episodes_per_predicate=max_episodes_per_predicate,
            sub_step_s=sub_step_s,
            worker_n=worker_n,
        )
        solutions.extend(log)
        solutions.append((time.time()-start_t, parameters))
        is_successful = True
    except SolverFailedException as e:
        total_t = time.time()-start_t
        solutions.append((total_t, e.final_parameters))
        print(traceback.format_exc())
    return solutions, is_successful


def custom_cem(
        target_password: tuple[int, ...],
        actuator_n: int,
        button_n: int,
        ) -> tuple[list[tuple[SolverTime, Parameters]], bool]:
    return cem_solver(
        target_password=target_password,
        actuator_n=actuator_n,
        button_n=button_n,
        worker_n=worker_n,
        password_so_far_encoding_size=rl_password_so_far_encoding_size,
        sub_step_s=sub_step_s,
        horizon_len=cem_horizon_len,
        init_stdev=cem_init_stdev,
        cem_inner_iter_n=cem_inner_iter_n,
        timeout_s=timeout_s,
        sample_n=cem_sample_n,
        elite_n=cem_elite_n,
        max_timesteps=episode_timestep_n*(1+len(target_password)),
        verbose=True,
    )


def rl_tight_guidance(
        pretrained_dir: Path,
        target_password: tuple[int, ...],
        actuator_n: int,
        button_n: int,
        ) -> tuple[list[tuple[SolverTime, Parameters]], bool]:
    # Identify pretrained snapshosts
    snapshot_dirs = list(pretrained_dir.glob("run_*"))
    snapshot_names = [snapshot_dir.name for snapshot_dir in snapshot_dirs]
    snapshot_ids = [int(name.replace("run_", "")) for name in snapshot_names]

    # Go through each snapshot in increasing ID until one satisfies the goal
    pretrain_time = 0
    is_successful = False
    solutions = list()
    for snapshot_id in sorted(snapshot_ids):
        # Filter snapshot
        snapshot_name = f"run_{snapshot_id}"
        matching_snapshot_dirs = [
            snapshot_dir
            for snapshot_dir in snapshot_dirs
            if snapshot_dir.name == snapshot_name
        ]
        assert len(matching_snapshot_dirs) == 1
        snapshot_dir = matching_snapshot_dirs[0]

        # Load model and environment statistics
        model_path = snapshot_dir/"pretrained_policy.zip"
        vec_normalize_path = snapshot_dir/"vec_normalize.pickle"
        parameters_path = snapshot_dir/"parameters.json"
        if any((
                not model_path.exists(),
                not vec_normalize_path.exists(),
                not parameters_path.exists(),
                )):
            continue
        with open(parameters_path, "rt") as fp:
            parameters = json.load(fp)

        # Add pretrain time
        pretrain_time += parameters["timeout_s"]

        start_t = time.time()
        try:
            parameters = rl_tight_guidance_solver(
                model_path=model_path,
                vec_normalize_path=vec_normalize_path,
                target_password=target_password,
                episode_timestep_n=rl_episode_timestep_n,
                actuator_n=actuator_n,
                button_n=button_n,
                rl_grad_steps_per_train=rl_grad_steps_per_train,
                rl_timestep_train_freq_n=rl_timestep_train_freq_n,
                train_timestep_n=rl_train_timestep_n,
                eval_freq=rl_eval_freq,
                timeout_s=timeout_s,
                worker_n=worker_n,
                sub_step_s=sub_step_s,
                password_so_far_encoding_size=rl_password_so_far_encoding_size,
                seed=random.randint(0, 2**16-1),
                verbose=True,
            )
            is_successful = True
            total_t = time.time()-start_t
            solutions.append((pretrain_time+total_t, parameters))
            break
        except SolverFailedException as e:
            total_t = time.time()-start_t
            solutions.append((pretrain_time+total_t, e.final_parameters))
            pass

    return solutions, is_successful


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description='Evaluate algorithms on the button password task.'
    )
    parser.add_argument(
        '--pretrained_dir',
        type=Path,
        help='Directry with data from pre-training',
        required=True,
    )
    args = parser.parse_args()

    output_path = Path()/"planning_buttons"
    output_path.mkdir()

    seed = "my seed :D"

    # List of solvers to benchmark
    pretrained_dir = args.pretrained_dir

    def pretrained_rl(*args, **kwargs):
        return rl_tight_guidance(pretrained_dir, *args, **kwargs)
    solvers = [
        pylic_cma,
        pretrained_rl,
        custom_cem,
    ]

    # Track solver statistics
    time_results: Results = defaultdict(lambda: defaultdict(list))
    quality_results: QualityResults = list()

    # Write current results regularly
    time_results_path = output_path/"results.json"
    quality_results_path = output_path/"quality_results.json"

    def write_results(results: (Results | QualityResults), path: Path):
        with open(path, "wt") as fp:
            json.dump(results, fp, indent=2)
        print(f"Wrote {path}")

    # Benchmark each solver
    experiment_id = 0
    for solver in solvers:
        # Create a directory for this solver's results
        solver_id = solver.__name__
        solver_output_path = output_path/solver_id
        solver_output_path.mkdir()

        for button_n, password in tasks:
            # Solve task
            starting_parameters = torch.zeros(
                (int(episode_timestep_n), 1, actuator_n)
            )

            # Execute solver
            solutions, is_successful = solver(
                target_password=password,
                actuator_n=actuator_n,
                button_n=button_n,
            )

            if is_successful:
                # Compute total time
                total_t = solutions[-1][0]

                # Log experiment time
                time_results[button_n][solver_id].append(total_t)
                write_results(time_results, time_results_path)
            else:
                time_results[button_n][solver_id].append(None)

            if len(solutions) > 0:
                # Extract parameters
                final_parameters = solutions[-1][1]
                # Plot solution
                if is_successful:
                    output_animation = solver_output_path/f"ant_{str(password)}_SUCCESS.mp4"
                else:
                    output_animation = solver_output_path/f"ant_{str(password)}_FAIL.mp4"
                plot_animation(
                    parameters=final_parameters,
                    fps=60,
                    output_path=output_animation,
                    password=password,
                    num_buttons=button_n,
                    sub_step_s=sub_step_s,
                )
                print(f"Wrote {output_animation}")

            # Log quality of each returned solution
            max_quality_so_far = None
            for total_t, parameters in solutions:
                quality = get_quality(
                    parameters,
                    password,
                    button_n,
                    sub_step_s,
                )
                if max_quality_so_far is None or max_quality_so_far < quality:
                    max_quality_so_far = quality
                quality_result = (
                    button_n,
                    solver_id,
                    total_t,
                    max_quality_so_far,
                    experiment_id,
                )
                quality_results.append(quality_result)
            write_results(quality_results, quality_results_path)

            experiment_id += 1

    # Write experiment results
    write_results(time_results, time_results_path)
    write_results(quality_results, quality_results_path)

    # Plot time results
    output_plots_dir = output_path/"figures_time"
    output_plots_dir.mkdir()
    plot_results(time_results, output_plots_dir)
    print(f"Wrote {output_plots_dir}")

    # Plot quality results
    output_quality_plots_dir = output_path/"figures_quality"
    output_quality_plots_dir.mkdir()
    plot_quality_results(quality_results, output_quality_plots_dir)
    print(f"Wrote {output_quality_plots_dir}")
