from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import os.path as osp
import PIL
# from model.utils.cython_bbox import bbox_overlaps
import numpy as np
import scipy.sparse
from model.utils.config import cfg
import pdb

ROOT_DIR = osp.join(osp.dirname(__file__), '..', '..')

class imdb(object):
  """Image database."""

  def __init__(self, name, obj_classes=None, region_classes=None, relation_classes=None):
    self._name = name
    self._num_obj_classes = 0
    self._num_region_classes = 0
    self._num_relation_classes = 0

    if not obj_classes:
      self._obj_classes = []
    else:
      self._obj_classes = obj_classes
    if not region_classes:
      self._region_classes = []
    else:
      self._region_classes = region_classes
    if not relation_classes:
      self._relation_classes = []
    else:
      self._relation_classes = relation_classes

    self._image_index = []
    self._obj_proposer = 'gt'
    self._roidb = None
    self._roidb_handler = self.default_roidb
    # Use this dict for storing dataset specific config options
    self.config = {}

  @property
  def name(self):
    return self._name

  @property
  def num_obj_classes(self):
    return len(self._obj_classes)

  @property
  def num_region_classes(self):
    return len(self._region_classes)

  @property
  def num_relation_classes(self):
    return len(self._relation_classes)

  @property
  def obj_classes(self):
    return self._obj_classes

  @property
  def region_classes(self):
    return self._region_classes

  @property
  def relation_classes(self):
    return self._relation_classes

  @property
  def image_index(self):
    return self._image_index

  @property
  def roidb_handler(self):
    return self._roidb_handler

  @roidb_handler.setter
  def roidb_handler(self, val):
    self._roidb_handler = val

  def set_proposal_method(self, method):
    method = eval('self.' + method + '_roidb')
    self.roidb_handler = method

  @property
  def roidb(self):
    # A roidb contains three dictionaries for objects, regions and relationships, respectively
    # objects: {boxes, gt_classes, ids, gt_overlaps, seg_areas, gt_ishards, flipped}
    # regions: {boxes, gt_classes, gt_overlaps, seg_areas, gt_ishards, flipped}
    # relationships: {gt_classes, subject_ids, object_ids, flipped}
    if self._roidb is not None:
      return self._roidb
    self._roidb = self.roidb_handler()
    return self._roidb

  @property
  def cache_path(self):
    cache_path = osp.abspath(osp.join(cfg.DATA_DIR, 'cache'))
    if not os.path.exists(cache_path):
      os.makedirs(cache_path)
    return cache_path

  @property
  def num_images(self):
    return len(self.image_index)

  def image_path_at(self, i):
    raise NotImplementedError

  def image_id_at(self, i):
    raise NotImplementedError

  def default_roidb(self):
    raise NotImplementedError

  def evaluate_detections(self, all_boxes, output_dir=None):
    """
    all_boxes is a list of length number-of-classes.
    Each list element is a list of length number-of-images.
    Each of those list elements is either an empty list []
    or a numpy array of detection.

    all_boxes[class][image] = [] or np.array of shape #dets x 5
    """
    raise NotImplementedError

  def _get_widths(self):
    return [PIL.Image.open(self.image_path_at(i)).size[0]
            for i in range(self.num_images)]

  def append_flipped_images(self):
    num_images = self.num_images
    widths = self._get_widths()

    # append annotations for flipped images
    for i in range(num_images):
      # additional object annotations
      obj_boxes = self.roidb[i]['obj_anno']['boxes'].copy()
      obj_oldx1 = obj_boxes[:, 0].copy()
      obj_oldx2 = obj_boxes[:, 2].copy()
      obj_boxes[:, 0] = widths[i] - obj_oldx2
      obj_boxes[:, 2] = widths[i] - obj_oldx1
      assert (obj_boxes[:, 2] >= obj_boxes[:, 0]).all()
      obj_anno = {'boxes': obj_boxes,
                  'gt_classes': self.roidb[i]['obj_anno']['gt_classes'],
                  'ids': self.roidb[i]['obj_anno']['ids'],
                  'gt_overlaps': self.roidb[i]['obj_anno']['gt_overlaps'],
                  'seg_areas': self.roidb[i]['obj_anno']['seg_areas'],
                  'gt_ishards': self.roidb[i]['obj_anno']['gt_ishards'],
                  'flipped': True}

      # additional region annotations
      region_boxes = self.roidb[i]['region_anno']['boxes'].copy()
      region_oldx1 = region_boxes[:, 0].copy()
      region_oldx2 = region_boxes[:, 2].copy()
      region_boxes[:, 0] = widths[i] - region_oldx2
      region_boxes[:, 2] = widths[i] - region_oldx1
      assert (region_boxes[:, 2] >= region_boxes[:, 0]).all()
      region_anno = {'boxes': region_boxes,
                     'gt_classes': self.roidb[i]['region_anno']['gt_classes'],
                     'gt_overlaps': self.roidb[i]['region_anno']['gt_overlaps'],
                     'seg_areas': self.roidb[i]['region_anno']['seg_areas'],
                     'gt_ishards': self.roidb[i]['region_anno']['gt_ishards'],
                     'flipped': True}

      # additional relation annotations
      relation_anno = {'gt_classes': self.roidb[i]['relation_anno']['gt_classes'],
                       'subject_ids': self.roidb[i]['relation_anno']['subject_ids'],
                       'object_ids': self.roidb[i]['relation_anno']['object_ids'],
                       'flipped': True}

      entry = {'obj_anno': obj_anno, 'region_anno': region_anno, 'relation_anno': relation_anno}
      self.roidb.append(entry)

    self._image_index = self._image_index * 2

  def create_roidb_from_box_list(self, box_list, gt_roidb):
    assert len(box_list) == self.num_images, \
      'Number of boxes must match number of ground-truth images'
    roidb = []
    for i in range(self.num_images):
      boxes = box_list[i]
      num_boxes = boxes.shape[0]
      overlaps = np.zeros((num_boxes, self.num_obj_classes), dtype=np.float32)

      if gt_roidb is not None and gt_roidb[i]['obj_anno']['boxes'].size > 0:
        gt_boxes = gt_roidb[i]['obj_anno']['boxes']
        gt_classes = gt_roidb[i]['obj_anno']['gt_classes']
        gt_overlaps = bbox_overlaps(boxes.astype(np.float),
                                    gt_boxes.astype(np.float))
        argmaxes = gt_overlaps.argmax(axis=1)
        maxes = gt_overlaps.max(axis=1)
        I = np.where(maxes > 0)[0]
        overlaps[I, gt_classes[argmaxes[I]]] = maxes[I]

      overlaps = scipy.sparse.csr_matrix(overlaps)
      obj_anno = {
        'boxes': boxes,
        'gt_classes': np.zeros((num_boxes,), dtype=np.int32),
        'gt_overlaps': overlaps,
        'flipped': False,
        'seg_areas': np.zeros((num_boxes,), dtype=np.float32)}
      region_anno = gt_roidb[i]['region_anno']
      relation_anno = gt_roidb[i]['relation_anno']
      roidb.append({'obj_anno': obj_anno, 'region_anno': region_anno, 'relation_anno': relation_anno})

    return roidb

  @staticmethod
  def merge_roidbs(a, b):
    assert len(a) == len(b)
    for i in range(len(a)):
      # merge object annotations
      a[i]['obj_anno']['boxes'] = np.vstack((a[i]['obj_anno']['boxes'], b[i]['obj_anno']['boxes']))
      a[i]['obj_anno']['gt_classes'] = np.hstack((a[i]['obj_anno']['gt_classes'],
                                                  b[i]['obj_anno']['gt_classes']))
      a[i]['obj_anno']['ids'] = np.vstack((a[i]['obj_anno']['ids'], b[i]['obj_anno']['ids']))
      a[i]['obj_anno']['gt_overlaps'] = scipy.sparse.vstack([a[i]['obj_anno']['gt_overlaps'],
                                                             b[i]['obj_anno']['gt_overlaps']])
      a[i]['obj_anno']['seg_areas'] = np.hstack((a[i]['obj_anno']['seg_areas'],
                                                 b[i]['obj_anno']['seg_areas']))
      a[i]['obj_anno']['gt_ishards'] = scipy.sparse.vstack([a[i]['obj_anno']['gt_ishards'],
                                                            b[i]['obj_anno']['gt_ishards']])
      a[i]['obj_anno']['flipped'] = np.vstack((a[i]['obj_anno']['flipped'], b[i]['obj_anno']['flipped']))

      # merge region annotations
      a[i]['region_anno']['boxes'] = np.vstack((a[i]['region_anno']['boxes'], b[i]['region_anno']['boxes']))
      a[i]['region_anno']['gt_classes'] = np.hstack((a[i]['region_anno']['gt_classes'],
                                                     b[i]['region_anno']['gt_classes']))
      a[i]['region_anno']['gt_overlaps'] = scipy.sparse.vstack([a[i]['region_anno']['gt_overlaps'],
                                                                b[i]['region_anno']['gt_overlaps']])
      a[i]['region_anno']['seg_areas'] = np.hstack((a[i]['region_anno']['seg_areas'],
                                                    b[i]['region_anno']['seg_areas']))
      a[i]['region_anno']['gt_ishards'] = scipy.sparse.vstack([a[i]['region_anno']['gt_ishards'],
                                                               b[i]['region_anno']['gt_ishards']])
      a[i]['region_anno']['flipped'] = np.vstack((a[i]['region_anno']['flipped'], b[i]['region_anno']['flipped']))

      # merge relation annotations
      a[i]['relation_anno']['gt_classes'] = np.hstack((a[i]['relation_anno']['gt_classes'],
                                                       b[i]['relation_anno']['gt_classes']))
      a[i]['relation_anno']['subject_ids'] = np.hstack((a[i]['relation_anno']['subject_ids'],
                                                        b[i]['relation_anno']['subject_ids']))
      a[i]['relation_anno']['object_ids'] = np.hstack((a[i]['relation_anno']['object_ids'],
                                                       b[i]['relation_anno']['object_ids']))
      a[i]['relation_anno']['flipped'] = np.vstack((a[i]['relation_anno']['flipped'],
                                                    b[i]['relation_anno']['flipped']))

    return a

  def competition_mode(self, on):
    """Turn competition mode on or off."""
    pass
