import argparse
import collections
import concurrent.futures
import os
import random
from pathlib import Path

import cv2
import numpy as np
from gym.wrappers import TimeLimit
from omegaconf import OmegaConf
from tqdm import tqdm

from envs.cw_envs import CwTargetEnv


def make_env(config_env_path, seed):
    env_config = OmegaConf.load(config_env_path)
    random.seed(seed)
    np.random.seed(seed)
    env = CwTargetEnv(env_config, seed)
    env.action_space.seed(seed)
    env = TimeLimit(env, env.unwrapped._max_episode_length)
    return env


def collect_observations(config, seed):
    env = make_env(config, seed)
    env.reset()
    episode_observations = [env.render()]
    done = False
    while not done:
        action = env.action_space.sample()
        _, _, done, _ = env.step(action)
        episode_observations.append(env.render())

    return episode_observations


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--config_env_path', type=str, required=True)
    parser.add_argument('--max_workers', type=int, default=1)
    parser.add_argument('--dataset_size', type=int, default=110000)
    parser.add_argument('--dataset_path', type=str, required=True)
    args = parser.parse_args()

    observations = []
    seed = 0
    executor = concurrent.futures.ProcessPoolExecutor(max_workers=args.max_workers)
    futures = collections.deque()
    print('Collect dataset')
    for _ in range(args.max_workers):
        futures.append(executor.submit(collect_observations, args.config_env_path, seed))
        seed += 1

    with tqdm(total=args.dataset_size) as pbar:
        while len(observations) < args.dataset_size:
            future = futures.popleft()
            episode_observations = future.result()
            observations.extend(episode_observations)
            futures.append(executor.submit(collect_observations, args.config_env_path, seed))
            seed += 1
            pbar.update(len(episode_observations))

    observations = observations[:args.dataset_size]
    Path(args.dataset_path).mkdir(parents=True, exist_ok=True)
    print('Save dataset')
    for i, observation in tqdm(enumerate(observations)):
        path = os.path.join(args.dataset_path, f'{i:06d}.png')
        cv2.imwrite(path, cv2.cvtColor(observation, cv2.COLOR_RGB2BGR))

    executor.shutdown(wait=False, cancel_futures=True)







