import numpy as np
import os
import pickle
from collect_data_mini import rollin, generate_histories_for_env_ids


if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("--envs", type=int, required=False, default=100, help="Envs")
    parser.add_argument("--H", type=int, required=False, default=10, help="Context horizon")
    parser.add_argument("--envname", type=str, required=False, default="mini", help="Environment name")
    parser.add_argument('--rollin', type=str, required=False, default="uniform", help="Whether to collect eval trajs in train tasks")
    parser.add_argument('--collect_in_train_tasks', default=False, action='store_true', help="Whether to collect eval trajs in train tasks")
    parser.add_argument("--small", action='store_true', help="Use small images")

    args = vars(parser.parse_args())
    print("Args:")
    print(args)

    n_envs = args['envs']
    H = args['H']
    envname = args['envname']
    rollin_type = args['rollin']
    train = args['collect_in_train_tasks']
    small = args['small']

    print("Save as small: ", small)

    env_ids = np.arange(10000)
    train_test_split = int(.8 * len(env_ids))
    traj_str = 'expert_trajs' if rollin_type == 'expert' else 'trajs'
    traj_str += '_test' if not train else '_train'
    filepath = f'datasets/{traj_str}_{envname}_envs{n_envs}_H{H}_small{small}.pkl'

    if envname == 'mini' or envname.startswith('mini_two_boxes'):
        env_name = 'MiniWorld-OneRoomS6FastMultiFixedInit-v0'
    elif envname.startswith('mini_three_boxes'):
        env_name = 'MiniWorld-OneRoomS6FastMultiThreeBoxesFixedInit-v0'
    elif envname.startswith('mini_four_boxes'):
        env_name = 'MiniWorld-OneRoomS6FastMultiFourBoxesFixedInit-v0'
    elif envname.startswith('mini_blue'):
        env_name = 'MiniWorld-OneRoomS6FastMultiBlueFixedInit-v0'
    else:
        raise ValueError("Invalid envname")

    if train:
        train_env_ids = env_ids[:train_test_split]
        train_env_ids = train_env_ids[:n_envs]
        trajs = generate_histories_for_env_ids(
            env_name, train_env_ids, 1, 1, H,
            small=small,
            rollin_type=rollin_type,
            image_dir=filepath.split('.')[0])
    else:
        test_env_ids = env_ids[train_test_split:]
        test_env_ids = test_env_ids[:n_envs]
        trajs = generate_histories_for_env_ids(
            env_name, test_env_ids, 1, 1, H,
            small=small,
            rollin_type=rollin_type,
            image_dir=filepath.split('.')[0])

    if not os.path.exists('datasets'):
        os.makedirs('datasets', exist_ok=True)
    with open(filepath, 'wb') as file:
        pickle.dump(trajs, file)
    
    print(f"Saved to {filepath}.")
