"""Compute minibatch blobs for training a Fast R-CNN network."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import numpy.random as npr
from scipy.misc import imread
from model.utils.config import cfg
from model.utils.blob import prep_im_for_blob, im_list_to_blob
import pdb
def get_minibatch(roidb, num_classes):
  """Given a roidb, construct a minibatch sampled from it."""
  num_images = len(roidb)
  # Sample random scales to use for each image in this batch
  random_scale_inds = npr.randint(0, high=len(cfg.TRAIN.SCALES),
                  size=num_images)
  assert(cfg.TRAIN.BATCH_SIZE % num_images == 0), \
    'num_images ({}) must divide BATCH_SIZE ({})'. \
    format(num_images, cfg.TRAIN.BATCH_SIZE)

  # Get the input image blob, formatted for caffe
  im_blob, im_scales = _get_image_blob(roidb, random_scale_inds)

  blobs = {'data': im_blob}

  assert len(im_scales) == 1, "Single batch only"
  assert len(roidb) == 1, "Single batch only"

  # gt object boxes: (x1, y1, x2, y2, cls, id)
  if cfg.TRAIN.USE_ALL_GT:
    # Include all ground truth object boxes
    obj_gt_inds = np.where(roidb[0]['obj_anno']['gt_classes'] != 0)[0]
  else:
    # For the COCO ground truth boxes, exclude the ones that are ''iscrowd''
    obj_gt_inds = np.where((roidb[0]['obj_anno']['gt_classes'] != 0) &
                           np.all(roidb[0]['obj_anno']['gt_overlaps'].toarray() > -1.0, axis=1))[0]
  obj_gt_boxes = np.empty((len(obj_gt_inds), 6), dtype=np.float32)
  obj_gt_boxes[:, 0:4] = roidb[0]['obj_anno']['boxes'][obj_gt_inds, :] * im_scales[0]
  obj_gt_boxes[:, 4] = roidb[0]['obj_anno']['gt_classes'][obj_gt_inds]
  obj_gt_boxes[:, 5] = roidb[0]['obj_anno']['ids'][obj_gt_inds]
  blobs['obj_gt_boxes'] = obj_gt_boxes

  # gt region boxes: (x1, y1, x2, y2, cls)
  if cfg.TRAIN.USE_ALL_GT:
    # Include all ground truth region boxes
    region_gt_inds = np.where(roidb[0]['region_anno']['gt_classes'] != 0)[0]
  else:
    # For the COCO ground truth boxes, exclude the ones that are ''iscrowd''
    region_gt_inds = np.where((roidb[0]['region_anno']['gt_classes'] != 0) &
                           np.all(roidb[0]['region_anno']['gt_overlaps'].toarray() > -1.0, axis=1))[0]
  region_gt_boxes = np.empty((len(region_gt_inds), 5), dtype=np.float32)
  region_gt_boxes[:, 0:4] = roidb[0]['region_anno']['boxes'][region_gt_inds, :] * im_scales[0]
  region_gt_boxes[:, 4] = roidb[0]['region_anno']['gt_classes'][region_gt_inds]
  blobs['region_gt_boxes'] = region_gt_boxes

  # gt relationship triples: (relation_cls, subject_id, object_id)
  num_relation = len(roidb[0]['relation_anno']['gt_classes'])
  relation_gt_triples = np.empty((num_relation, 3), dtype=np.float32)
  relation_gt_triples[:, 0] = roidb[0]['relation_anno']['gt_classes'][:]
  relation_gt_triples[:, 1] = roidb[0]['relation_anno']['subject_ids'][:]
  relation_gt_triples[:, 2] = roidb[0]['relation_anno']['object_ids'][:]
  blobs['relation_triples'] = relation_gt_triples

  blobs['im_info'] = np.array(
    [[im_blob.shape[1], im_blob.shape[2], im_scales[0]]],
    dtype=np.float32)

  blobs['img_id'] = roidb[0]['img_id']

  return blobs

def _get_image_blob(roidb, scale_inds):
  """Builds an input blob from the images in the roidb at the specified
  scales.
  """
  num_images = len(roidb)

  processed_ims = []
  im_scales = []
  for i in range(num_images):
    #im = cv2.imread(roidb[i]['image'])
    im = imread(roidb[i]['image'])

    if len(im.shape) == 2:
      im = im[:,:,np.newaxis]
      im = np.concatenate((im,im,im), axis=2)
    # flip the channel, since the original one using cv2
    # rgb -> bgr
    im = im[:,:,::-1]

    if roidb[i]['obj_anno']['flipped']:
      im = im[:, ::-1, :]
    target_size = cfg.TRAIN.SCALES[scale_inds[i]]
    im, im_scale = prep_im_for_blob(im, cfg.PIXEL_MEANS, target_size,
                    cfg.TRAIN.MAX_SIZE)
    im_scales.append(im_scale)
    processed_ims.append(im)

  # Create a blob to hold the input images
  blob = im_list_to_blob(processed_ims)

  return blob, im_scales
