import numpy as np

import embodied
from . import base


class BatchEnv(base.Env):

  def __init__(self, envs, parallel):
    assert all(len(env) == 0 for env in envs)
    assert len(envs) > 0
    self._envs = envs
    self._parallel = parallel
    self._keys = list(self.obs_space.keys())

  @property
  def obs_space(self):
    return self._envs[0].obs_space

  @property
  def act_space(self):
    return self._envs[0].act_space

  def __len__(self):
    return len(self._envs)

  def step(self, action):
    assert all(len(v) == len(self._envs) for v in action.values()), (
        len(self._envs), {k: v.shape for k, v in action.items()})
    obs = []
    for i, env in enumerate(self._envs):
      act = {k: v[i] for k, v in action.items()}
      obs.append(env.step(act))
    infos = {k: np.array([env.info[k] for env in self._envs]) for k in self._envs[0].info}
    if self._parallel:
      obs = [ob() for ob in obs]
    return {k: np.array([ob[k] for ob in obs]) for k in obs[0]}, infos

  def render(self):
    return np.stack([env.render() for env in self._envs])

  def close(self):
    for env in self._envs:
      try:
        env.close()
      except Exception:
        pass


class OrderedSlotsBatchWrapper(base.Env):
  def __init__(self, batch_env, slots_extractor, is_ordered=True, return_image=False):
    self._batch_env = batch_env
    self._slots_extractor = slots_extractor
    self._is_ordered = is_ordered
    self._return_image = return_image
    self._previous_slots = np.zeros((len(self), *self._slots_extractor.get_slots_dim()), dtype=np.float32)

  @property
  def obs_space(self):
    obs_space = dict(self._batch_env._envs[0].obs_space)
    del obs_space['image']
    obs_space['vector'] = embodied.Space(np.float32, int(np.prod(self._slots_extractor.get_slots_dim())))

    return obs_space

  @property
  def act_space(self):
    return self._batch_env._envs[0].act_space

  def __len__(self):
    return len(self._batch_env)

  @staticmethod
  def match(masks, masks_predict):
    not_matched_masks = set(range(masks_predict.shape[0]))
    not_matched_masks_predict = set(range(masks_predict.shape[0]))
    matching = [None] * masks_predict.shape[0]
    for i in range(len(masks)):
      mask = masks[i]
      area = mask.sum()
      if area == 0:
        continue

      best_match_index = None
      best_match_score = 0
      for j in not_matched_masks_predict:
        score = masks_predict[j][mask].sum() / area
        if score > best_match_score:
          best_match_index = j
          best_match_score = score

      matching[i] = best_match_index
      not_matched_masks.remove(i)
      not_matched_masks_predict.remove(best_match_index)

    for i, j in zip(not_matched_masks, not_matched_masks_predict):
      matching[i] = j

    return matching

  def step(self, action):
    obs, infos = self._batch_env.step(action)
    image = obs.pop('image')
    batch_masks = obs.pop('masks')
    if self._return_image:
      obs['raw'] = image
    is_first = obs['is_first']
    slots = np.zeros_like(self._previous_slots)
    masks_predict_shape = list(batch_masks.shape)
    masks_predict_shape[1] = self._slots_extractor.get_slots_dim()[0]
    batch_masks_predict = np.zeros(masks_predict_shape, dtype=np.float32)
    if is_first.any():
      slots[is_first], batch_masks_predict[is_first] = self._slots_extractor(image[is_first], prev_slots=None)

    if not is_first.all():
      slots[~is_first], batch_masks_predict[~is_first] = self._slots_extractor(image[~is_first], prev_slots=self._previous_slots[~is_first])

    self._previous_slots = slots.copy()
    if self._is_ordered:
      matching = []
      for masks, masks_predict in zip(batch_masks, batch_masks_predict):
        matching.append(self.match(masks, masks_predict))

      matching = np.asarray(matching)
      slots = np.take_along_axis(slots, matching[..., np.newaxis], axis=1)

    obs['vector'] = slots.reshape((self._previous_slots.shape[0], -1))

    return obs, infos

  def render(self):
    return self._batch_env.render()

  def close(self):
    self._batch_env.close()
