# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import collections
import json
import os.path as op

import numpy as np
import torch

from .tsv import TSVYamlDataset, find_file_path_in_yaml
from .box_label_loader import BoxLabelLoader
from maskrcnn_benchmark.data.datasets.coco_dt import CocoDetectionTSV


class VGDetectionTSV(CocoDetectionTSV):
    pass


def sort_key_by_val(dic):
    sorted_dic = sorted(dic.items(), key=lambda kv: kv[1])
    return [kv[0] for kv in sorted_dic]


def bbox_overlaps(anchors, gt_boxes):
    """
    anchors: (N, 4) ndarray of float
    gt_boxes: (K, 4) ndarray of float
    overlaps: (N, K) ndarray of overlap between boxes and query_boxes
    """
    N = anchors.size(0)
    K = gt_boxes.size(0)

    gt_boxes_area = ((gt_boxes[:, 2] - gt_boxes[:, 0] + 1) *
                     (gt_boxes[:, 3] - gt_boxes[:, 1] + 1)).view(1, K)

    anchors_area = ((anchors[:, 2] - anchors[:, 0] + 1) *
                    (anchors[:, 3] - anchors[:, 1] + 1)).view(N, 1)

    boxes = anchors.view(N, 1, 4).expand(N, K, 4)
    query_boxes = gt_boxes.view(1, K, 4).expand(N, K, 4)

    iw = (torch.min(boxes[:, :, 2], query_boxes[:, :, 2]) -
          torch.max(boxes[:, :, 0], query_boxes[:, :, 0]) + 1)
    iw[iw < 0] = 0

    ih = (torch.min(boxes[:, :, 3], query_boxes[:, :, 3]) -
          torch.max(boxes[:, :, 1], query_boxes[:, :, 1]) + 1)
    ih[ih < 0] = 0

    ua = anchors_area + gt_boxes_area - (iw * ih)
    overlaps = iw * ih / ua

    return overlaps


# VG data loader for Danfei Xu's Scene graph focused format.
# todo: if ordering of classes, attributes, relations changed
# todo make sure to re-write the obj_classes.txt/rel_classes.txt files

def _box_filter(boxes, must_overlap=False):
    """ Only include boxes that overlap as possible relations.
    If no overlapping boxes, use all of them."""
    overlaps = bbox_overlaps(boxes, boxes).numpy() > 0
    np.fill_diagonal(overlaps, 0)

    all_possib = np.ones_like(overlaps, dtype=np.bool)
    np.fill_diagonal(all_possib, 0)

    if must_overlap:
        possible_boxes = np.column_stack(np.where(overlaps))

        if possible_boxes.size == 0:
            possible_boxes = np.column_stack(np.where(all_possib))
    else:
        possible_boxes = np.column_stack(np.where(all_possib))
    return possible_boxes


class VGTSVDataset(TSVYamlDataset):
    """
    Generic TSV dataset format for Object Detection.
    """

    def __init__(self, yaml_file, extra_fields=None, transforms=None,
                 is_load_label=True, filter_duplicate_rels=True,
                 relation_on=False, cv2_output=False, **kwargs):
        if extra_fields is None:
            extra_fields = []
        self.transforms = transforms
        self.is_load_label = is_load_label
        self.relation_on = relation_on
        super(VGTSVDataset, self).__init__(yaml_file, cv2_output=cv2_output)

        ignore_attrs = self.cfg.get("ignore_attrs", None)
        # construct those maps
        jsondict_file = find_file_path_in_yaml(self.cfg.get("jsondict", None), self.root)
        jsondict = json.load(open(jsondict_file, 'r'))

        # self.linelist_file
        if 'train' in op.basename(self.linelist_file):
            self.split = "train"
        elif 'test' in op.basename(self.linelist_file) \
                or 'val' in op.basename(self.linelist_file) \
                or 'valid' in op.basename(self.linelist_file):
            self.split = "test"
        else:
            raise ValueError("Split must be one of [train, test], but get {}!".format(self.linelist_file))
        self.filter_duplicate_rels = filter_duplicate_rels and self.split == 'train'

        self.class_to_ind = jsondict['label_to_idx']
        self.ind_to_class = jsondict['idx_to_label']
        self.class_to_ind['__background__'] = 0
        self.ind_to_class['0'] = '__background__'
        self.classes = sort_key_by_val(self.class_to_ind)
        assert (all([self.classes[i] == self.ind_to_class[str(i)] for i in range(len(self.classes))]))

        # writing obj classes to disk for Neural Motif model building.
        obj_classes_out_fn = op.splitext(self.label_file)[0] + ".obj_classes.txt"
        if not op.isfile(obj_classes_out_fn):
            with open(obj_classes_out_fn, 'w') as f:
                for item in self.classes:
                    f.write("%s\n" % item)

        self.attribute_to_ind = jsondict['attribute_to_idx']
        self.ind_to_attribute = jsondict['idx_to_attribute']
        self.attribute_to_ind['__no_attribute__'] = 0
        self.ind_to_attribute['0'] = '__no_attribute__'
        self.attributes = sort_key_by_val(self.attribute_to_ind)
        assert (all([self.attributes[i] == self.ind_to_attribute[str(i)] for i in range(len(self.attributes))]))

        self.relation_to_ind = jsondict['predicate_to_idx']
        self.ind_to_relation = jsondict['idx_to_predicate']
        self.relation_to_ind['__no_relation__'] = 0
        self.ind_to_relation['0'] = '__no_relation__'
        self.relations = sort_key_by_val(self.relation_to_ind)
        assert (all([self.relations[i] == self.ind_to_relation[str(i)] for i in range(len(self.relations))]))

        # writing rel classes to disk for Neural Motif Model building.
        rel_classes_out_fn = op.splitext(self.label_file)[0] + '.rel_classes.txt'
        if not op.isfile(rel_classes_out_fn):
            with open(rel_classes_out_fn, 'w') as f:
                for item in self.relations:
                    f.write("%s\n" % item)

        # label map: minus one because we will add one in BoxLabelLoader
        self.labelmap = {key: val - 1 for key, val in self.class_to_ind.items()}
        labelmap_file = find_file_path_in_yaml(self.cfg.get("labelmap_dec"), self.root)
        # self.labelmap_dec = load_labelmap_file(labelmap_file)
        if self.is_load_label:
            self.label_loader = BoxLabelLoader(
                labelmap=self.labelmap,
                extra_fields=extra_fields,
                ignore_attrs=ignore_attrs
            )

        # get frequency prior for relations
        if self.relation_on:
            self.freq_prior_file = op.splitext(self.label_file)[0] + ".freq_prior.npy"
            if self.split == 'train' and not op.exists(self.freq_prior_file):
                print("Computing frequency prior matrix...")
                fg_matrix, bg_matrix = self._get_freq_prior()
                prob_matrix = fg_matrix.astype(np.float32)
                prob_matrix[:, :, 0] = bg_matrix
                prob_matrix[:, :, 0] += 1
                prob_matrix /= np.sum(prob_matrix, 2)[:, :, None]
                np.save(self.freq_prior_file, prob_matrix)

    def _get_freq_prior(self, must_overlap=False):
        fg_matrix = np.zeros((
            len(self.classes),
            len(self.classes),
            len(self.relations)
        ), dtype=np.int64)

        bg_matrix = np.zeros((
            len(self.classes),
            len(self.classes),
        ), dtype=np.int64)

        for ex_ind in range(self.__len__()):
            target = self.get_groundtruth(ex_ind)
            gt_classes = target.get_field('labels').numpy()
            gt_relations = target.get_field('relation_labels').numpy()
            gt_boxes = target.bbox

            # For the foreground, we'll just look at everything
            try:
                o1o2 = gt_classes[gt_relations[:, :2]]
                for (o1, o2), gtr in zip(o1o2, gt_relations[:, 2]):
                    fg_matrix[o1, o2, gtr] += 1

                # For the background, get all of the things that overlap.
                o1o2_total = gt_classes[np.array(
                    _box_filter(gt_boxes, must_overlap=must_overlap), dtype=int)]
                for (o1, o2) in o1o2_total:
                    bg_matrix[o1, o2] += 1
            except IndexError as e:
                assert len(gt_relations) == 0

            if ex_ind % 20 == 0:
                print("processing {}/{}".format(ex_ind, self.__len__()))

        return fg_matrix, bg_matrix

    def relation_loader(self, relation_triplets, target):
        # relation_triplets [list of tuples]: M*3
        # target: BoxList from label_loader
        if self.filter_duplicate_rels:
            # Filter out dupes!
            assert self.split == 'train'
            all_rel_sets = collections.defaultdict(list)
            for (o0, o1, r) in relation_triplets:
                all_rel_sets[(o0, o1)].append(r)
            relation_triplets = [(k[0], k[1], np.random.choice(v)) for k, v in all_rel_sets.items()]

        # get M*M pred_labels
        relations = torch.zeros([len(target), len(target)], dtype=torch.int64)
        for i in range(len(relation_triplets)):
            subj_id = relation_triplets[i][0]
            obj_id = relation_triplets[i][1]
            pred = relation_triplets[i][2]
            relations[subj_id, obj_id] = int(pred)

        relation_triplets = torch.tensor(relation_triplets)
        target.add_field("relation_labels", relation_triplets)
        target.add_field("pred_labels", relations)
        return target

    def get_target_from_annotations(self, annotations, img_size, idx):
        if self.is_load_label and annotations:
            target = self.label_loader(annotations['objects'], img_size)
            # make sure no boxes are removed
            assert (len(annotations['objects']) == len(target))
            if self.split in ["val", "test"]:
                # add the difficult field
                target.add_field("difficult", torch.zeros(len(target), dtype=torch.int32))
            # load relations
            if self.relation_on:
                target = self.relation_loader(annotations["relations"], target)
            return target

    def get_groundtruth(self, idx, call=False):
        # similar to __getitem__ but without transform
        img = self.get_image(idx)
        if self.cv2_output:
            img_size = img.shape[:2][::-1]  # h, w -> w, h
        else:
            img_size = img.size  # w, h
        annotations = self.get_annotations(idx)
        target = self.get_target_from_annotations(annotations, img_size, idx)
        if call:
            return img, target, annotations
        else:
            return target

    def apply_transforms(self, img, target=None):
        if self.transforms is not None:
            img, target = self.transforms(img, target)
        return img, target

    def map_class_id_to_class_name(self, class_id):
        return self.classes[class_id]

    def map_attribute_id_to_attribute_name(self, attribute_id):
        return self.attributes[attribute_id]

    def map_relation_id_to_relation_name(self, relation_id):
        return self.relations[relation_id]
