import gymnasium as gym
import miniworld
import numpy as np
import os
import pickle
import argparse
from IPython import embed
import matplotlib.pyplot as plt
from skimage.transform import resize

target_shape = (25, 25, 3)


def rollin(env, H, small=False, rollin_type='uniform'):
    observations = []
    poses = []
    angles = []
    actions = []
    next_observations = []
    next_poses = []
    next_angles = []
    rewards = []
    for h in range(H):
        if rollin_type == 'uniform':
            env.place_agent()
        obs = env.render_obs()
        if small:
            obs = resize(obs, target_shape, anti_aliasing=True)
        observations.append(obs)
        poses.append(env.agent.pos[[0, -1]].copy())
        angles.append(env.agent.dir_vec[[0, -1]].copy())

        if rollin_type == 'uniform':
            action = np.random.randint(env.action_space.n)
        elif rollin_type == 'expert':
            action = env.opt_a(obs, env.agent.pos, env.agent.dir_vec)
        else:
            raise ValueError("Invalid rollin type")
        _, rew, _, _, _ = env.step(action)
        a_zero = np.zeros(env.action_space.n)
        a_zero[action] = 1

        actions.append(a_zero)
        next_obs = env.render_obs()
        if small:
            next_obs = resize(next_obs, target_shape, anti_aliasing=True)
        next_observations.append(next_obs)
        next_poses.append(env.agent.pos[[0, -1]].copy())
        next_angles.append(env.agent.dir_vec[[0, -1]].copy())
        rewards.append(rew)

    observations = np.array(observations)
    poses = np.array(poses)
    angles = np.array(angles)
    actions = np.array(actions)
    next_observations = np.array(next_observations)
    next_poses = np.array(next_poses)
    next_angles = np.array(next_angles)
    rewards = np.array(rewards)
    return (
        observations,
        poses,
        angles,
        actions,
        next_observations,
        next_poses,
        next_angles,
        rewards,
    )


def generate_histories_for_env_ids(
        env_name,
        env_ids,
        n_hists,
        n_samples,
        H,
        small=False,
        rollin_type='uniform',
        image_dir=''):
    if not os.path.exists(image_dir):
        os.makedirs(image_dir, exist_ok=True)

    n_envs = len(env_ids)
    env = gym.make(env_name)
    obs = env.reset()

    trajs = []
    for i, env_id in enumerate(env_ids):
        print("Generating histories for env {}/{}".format(i, n_envs))
        env.set_task(env_id)
        env.reset()
        for j in range(n_hists):
            (
                rollin_obs,
                rollin_poses,
                rollin_angles,
                rollin_us,
                rollin_next_obs,
                rollin_next_poses,
                rollin_next_angles,
                rollin_rs,
            ) = rollin(
                env, H, small=small, rollin_type=rollin_type)
            filepath = '{}/rollin{}_{}.npy'.format(image_dir, i, j)
            np.save(filepath, rollin_obs)

            next_filepath = '{}/next_rollin{}_{}.npy'.format(image_dir, i, j)
            np.save(next_filepath, rollin_next_obs)

            for k in range(n_samples):
                env.place_agent()
                obs = env.render_obs()
                if small:
                    obs = resize(obs, target_shape, anti_aliasing=True)
                # a, _ = opt_policy(env)
                a = env.opt_a(obs, env.agent.pos, env.agent.dir_vec)
                a_zero = np.zeros(env.action_space.n)
                a_zero[a] = 1

                traj ={
                    'state': obs,
                    'action': a_zero,
                    'rollin_obs': filepath,
                    'rollin_poses': rollin_poses,
                    'rollin_angles': rollin_angles,
                    'rollin_next_obs': next_filepath,
                    'rollin_next_poses': rollin_next_poses,
                    'rollin_next_angles': rollin_next_angles,
                    'rollin_us': rollin_us,
                    'rollin_rs': rollin_rs,
                    'pose': env.agent.pos[[0, -1]].copy(),
                    'angle': env.agent.dir_vec[[0, -1]].copy(),
                }
                trajs.append(traj)

    return trajs


if __name__ == '__main__':

    if not os.path.exists('datasets'):
        os.makedirs('datasets', exist_ok=True)
    os.makedirs('figs/images', exist_ok=True)

    parser = argparse.ArgumentParser()
    parser.add_argument("--envs", type=int, required=False, default=10000, help="Envs")
    parser.add_argument("--hists", type=int, required=False, default=1, help="Histories")
    parser.add_argument("--samples", type=int, required=False, default=1, help="Samples")
    parser.add_argument("--H", type=int, required=False, default=3, help="Context horizon")
    parser.add_argument("--envname", type=str, required=False, default='mini', help="Environment name")
    parser.add_argument("--rollin_type", type=str, required=False, default='uniform', help="Rollin type")
    parser.add_argument("--small", action='store_true', help="Use small images")

    args = vars(parser.parse_args())
    print("Args:")
    print(args)

    n_envs = args['envs']
    n_hists = args['hists']
    n_samples = args['samples']
    H = args['H']
    envname = args['envname']
    rollin_type = args['rollin_type']
    small = args['small']

    print("Save as small: ", small)

    n_envs_tr = int(.8 * n_envs)
    n_envs_te = n_envs - n_envs_tr

    traj_prefix = ''
    if rollin_type == 'expert':
        traj_prefix = 'expert_'
    filepath_tr = 'datasets/{}trajs_{}_envs{}_hists{}_samples{}_small{}_train.pkl'.format(traj_prefix, envname, n_envs, n_hists, n_samples, small)
    filepath_te = 'datasets/{}trajs_{}_envs{}_hists{}_samples{}_small{}_test.pkl'.format(traj_prefix, envname, n_envs, n_hists, n_samples, small)

    env_ids = np.arange(n_envs)
    train_test_split = int(.8 * len(env_ids))
    train_env_ids = env_ids[:train_test_split]
    test_env_ids = env_ids[train_test_split:]
    train_env_ids = np.repeat(train_env_ids, n_envs // len(env_ids), axis=0)
    test_env_ids = np.repeat(test_env_ids, n_envs // len(env_ids), axis=0)

    if envname == 'mini' or envname.startswith('mini_two_boxes'):
        env_name = 'MiniWorld-OneRoomS6FastMultiFixedInit-v0'
    elif envname.startswith('mini_three_boxes'):
        env_name = 'MiniWorld-OneRoomS6FastMultiThreeBoxesFixedInit-v0'
    elif envname.startswith('mini_four_boxes'):
        env_name = 'MiniWorld-OneRoomS6FastMultiFourBoxesFixedInit-v0'
    elif envname.startswith('mini_blue'):
        env_name = 'MiniWorld-OneRoomS6FastMultiBlueFixedInit-v0'
    else:
        raise ValueError("Invalid envname")

    train_trajs = generate_histories_for_env_ids(
        env_name,
        train_env_ids,
        n_hists,
        n_samples,
        H,
        small=small,
        image_dir=filepath_tr.split('.')[0],
        rollin_type=rollin_type)
    test_trajs = generate_histories_for_env_ids(
        env_name,
        test_env_ids,
        n_hists,
        n_samples,
        H,
        small=small,
        image_dir=filepath_te.split('.')[0],
        rollin_type=rollin_type)

    traj = train_trajs[0]
    for i in range(min(H, 30)):
        filepath = traj['rollin_obs']
        rollin_obs = np.load(filepath)[i,:,:]
        plt.imshow(rollin_obs)
        plt.savefig('figs/images/rollin_{}.png'.format(i))

    with open(filepath_tr, 'wb') as file:
        pickle.dump(train_trajs, file)
    with open(filepath_te, 'wb') as file:
        pickle.dump(test_trajs, file)

    exit(1)
