import argparse
import gzip
import os
import pickle

import h5py
import numpy as np
import skvideo.io
import torch
from PIL import Image
from rlkit.torch.pytorch_util import set_gpu_mode

from d4rl_alt.locomotion import ant, maze_env, swimmer
from d4rl_alt.locomotion.wrappers import NormalizedBoxEnv


def reset_data():
    return {
        "observations": [],
        "actions": [],
        "terminals": [],
        "rewards": [],
        "infos/goal": [],
        "infos/qpos": [],
        "infos/qvel": [],
    }


def append_data(data, s, a, r, tgt, done, env_data):
    data["observations"].append(s)
    data["actions"].append(a)
    data["rewards"].append(r)
    data["terminals"].append(done)
    data["infos/goal"].append(tgt)
    data["infos/qpos"].append(env_data.qpos.ravel().copy())
    data["infos/qvel"].append(env_data.qvel.ravel().copy())


def npify(data):
    for k in data:
        if k == "terminals":
            dtype = np.bool_
        else:
            dtype = np.float32

        data[k] = np.array(data[k], dtype=dtype)


def load_policy(policy_file):
    data = torch.load(policy_file)
    policy = data["exploration/policy"]
    env = data["evaluation/env"]
    print("Policy loaded")
    if True:
        set_gpu_mode(True)
        policy.cuda()
    return policy, env


def save_video(save_dir, file_name, frames, episode_id=0):
    filename = os.path.join(save_dir, file_name + "_episode_{}".format(episode_id))
    if not os.path.exists(filename):
        os.makedirs(filename)
    num_frames = frames.shape[0]
    for i in range(num_frames):
        img = Image.fromarray(np.flipud(frames[i]), "RGB")
        img.save(os.path.join(filename, "frame_{}.png".format(i)))


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--noisy", action="store_true", help="Noisy actions")
    parser.add_argument(
        "--maze", type=str, default="u-maze", help="Maze type. small or default"
    )
    parser.add_argument(
        "--num_samples", type=int, default=int(1e6), help="Num samples to collect"
    )
    parser.add_argument("--env", type=str, default="Ant", help="Environment type")
    parser.add_argument(
        "--policy_file", type=str, default="policy_file", help="file_name"
    )
    parser.add_argument("--max_episode_steps", default=1000, type=int)
    parser.add_argument("--video", action="store_true")
    parser.add_argument("--multi_start", action="store_true")
    parser.add_argument("--multigoal", action="store_true")
    args = parser.parse_args()

    if args.maze == "u-maze":
        maze = maze_env.U_MAZE
    elif args.maze == "big-maze":
        maze = maze_env.BIG_MAZE
    elif args.maze == "hardest-maze":
        maze = maze_env.HARDEST_MAZE
    else:
        raise NotImplementedError

    if args.env == "Ant":
        env = NormalizedBoxEnv(
            ant.AntMazeEnv(
                maze_map=maze, maze_size_scaling=4.0, non_zero_reset=args.multi_start
            )
        )
    elif args.env == "Swimmer":
        env = NormalizedBoxEnv(
            swimmer.SwimmerMazeEnv(
                mmaze_map=maze, maze_size_scaling=4.0, non_zero_reset=args.multi_start
            )
        )

    env.set_target_goal()
    s = env.reset()
    print(s.shape)
    act = env.action_space.sample()
    done = False

    # Load the policy
    policy, train_env = load_policy(args.policy_file)

    # Define goal reaching policy fn
    def _goal_reaching_policy_fn(obs, goal):
        goal_x, goal_y = goal
        obs_new = obs[2:-2]
        goal_tuple = np.array([goal_x, goal_y])

        # normalize the norm of the relative goals to in-distribution values
        goal_tuple = goal_tuple / np.linalg.norm(goal_tuple) * 10.0

        new_obs = np.concatenate([obs_new, goal_tuple], -1)
        return policy.get_action(new_obs)[0], (
            goal_tuple[0] + obs[0],
            goal_tuple[1] + obs[1],
        )

    data = reset_data()

    # create waypoint generating policy integrated with high level controller
    data_collection_policy = env.create_navigation_policy(
        _goal_reaching_policy_fn,
    )

    if args.video:
        frames = []

    ts = 0
    num_episodes = 0
    for _ in range(args.num_samples):
        act, waypoint_goal = data_collection_policy(s)

        if args.noisy:
            act = act + np.random.randn(*act.shape) * 0.2
            act = np.clip(act, -1.0, 1.0)

        ns, r, done, info = env.step(act)
        if ts >= args.max_episode_steps:
            done = True

        append_data(data, s[:-2], act, r, env.target_goal, done, env.physics.data)

        if len(data["observations"]) % 10000 == 0:
            print(len(data["observations"]))

        ts += 1

        if done:
            done = False
            ts = 0
            s = env.reset()
            env.set_target_goal()
            if args.video:
                frames = np.array(frames)
                save_video("./videos/", args.env + "_navigation", frames, num_episodes)

            num_episodes += 1
            frames = []
        else:
            s = ns

        if args.video:
            curr_frame = env.physics.render(width=500, height=500, depth=False)
            frames.append(curr_frame)

    if args.noisy:
        fname = args.env + "_maze_%s_noisy_multistart_%s_multigoal_%s.hdf5" % (
            args.maze,
            str(args.multi_start),
            str(args.multigoal),
        )
    else:
        fname = args.env + "maze_%s_multistart_%s_multigoal_%s.hdf5" % (
            args.maze,
            str(args.multi_start),
            str(args.multigoal),
        )
    dataset = h5py.File(fname, "w")
    npify(data)
    for k in data:
        dataset.create_dataset(k, data=data[k], compression="gzip")


if __name__ == "__main__":
    main()
