import collections
import concurrent.futures
import multiprocessing
import os.path

import cv2
from tqdm import tqdm

import embodied
import numpy as np


def store_observations(observations, segmentations, episode_dir_path):
  os.makedirs(episode_dir_path, exist_ok=False)
  for i, observation in enumerate(observations):
    observation = cv2.cvtColor(observation, cv2.COLOR_RGB2BGR)
    cv2.imwrite(os.path.join(episode_dir_path, f'obs-{i}.png'), observation)
    if segmentations is not None:
      cv2.imwrite(os.path.join(episode_dir_path, f'seg-{i}.png'), segmentations[i])

  return len(observations)


def collect_observations(agent, env, args, checkpoint_path, dataset_size, output_path, max_workers, greedy_epsilon, n_samples=None):

  checkpoint_path = embodied.Path(checkpoint_path)
  print('Observation space:', embodied.format(env.obs_space), sep='\n')
  print('Action space:', embodied.format(env.act_space), sep='\n')

  timer = embodied.Timer()
  timer.wrap('agent', agent, ['policy'])
  timer.wrap('env', env, ['step'])

  n_episodes = 0
  n_submitted_observations = 0
  rng = np.random.default_rng(args.seed)
  executor = concurrent.futures.ProcessPoolExecutor(max_workers=max_workers,
                                                    mp_context=multiprocessing.get_context('forkserver'))
  futures = collections.deque()
  pbar_submitted_observations = tqdm(total=dataset_size, position=0, desc='Collected observations')
  pbar_stored_observations = tqdm(total=dataset_size, position=1, desc='Stored observations')

  def per_episode(ep):
    nonlocal n_episodes
    nonlocal n_submitted_observations
    nonlocal executor
    nonlocal futures
    nonlocal pbar_submitted_observations
    nonlocal pbar_stored_observations

    if n_submitted_observations < dataset_size:
      episode_dir_path = os.path.join(output_path, str(n_episodes))
      episode_observations = ep['source_observation']
      size = episode_observations.shape[0]
      if n_samples is not None:
        size = n_samples

      size = min(size, dataset_size - n_submitted_observations)
      selected_ids = np.sort(rng.choice(episode_observations.shape[0], size=size, replace=False))
      observations = episode_observations[selected_ids]
      segmentations = None
      if 'segmentation' in ep:
        segmentations = ep['segmentation'][selected_ids]

      n_submitted_observations += len(observations)
      n_episodes += 1
      pbar_submitted_observations.update(len(observations))
      futures.append(executor.submit(store_observations, observations, segmentations, episode_dir_path))

    while len(futures) > 0 and futures[0].done():
      pbar_stored_observations.update(futures.popleft().result())

  driver = embodied.Driver(env, greedy_epsilon=greedy_epsilon)
  driver.on_episode(lambda ep, worker: per_episode(ep))

  checkpoint = embodied.Checkpoint(checkpoint_path / 'checkpoint.ckpt')
  checkpoint.agent = agent
  checkpoint.load(args.from_checkpoint, keys=['agent'])

  print('Start observation collection loop.')
  policy = lambda *args: agent.policy(*args, mode='eval')
  while n_submitted_observations < dataset_size:
    driver(policy, steps=100)

  pbar_submitted_observations.close()

  for future in tqdm(futures):
    pbar_stored_observations.update(future.result())

  pbar_stored_observations.close()

  executor.shutdown(wait=False)
