import argparse
import sys
import os
import yaml

import h5py
import torch
import numpy as np
import stable_baselines3 as sb3

sys.path.append(f"{os.path.expanduser('~')}/d4mrl/")
import sim
import utils


def get_reset_data():
    data = dict(
        observations=[],
        next_observations=[],
        actions=[],
        rewards=[],
        terminals=[],
        timeouts=[],
        qpos=[],
        qvel=[]
    )
    return data


@torch.no_grad()
def rollout(policy, env, max_path, num_data, random=False, stacked_frames=False):
    data = get_reset_data()
    traj_data = get_reset_data()

    _returns = 0
    t = 0
    done = False
    s = env.reset()
    while len(data['rewards']) < num_data:

        if random:
            a = env.action_space.sample()
        else:
            # a = policy.select_action(s)
            a = policy.predict(s)[0]

        # mujoco only
        qpos, qvel = env.sim.data.qpos.ravel().copy(), env.sim.data.qvel.ravel().copy()
        ns, rew, done, infos = env.step(a)

        _returns += rew

        t += 1
        timeout = False
        terminal = False
        if t == max_path:
            timeout = True
        elif done:
            terminal = True

        # Added functionality:
        # If the environment trained on stacked frames, there is no reason to keep the entire frames in the dataset
        # Instead we can save only the last added observation
        if stacked_frames > 0:
            unflatten_obs_shape = int(env.observation_space.shape[0] / stacked_frames)
            traj_data['observations'].append(s[-unflatten_obs_shape:])
            traj_data['next_observations'].append(ns[-unflatten_obs_shape:])
        else:
            traj_data['observations'].append(s)
            traj_data['next_observations'].append(ns)

        traj_data['actions'].append(a)
        traj_data['rewards'].append(rew)
        traj_data['terminals'].append(terminal)
        traj_data['timeouts'].append(timeout)
        traj_data['qpos'].append(qpos)
        traj_data['qvel'].append(qvel)

        s = ns
        if terminal or timeout:
            print('Finished trajectory. Len=%d, Returns=%f. Progress:%d/%d' %
                  (t, _returns, len(data['rewards']), num_data))
            s = env.reset()
            t = 0
            _returns = 0
            for k in data:
                data[k].extend(traj_data[k])
            traj_data = get_reset_data()

    new_data = dict(
        observations=np.array(data['observations']).astype(np.float32),
        actions=np.array(data['actions']).astype(np.float32),
        next_observations=np.array(data['next_observations']).astype(np.float32),
        rewards=np.array(data['rewards']).astype(np.float32),
        terminals=np.array(data['terminals']).astype(bool),
        timeouts=np.array(data['timeouts']).astype(bool)
    )
    new_data['infos/qpos'] = np.array(data['qpos']).astype(np.float32)
    new_data['infos/qvel'] = np.array(data['qvel']).astype(np.float32)

    for k in new_data:
        new_data[k] = new_data[k][:num_data]
    return new_data


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--env', type=str)
    parser.add_argument('--agent-path', type=str, default=None)
    parser.add_argument('--output-dir', type=str)
    parser.add_argument('--max-path', type=int, default=1000)
    parser.add_argument('--num-data', type=int, default=1000000)
    parser.add_argument('--random', action='store_true')
    parser.add_argument('--seed', type=int, default=0)

    parser.add_argument('--stacked-frames', type=int, default=0)

    parser.add_argument('--transform-list-agent', action='store_true')
    parser.add_argument('--transform-list', nargs='+')

    args = parser.parse_args()

    utils.set_seed(args.seed)

    if args.transform_list_agent:
        # Read from agent's config:
        agent_config_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(args.agent_path))), 'config.yaml')
        with open(agent_config_path, 'r') as stream:
            data_loaded = yaml.load(stream, Loader=yaml.FullLoader)
        transform_list = data_loaded['simulator']['transform_list']
    elif args.transform_list is None:
        transform_list = []
    else:
        # Create a new transform list
        transform_list = [(args.transform_list[i], utils.num(args.transform_list[i + 1])) for i in
                          range(0, len(args.transform_list) - 1, 2)]

    env, _ = sim.get_transformed_env(args.env, transform_list)

    if args.stacked_frames > 0:
        env = utils.stack_frames(env, args.stacked_frames, flatten=True)

    if not args.random:
        policy = sb3.SAC.load(args.agent_path)
        # policy = TD3(state_dim=env.observation_space.shape[0],
        #              action_dim=env.action_space.shape[0],
        #              max_action=float(env.action_space.high[0]))
        # policy.load(args.agent_path)
    else:
        policy = None

    data = rollout(policy, env, max_path=args.max_path, num_data=args.num_data, random=args.random, stacked_frames=args.stacked_frames)

    os.makedirs(args.output_dir, exist_ok=True)
    output_path = os.path.join(args.output_dir, 'dataset.hdf5')
    hfile = h5py.File(output_path, 'w')
    for k in data:
        hfile.create_dataset(k, data=data[k], compression='gzip')

    for transformation in transform_list:
        hfile.attrs[transformation[0]] = transformation[1]

    hfile.attrs['agent_path'] = args.agent_path

    hfile.close()

    print('--------------------------------------------------')
    print(f'Dataset creation complete. Saved in {output_path}')
    print('--------------------------------------------------')
