import collections
import concurrent.futures
import multiprocessing
import os.path

import cv2
from tqdm import tqdm

import embodied
import numpy as np


def store_observations(observations, episode_dir_path):
  os.makedirs(episode_dir_path, exist_ok=False)
  for i, observation in enumerate(observations):
    observation = cv2.cvtColor(observation, cv2.COLOR_RGB2BGR)
    cv2.imwrite(os.path.join(episode_dir_path, f'{i}.png'), observation)

  return len(observations)


def collect_observations(agent, env, args, checkpoint_path, dataset_size, n_samples, output_path, max_workers):

  checkpoint_path = embodied.Path(checkpoint_path)
  print('Observation space:', embodied.format(env.obs_space), sep='\n')
  print('Action space:', embodied.format(env.act_space), sep='\n')

  timer = embodied.Timer()
  timer.wrap('agent', agent, ['policy'])
  timer.wrap('env', env, ['step'])

  n_episodes = 0
  n_submitted_observations = 0
  rng = np.random.default_rng(args.seed)
  executor = concurrent.futures.ProcessPoolExecutor(max_workers=max_workers,
                                                    mp_context=multiprocessing.get_context('forkserver'))
  futures = collections.deque()
  pbar_submitted_observations = tqdm(total=dataset_size, position=0, desc='Collected observations')
  pbar_stored_observations = tqdm(total=dataset_size, position=1, desc='Stored observations')

  def per_episode(ep):
    nonlocal n_episodes
    nonlocal n_submitted_observations
    nonlocal executor
    nonlocal futures
    nonlocal pbar_submitted_observations
    nonlocal pbar_stored_observations

    if n_submitted_observations < dataset_size:
      episode_dir_path = os.path.join(output_path, str(n_episodes))
      observations = rng.choice(ep['source_observation'], size=min(n_samples, dataset_size - n_submitted_observations),
                                replace=False)
      n_submitted_observations += len(observations)
      n_episodes += 1
      pbar_submitted_observations.update(len(observations))
      futures.append(executor.submit(store_observations, observations, episode_dir_path))

    while len(futures) > 0 and futures[0].done():
      pbar_stored_observations.update(futures.popleft().result())

  driver = embodied.Driver(env)
  driver.on_episode(lambda ep, worker: per_episode(ep))

  checkpoint = embodied.Checkpoint(checkpoint_path / 'checkpoint.ckpt')
  checkpoint.agent = agent
  checkpoint.load(args.from_checkpoint, keys=['agent'])

  print('Start observation collection loop.')
  policy = lambda *args: agent.policy(*args, mode='eval')
  while n_submitted_observations < dataset_size:
    driver(policy, steps=100)

  pbar_submitted_observations.close()

  for future in tqdm(futures):
    pbar_stored_observations.update(future.result())

  pbar_stored_observations.close()

  executor.shutdown(wait=False)
