"""Pretrain an RL policy."""
from rl_baseline import pretrain_policy
from rl_baseline import HierarchicalButtonPreTrainingEnv
from main import actuator_n
from main import sub_step_s
from main import rl_episode_timestep_n
from main import rl_grad_steps_per_train
from main import rl_timestep_train_freq_n
from main import rl_train_timestep_n
from main import rl_eval_freq
from main import rl_password_so_far_encoding_size
from pathlib import Path
import json
from plotting import frames_to_video
from stable_baselines3.common.vec_env.vec_normalize import VecNormalize
from stable_baselines3.common.vec_env import DummyVecEnv

# Make this true to make script finish quickly (for debugging)
quick = False


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--timeout-h',
        type=float,
        required=True,
    )
    parser.add_argument(
        '--max-button-n',
        type=int,
        required=True,
    )
    parser.add_argument(
        '--worker-n',
        type=int,
        required=True,
    )
    parser.add_argument(
        '--output-dir',
        type=Path,
        required=True,
    )

    # Parse arguments
    args = parser.parse_args()
    train_timeout_s = args.timeout_h*60*60
    output_dir = args.output_dir
    max_button_n = args.max_button_n
    worker_n = args.worker_n
    output_video_episode_n = 32

    if quick:
        train_timeout_s = 15

    # Create output directory
    output_dir.mkdir(exist_ok=True)
    policy_path = output_dir/"pretrained_policy.zip"
    replay_buffer_path = output_dir/"replay_buffer.pickle"
    vec_normalize_path = output_dir/"vec_normalize.pickle"

    # Identify current policy
    if policy_path.exists():
        pretrained_model_path = policy_path
    else:
        pretrained_model_path = None

    # Identify replay buffer
    if replay_buffer_path.exists():
        pass
    else:
        replay_buffer_path = None

    if vec_normalize_path.exists():
        pass
    else:
        vec_normalize_path = None

    # Identify number of previous runs
    experiment_n = len(list(output_dir.glob("run_*")))
    experiment_dir = output_dir/f"run_{experiment_n+1}"
    experiment_dir.mkdir()

    # Training logs
    rl_log_path = experiment_dir/"log"
    rl_log_path.mkdir()

    # Parameters
    training_parameters = dict(
        pretrained_model_path=pretrained_model_path,
        replay_buffer_path=replay_buffer_path,
        vec_normalize_path=vec_normalize_path,
        episode_timestep_n=rl_episode_timestep_n,
        actuator_n=actuator_n,
        max_button_n=max_button_n,
        rl_timestep_train_freq_n=rl_timestep_train_freq_n,
        rl_grad_steps_per_train=rl_grad_steps_per_train,
        train_timestep_n=rl_train_timestep_n,
        worker_n=worker_n,
        seed=0,
        eval_freq=rl_eval_freq,
        timeout_s=train_timeout_s,
        sub_step_s=sub_step_s,
        password_so_far_encoding_size=rl_password_so_far_encoding_size,
        log_path=rl_log_path,
        verbose=True,
    )

    # Pre-train policy
    model = pretrain_policy(**training_parameters)

    # Save policy
    policy_dirs = [output_dir, experiment_dir]
    for policy_dir in policy_dirs:
        # Save policy weights
        policy_path = policy_dir/"pretrained_policy.zip"
        model.save(policy_path)
        print(f"Wrote {policy_path}")

        # Save replay buffer if it exists
        replay_buffer_path = policy_dir/"replay_buffer.pickle"
        save_replay_buffer = getattr(model, "save_replay_buffer", None)
        if callable(save_replay_buffer):
            model.save_replay_buffer(replay_buffer_path)
            print(f"Wrote {replay_buffer_path}")

        # Save normalization statistics
        vec_normalize = model.get_vec_normalize_env()
        if vec_normalize is not None:
            vec_normalize_path = policy_dir/"vec_normalize.pickle"
            vec_normalize.save(str(vec_normalize_path))
            print(f"Wrote {vec_normalize_path}")

    # Save training parameters
    parameters_path = experiment_dir/"parameters.json"
    with open(parameters_path, "wt") as fp:
        json.dump(
            {
                k: v if not isinstance(v, Path) else str(v)
                for k, v in training_parameters.items()
            },
            fp,
            indent=2,
        )
    print(f"Wrote {parameters_path}")

    # Write video
    video_path = experiment_dir/"video.mp4"
    video_length = rl_episode_timestep_n*10

    def make_env():
        env = HierarchicalButtonPreTrainingEnv(
            max_button_n=max_button_n,
            actuator_n=actuator_n,
            episode_timestep_n=rl_episode_timestep_n,
            sub_step_s=sub_step_s,
            password_so_far_encoding_size=rl_password_so_far_encoding_size,
        )
        return env
    # Wrap in vec normalize
    if vec_normalize_path is not None:
        env = VecNormalize.load(
            str(vec_normalize_path),
            DummyVecEnv([make_env]),
        )
        env.training = False
    else:
        env = make_env()
    obs = env.reset()
    frames = [env.render()]
    for i in range(video_length + 1):
        action, _ = model.predict(obs, deterministic=True)
        obs, _, is_done, _ = env.step(action)
        frames.append(env.render())
        if is_done:
            obs = env.reset()
    env.close()
    frames_to_video(frames, fps=60, output_path=video_path)
    print(f"Wrote {video_path}")
