import argparse
import os
from pathlib import Path

import cv2
import gym
from tqdm import tqdm

import envs


if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    parser.add_argument('--dataset_size', type=int, default=100000)
    parser.add_argument('--dataset_path', type=str, default='shapes2d_dataset')
    parser.add_argument('--env_id', type=str, required=True)

    args = parser.parse_args()

    env = gym.make(args.env_id)
    observations = []
    print('Collect dataset')
    for i in tqdm(range(args.dataset_size)):
        observations.append(env.reset())

    Path(args.dataset_path).mkdir(parents=True, exist_ok=True)
    print('Save dataset')
    for i, observation in tqdm(enumerate(observations)):
        path = os.path.join(args.dataset_path, f'{i:06d}.png')
        cv2.imwrite(path, cv2.cvtColor(observation, cv2.COLOR_RGB2BGR))