from __future__ import print_function
from __future__ import absolute_import

import xml.dom.minidom as minidom

import os
# import PIL
import json
import numpy as np
import scipy.sparse
import subprocess
import math
import glob
import uuid
import scipy.io as sio
import xml.etree.ElementTree as ET
import pickle
from .imdb import imdb
from .imdb import ROOT_DIR
from . import ds_utils
from .voc_eval import voc_eval

# TODO: make fast_rcnn irrelevant
# >>>> obsolete, because it depends on sth outside of this project
from model.utils.config import cfg

try:
    xrange          # Python 2
except NameError:
    xrange = range  # Python 3

# <<<< obsolete

# define the Visual Genome dataset for visual-knowledge detection
class VisualGenome(imdb):
    def __init__(self, base_dir, image_set = 'train', process = True):
        if image_set in ['train', 'test']:
            if process:
                imdb.__init__(self, 'visual_genome_' + image_set + '_process')
            else:
                imdb.__init__(self, 'visual_genome_' + image_set)
        else:
            raise ValueError('Invalid split name')

        self._base_dir = base_dir
        self._image_set = image_set
        self.process = process
        self._data_path = os.path.join(base_dir, 'Visual_Genome')

        if self.process:
            obj_dict_file = open(os.path.join(self._data_path, 'obj_num_dict_process.json'), 'r')
            relation_dict_file = open(os.path.join(self._data_path, 'relation_num_dict_process.json'), 'r')
        else:
            obj_dict_file = open(os.path.join(self._data_path, 'obj_num_dict.json'), 'r')
            relation_dict_file = open(os.path.join(self._data_path, 'relation_num_dict.json'), 'r')
        obj_dict = json.load(obj_dict_file)
        relation_dict = json.load(relation_dict_file)
        self._obj_classes = ['__background__'] + list(obj_dict.keys())
        self._region_classes = ['__background__', 'foreground']
        self._relation_classes = list(relation_dict.keys())

        self._obj_class_to_ind = dict(zip(self._obj_classes, xrange(len(self._obj_classes))))
        self._region_class_to_ind = dict(zip(self._region_classes, xrange(len(self._region_classes))))
        self._relation_class_to_ind = dict(zip(self._relation_classes, xrange(len(self._relation_classes))))

        self._image_ext = '.jpg'
        self._image_index = self._load_image_set_index()
        # Default to roidb handler
        # self._roidb_handler = self.selective_search_roidb
        self._roidb_handler = self.gt_roidb
        self._salt = str(uuid.uuid4())
        self._comp_id = 'comp4'

        # PASCAL specific config options
        self.config = {'cleanup': True,
                       'use_salt': False,
                       'use_diff': False,
                       'matlab_eval': False,
                       'rpn_file': None,
                       'min_size': 2}

        assert os.path.exists(self._data_path), \
            'Path does not exist: {}'.format(self._data_path)

    def image_path_at(self, i):
        """
        Return the absolute path to image i in the image sequence.
        """
        return self.image_path_from_index(self._image_index[i])

    def image_id_at(self, i):
        """
        Return the absolute path to image i in the image sequence.
        """
        return i

    def image_path_from_index(self, index):
        """
        Construct an image path from the image's "index" identifier.
        """
        if self.process:
            image_path = os.path.join(self._data_path, 'Images_' + self._image_set + '_process',
                                      index + self._image_ext)
        else:
            image_path = os.path.join(self._data_path, 'Images_' + self._image_set,
                                      index + self._image_ext)
        assert os.path.exists(image_path), \
            'Path does not exist: {}'.format(image_path)
        return image_path

    def _load_image_set_index(self):
        """
        Load the indexes listed in this dataset's image set file.
        """
        # Example path to image set file:
        # self._data_path/train.txt
        if self.process:
            image_set_file = os.path.join(self._data_path, self._image_set + '_process.txt')
        else:
            image_set_file = os.path.join(self._data_path, self._image_set + '.txt')
        assert os.path.exists(image_set_file), \
            'Path does not exist: {}'.format(image_set_file)
        with open(image_set_file) as f:
            image_index = [x.strip() for x in f.readlines()]

        return image_index

    def _get_default_path(self):
        """
        Return the default path where Visual Genome is expected to be installed.
        """
        return os.path.join(cfg.DATA_DIR, 'Visual_Genome')

    def gt_roidb(self):
        """
        Return the database of ground-truth regions of interest.

        This function loads/saves from/to a cache file to speed up future calls.
        """
        cache_file = os.path.join(self.cache_path, self.name + '_gt_roidb.pkl')
        if os.path.exists(cache_file):
            with open(cache_file, 'rb') as fid:
                roidb = pickle.load(fid, encoding='iso-8859-1')
            print('{} gt roidb loaded from {}'.format(self.name, cache_file))
            return roidb

        gt_roidb = [self._load_visual_genome_annotation(index)
                    for index in self.image_index]
        with open(cache_file, 'wb') as fid:
            pickle.dump(gt_roidb, fid, pickle.HIGHEST_PROTOCOL)
        print('wrote gt roidb to {}'.format(cache_file))

        return gt_roidb

    def selective_search_roidb(self):
        """
        Return the database of selective search regions of interest.
        Ground-truth ROIs are also included.

        This function loads/saves from/to a cache file to speed up future calls.
        """
        cache_file = os.path.join(self.cache_path,
                                  self.name + '_selective_search_roidb.pkl')

        if os.path.exists(cache_file):
            with open(cache_file, 'rb') as fid:
                roidb = pickle.load(fid, encoding='iso-8859-1')
            print('{} ss roidb loaded from {}'.format(self.name, cache_file))
            return roidb

        if int(self._year) == 2007 or self._image_set != 'test':
            gt_roidb = self.gt_roidb()
            ss_roidb = self._load_selective_search_roidb(gt_roidb)
            roidb = imdb.merge_roidbs(gt_roidb, ss_roidb)
        else:
            roidb = self._load_selective_search_roidb(None)
        with open(cache_file, 'wb') as fid:
            pickle.dump(roidb, fid, pickle.HIGHEST_PROTOCOL)
        print('wrote ss roidb to {}'.format(cache_file))

        return roidb

    def rpn_roidb(self):
        if int(self._year) == 2007 or self._image_set != 'test':
            gt_roidb = self.gt_roidb()
            rpn_roidb = self._load_rpn_roidb(gt_roidb)
            roidb = imdb.merge_roidbs(gt_roidb, rpn_roidb)
        else:
            roidb = self._load_rpn_roidb(None)

        return roidb

    def _load_rpn_roidb(self, gt_roidb):
        filename = self.config['rpn_file']
        print('loading {}'.format(filename))
        assert os.path.exists(filename), \
            'rpn data not found at: {}'.format(filename)
        with open(filename, 'rb') as f:
            box_list = pickle.load(f)
        return self.create_roidb_from_box_list(box_list, gt_roidb)

    def _load_selective_search_roidb(self, gt_roidb):
        filename = os.path.abspath(os.path.join(cfg.DATA_DIR,
                                                'selective_search_data',
                                                self.name + '.mat'))
        assert os.path.exists(filename), \
            'Selective search data not found at: {}'.format(filename)
        raw_data = sio.loadmat(filename)['boxes'].ravel()

        box_list = []
        for i in xrange(raw_data.shape[0]):
            boxes = raw_data[i][:, (1, 0, 3, 2)] - 1
            keep = ds_utils.unique_boxes(boxes)
            boxes = boxes[keep, :]
            keep = ds_utils.filter_small_boxes(boxes, self.config['min_size'])
            boxes = boxes[keep, :]
            box_list.append(boxes)

        return self.create_roidb_from_box_list(box_list, gt_roidb)

    def _load_visual_genome_annotation(self, index):
        """
        Load image, bounding boxes and object relationships info from XML file
        """
        if self.process:
            filename = os.path.join(self._data_path, 'Annotations_' + self._image_set + '_process', index + '.xml')
        else:
            filename = os.path.join(self._data_path, 'Annotations_' + self._image_set, index + '.xml')
        tree = ET.parse(filename)

        # load objects
        objs = tree.findall('object')
        num_objs = len(objs)

        obj_boxes = np.zeros((num_objs, 4), dtype=np.int32)
        obj_gt_classes = np.zeros((num_objs), dtype=np.int32)
        obj_ids = np.zeros((num_objs), dtype=np.int32)
        obj_overlaps = np.zeros((num_objs, len(self._obj_classes)), dtype=np.float32)
        obj_seg_areas = np.zeros((num_objs), dtype=np.float32)
        obj_ishards = np.zeros((num_objs), dtype=np.int32)

        for ix, obj in enumerate(objs):
            bbox = obj.find('bndbox')
            x1 = float(bbox.find('xmin').text)
            y1 = float(bbox.find('ymin').text)
            x2 = float(bbox.find('xmax').text)
            y2 = float(bbox.find('ymax').text)
            obj_boxes[ix, :] = [x1, y1, x2, y2]

            cls = self._obj_class_to_ind[obj.find('name').text.lower().strip()]
            obj_gt_classes[ix] = cls
            obj_ids[ix] = int(obj.find('object_id').text)

            obj_overlaps[ix, cls] = 1.0
            obj_seg_areas[ix] = (x2 - x1 + 1) * (y2 - y1 + 1)

            diff = bbox.find('difficult')
            difficult = 0 if diff == None else int(diff.text)
            obj_ishards[ix] = difficult

        obj_overlaps = scipy.sparse.csr_matrix(obj_overlaps)
        obj_anno = {'boxes': obj_boxes,
                    'gt_classes': obj_gt_classes,
                    'ids': obj_ids,
                    'gt_overlaps': obj_overlaps,
                    'seg_areas': obj_seg_areas,
                    'gt_ishards': obj_ishards,
                    'flipped': False}

        # load regions
        regions = tree.findall('region')
        num_regions = len(regions)

        region_boxes = np.zeros((num_regions, 4), dtype=np.int32)
        region_gt_classes = np.zeros((num_regions), dtype=np.int32)
        region_overlaps = np.zeros((num_regions, len(self._region_classes)), dtype=np.float32)
        region_seg_areas = np.zeros((num_regions), dtype=np.float32)
        region_ishards = np.zeros((num_regions), dtype=np.int32)

        for ix, region in enumerate(regions):
            bbox = region.find('bndbox')
            x1 = float(bbox.find('xmin').text)
            y1 = float(bbox.find('ymin').text)
            x2 = float(bbox.find('xmax').text)
            y2 = float(bbox.find('ymax').text)
            region_boxes[ix, :] = [x1, y1, x2, y2]

            cls = self._region_class_to_ind['foreground']
            region_gt_classes[ix] = cls

            region_overlaps[ix, cls] = 1.0
            region_seg_areas[ix] = (x2 - x1 + 1) * (y2 - y1 + 1)

            diff = bbox.find('difficult')
            difficult = 0 if diff == None else int(diff.text)
            region_ishards[ix] = difficult

        region_overlaps = scipy.sparse.csr_matrix(region_overlaps)
        region_anno = {'boxes': region_boxes,
                       'gt_classes': region_gt_classes,
                       'gt_overlaps': region_overlaps,
                       'seg_areas': region_seg_areas,
                       'gt_ishards': region_ishards,
                       'flipped': False}

        # load relationships
        relations = tree.findall('relation')
        num_relations = len(relations)

        relation_gt_classes = np.zeros((num_relations), dtype=np.int32)
        relation_subject_ids = np.zeros((num_relations), dtype=np.int32)
        relation_object_ids = np.zeros((num_relations), dtype=np.int32)

        for ix, relation in enumerate(relations):
            cls = self._relation_class_to_ind[relation.find('predicate').text.lower().strip()]
            relation_gt_classes[ix] = cls

            relation_subject_ids[ix] = int(relation.find('subject_id').text)
            relation_object_ids[ix] = int(relation.find('object_id').text)

        relation_anno = {'gt_classes': relation_gt_classes,
                         'subject_ids': relation_subject_ids,
                         'object_ids': relation_object_ids,
                         'flipped': False}

        return {'obj_anno': obj_anno, 'region_anno': region_anno, 'relation_anno': relation_anno}

    def _get_comp_id(self):
        comp_id = (self._comp_id + '_' + self._salt if self.config['use_salt']
                   else self._comp_id)
        return comp_id

    def _get_voc_results_file_template(self):
        # ./data/Visual_Genome/<comp_id>_det_test_coffee.n.01.txt
        filename = self._get_comp_id() + '_det_' + self._image_set + '_{:s}.txt'
        filedir = os.path.join(self._data_path, 'results')
        if not os.path.exists(filedir):
            os.makedirs(filedir)
        path = os.path.join(filedir, filename)
        return path

    def _write_voc_results_file(self, all_boxes, type = 'object'):
        if type == 'object':
            classes = self._obj_classes
        elif type == 'region':
            classes = self._region_classes
        else:
            raise ValueError('The type of bounding boxes must be either object or region.')

        for cls_ind, cls in enumerate(classes):
            if cls == '__background__':
                continue
            print('Writing {} VOC results file'.format(cls))
            filename = self._get_voc_results_file_template().format(cls)
            with open(filename, 'wt') as f:
                for im_ind, index in enumerate(self.image_index):
                    dets = all_boxes[cls_ind][im_ind]
                    if dets == []:
                        continue
                    # the VOCdevkit expects 1-based indices
                    for k in xrange(dets.shape[0]):
                        f.write('{:s} {:.3f} {:.1f} {:.1f} {:.1f} {:.1f}\n'.
                                format(index, dets[k, -1],
                                       dets[k, 0], dets[k, 1],
                                       dets[k, 2], dets[k, 3]))

    def _do_python_eval(self, output_dir='output', epoch = 10, type = 'object'):
        if self.process:
            annopath = os.path.join(
                self._data_path,
                'Annotations_' + self._image_set + '_process',
                '{:s}.xml')
            imagesetfile = os.path.join(
                self._data_path,
                self._image_set + '_process.txt')
        else:
            annopath = os.path.join(
                self._data_path,
                'Annotations_' + self._image_set,
                '{:s}.xml')
            imagesetfile = os.path.join(
                self._data_path,
                self._image_set + '.txt')
        cachedir = os.path.join(self._data_path, 'annotations_cache')
        aps = []
        # The PASCAL VOC metric changed in 2010
        use_07_metric = False
        print('VOC07 metric? ' + ('Yes' if use_07_metric else 'No'))
        if not os.path.isdir(output_dir):
            os.mkdir(output_dir)
        results = open(os.path.join(output_dir, 'results.txt'), 'a')
        results.write('Epoch {}:\n'.format(str(epoch)))

        if type == 'object':
            classes = self._obj_classes
        elif type == 'region':
            classes = self._region_classes
        else:
            raise ValueError('The evaluation type must be either object or region.')

        for i, cls in enumerate(classes):
            if cls == '__background__':
                continue
            filename = self._get_voc_results_file_template().format(cls)
            rec, prec, ap = voc_eval(
                filename, annopath, imagesetfile, cls, cachedir, ovthresh=0.5,
                use_07_metric=use_07_metric, type=type)
            aps += [ap]
            print('AP for {} = {:.4f}'.format(cls, ap))
            results.write('AP for {} = {:.4f}\n'.format(cls, ap))
            with open(os.path.join(output_dir, cls + '_pr.pkl'), 'wb') as f:
                pickle.dump({'rec': rec, 'prec': prec, 'ap': ap}, f)
        print('Mean AP = {:.4f}'.format(np.mean(aps)))
        results.write('Mean AP = {:.4f}\n'.format(np.mean(aps)))
        results.write('\n')
        print('~~~~~~~~')
        print('Results:')
        for ap in aps:
            print('{:.3f}'.format(ap))
        print('{:.3f}'.format(np.mean(aps)))
        print('~~~~~~~~')
        print('')
        print('--------------------------------------------------------------')
        print('Results computed with the **unofficial** Python eval code.')
        print('Results should be very close to the official MATLAB eval code.')
        print('Recompute with `./tools/reval.py --matlab ...` for your paper.')
        print('-- Thanks, The Management')
        print('--------------------------------------------------------------')

    def _do_matlab_eval(self, output_dir='output'):
        print('-----------------------------------------------------')
        print('Computing results with the official MATLAB eval code.')
        print('-----------------------------------------------------')
        path = os.path.join(cfg.ROOT_DIR, 'lib', 'datasets',
                            'VOCdevkit-matlab-wrapper')
        cmd = 'cd {} && '.format(path)
        cmd += '{:s} -nodisplay -nodesktop '.format(cfg.MATLAB)
        cmd += '-r "dbstop if error; '
        cmd += 'voc_eval(\'{:s}\',\'{:s}\',\'{:s}\',\'{:s}\'); quit;"' \
            .format(self._devkit_path, self._get_comp_id(),
                    self._image_set, output_dir)
        print('Running:\n{}'.format(cmd))
        status = subprocess.call(cmd, shell=True)

    def evaluate_detections(self, all_boxes, output_dir, epoch = 10, type = 'object'):
        self._write_voc_results_file(all_boxes, type = type)
        self._do_python_eval(output_dir, epoch = epoch, type = type)
        if self.config['matlab_eval']:
            self._do_matlab_eval(output_dir)
        if self.config['cleanup']:
            for cls in self._classes:
                if cls == '__background__':
                    continue
                filename = self._get_voc_results_file_template().format(cls)
                os.remove(filename)

    def competition_mode(self, on):
        if on:
            self.config['use_salt'] = False
            self.config['cleanup'] = False
        else:
            self.config['use_salt'] = True
            self.config['cleanup'] = True


if __name__ == '__main__':
    d = VisualGenome('./data', 'train')
    res = d.roidb
    from IPython import embed

    embed()

