import os
import sys
import time
import warnings
from pathlib import Path
from typing import Optional

os.environ["D4RL_SUPPRESS_IMPORT_ERROR"] = "1"

import d4rl
import gym
import imageio.v2
import moviepy.editor as mpy
from omegaconf import OmegaConf
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.utils import safe_mean

sys.path.append("../ours")
from ours.utils.visualize_episodes import put_infos
from ours.utils.utils import get_success


def learn(
    model,
    args,
    callback=None,
    eval_env=None,
    eval_freq: int = -1,
    n_eval_episodes: int = 5,
    tb_log_name: str = "PPO",
    eval_log_path: Optional[str] = None,
    reset_num_timesteps: bool = True,
):
    total_timesteps = args.train_timesteps
    log_interval = args.log_interval
    iteration = 0

    total_timesteps, callback = model._setup_learn(
        total_timesteps, eval_env, callback, eval_freq, n_eval_episodes,
        eval_log_path, reset_num_timesteps, tb_log_name)

    callback.on_training_start(locals(), globals())

    while model.num_timesteps < total_timesteps:

        continue_training = model.collect_rollouts(
            model.env,
            callback,
            model.rollout_buffer,
            n_rollout_steps=model.n_steps)

        if continue_training is False:
            break

        iteration += 1
        model._update_current_progress_remaining(model.num_timesteps,
                                                 total_timesteps)

        # Display training infos
        if log_interval is not None and iteration % log_interval == 0:
            fps = int((model.num_timesteps - model._num_timesteps_at_start) /
                      (time.time() - model.start_time))
            model.logger.record("time/iterations",
                                iteration,
                                exclude="tensorboard")
            if len(model.ep_info_buffer) > 0 and len(
                    model.ep_info_buffer[0]) > 0:
                model.logger.record(
                    "rollout/ep_rew_mean",
                    safe_mean(
                        [ep_info["r"] for ep_info in model.ep_info_buffer]))
                model.logger.record(
                    "rollout/ep_len_mean",
                    safe_mean(
                        [ep_info["l"] for ep_info in model.ep_info_buffer]))
            model.logger.record("time/fps", fps)
            model.logger.record("time/time_elapsed",
                                int(time.time() - model.start_time),
                                exclude="tensorboard")
            model.logger.record("time/total_timesteps",
                                model.num_timesteps,
                                exclude="tensorboard")
            model.logger.dump(step=model.num_timesteps)
            model.save(args.expert_path)

        model.train()

    callback.on_training_end()

    return model


if __name__ == "__main__":
    base_args = OmegaConf.create({
        "env_id": "point-test-v1",
        "expert_path": "experts/point/point-test-v1_ppo.zip",
        "load_model": True,
        "visualize_demo": True,
        "train_timesteps": 1_000_000,
        "n_envs": 10,
        "log_interval": 1,
        "env_kwargs": {
            "eval": False,
            "return_direction": True,
            "non_zero_reset": True,
            "reward_type": "dense",
        }
    })
    cli_args = OmegaConf.from_cli()
    args = OmegaConf.merge(base_args, cli_args)
    expert_path = Path(args.expert_path)

    # Parallel environments
    env = make_vec_env(args.env_id,
                       n_envs=args.n_envs,
                       env_kwargs=args.env_kwargs)
    print(env.observation_space)

    model = PPO(
        policy="MlpPolicy",
        env=env,
        use_sde=True,
        verbose=1,
    )

    if expert_path.exists() and args.load_model:
        model = PPO.load(expert_path, env=env)
        print(expert_path, "is loaded.")
    learn(model, args)
    model.save(expert_path)

    if args.visualize_demo:
        env = gym.make(args.env_id, **args.env_kwargs)

        skip = 5
        repeat = 30

        frames = []
        for i in range(repeat):
            obs = env.reset()
            env.render(mode="rgb_array")
            env.reset()
            done = False
            t = 1
            while not done:
                for _ in range(skip):
                    action, _states = model.predict(obs)
                    obs, rew, done, info = env.step(action)
                    t += 1
                    done |= get_success(obs, env.get_target(), args.env_id)
                frame = env.render(mode="rgb_array").astype("uint8")
                put_infos(frame, {
                    "i": i,
                    "t": t,
                    "obs": obs,
                    "act": action,
                    "rew": rew
                })
                frames.append(frame)

        imageio.mimsave(f"demo/ppo/{args.env_id}.mp4", frames, fps=20)
