import argparse
import sys
import os
import yaml

import h5py
import torch
import numpy as np
import gymnasium
sys.modules["gym"] = gymnasium
import stable_baselines3 as sb3

sys.path.append(f"{os.path.expanduser('~')}/d4mrl/")

# import gymnasium
# sys.modules["gym"] = gymnasium

import highway
import utils


def get_reset_data():
    data = dict(
        observations=[],
        next_observations=[],
        actions=[],
        rewards=[],
        terminals=[],
    )
    return data


@torch.no_grad()
def rollout(policy, env, num_data, random=False):
    data = get_reset_data()
    traj_data = get_reset_data()

    _returns = 0
    t = 0
    done = False
    s = env.reset()
    while len(data['rewards']) < num_data:

        if random:
            a = env.action_space.sample()
        else:
            # a = policy.select_action(s)
            a = policy.predict(s)[0]

        # mujoco only
        ns, rew, done, infos = env.step(a)

        _returns += rew

        t += 1

        # Added functionality:
        # If the environment trained on stacked frames, there is no reason to keep the entire frames in the dataset
        # Instead we can save only the last added observation
        traj_data['observations'].append(s)
        traj_data['next_observations'].append(ns)
        traj_data['actions'].append(a)
        traj_data['rewards'].append(rew)
        traj_data['terminals'].append(done)

        s = ns
        if done:
            print('Finished trajectory. Len=%d, Returns=%f. Progress:%d/%d' %
                  (t, _returns, len(data['rewards']), num_data))
            s = env.reset()
            t = 0
            _returns = 0
            for k in data:
                data[k].extend(traj_data[k])
            traj_data = get_reset_data()

    new_data = dict(
        observations=np.array(data['observations']).astype(np.float32),
        actions=np.array(data['actions']).astype(np.float32),
        next_observations=np.array(data['next_observations']).astype(np.float32),
        rewards=np.array(data['rewards']).astype(np.float32),
        terminals=np.array(data['terminals']).astype(bool),
    )

    for k in new_data:
        new_data[k] = new_data[k][:num_data]
    return new_data


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--env', type=str)
    parser.add_argument('--agent-path', type=str, default=None)
    parser.add_argument('--output-dir', type=str)
    parser.add_argument('--num-data', type=int, default=100000)
    parser.add_argument('--random', action='store_true')
    parser.add_argument('--seed', type=int, default=0)

    parser.add_argument('--stacked-frames', type=int, default=0)

    parser.add_argument('--transform-list-agent', action='store_true')
    parser.add_argument('--transform-list', nargs='+')

    args = parser.parse_args()

    utils.set_seed(args.seed)

    if args.transform_list_agent:
        # Read from agent's config:
        agent_config_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(args.agent_path))), 'config.yaml')
        with open(agent_config_path, 'r') as stream:
            data_loaded = yaml.load(stream, Loader=yaml.FullLoader)
        transform_list = data_loaded['simulator']['transform_list']
    elif args.transform_list is None:
        transform_list = []
    else:
        # Create a new transform list
        transform_list = [(args.transform_list[i], utils.num(args.transform_list[i + 1])) for i in
                          range(0, len(args.transform_list) - 1, 2)]

    env = highway.make_env(args.env, wrap=True)
    env.reset()

    if not args.random:
        policy = sb3.SAC.load(args.agent_path, env)
    else:
        policy = None

    data = rollout(policy, env, num_data=args.num_data, random=args.random)

    os.makedirs(args.output_dir, exist_ok=True)
    output_path = os.path.join(args.output_dir, 'dataset.hdf5')
    hfile = h5py.File(output_path, 'w')
    for k in data:
        hfile.create_dataset(k, data=data[k], compression='gzip')

    for transformation in transform_list:
        hfile.attrs[transformation[0]] = transformation[1]

    hfile.attrs['agent_path'] = args.agent_path

    hfile.close()

    print('--------------------------------------------------')
    print(f'Dataset creation complete. Saved in {output_path}')
    print('--------------------------------------------------')
