"""Pascal VOC object detection dataset."""
from __future__ import absolute_import, division

import json
import logging
import os
import pickle
import warnings
from collections import Counter

import mxnet as mx
import numpy as np
from gluoncv.data.base import VisionDataset
from gluoncv.data.transforms.presets.rcnn import (
    FasterRCNNDefaultTrainTransform, FasterRCNNDefaultValTransform)

import dgl


class VGRelation(VisionDataset):
    def __init__(
        self,
        root=os.path.join("~", ".mxnet", "datasets", "visualgenome"),
        split="train",
    ):
        super(VGRelation, self).__init__(root)
        self._root = os.path.expanduser(root)
        self._img_path = os.path.join(self._root, "VG_100K", "{}")

        if split == "train":
            self._dict_path = os.path.join(
                self._root, "rel_annotations_train.json"
            )
        elif split == "val":
            self._dict_path = os.path.join(
                self._root, "rel_annotations_val.json"
            )
        else:
            raise NotImplementedError
        with open(self._dict_path) as f:
            tmp = f.read()
            self._dict = json.loads(tmp)

        self._predicates_path = os.path.join(self._root, "predicates.json")
        with open(self._predicates_path, "r") as f:
            tmp = f.read()
            self.rel_classes = json.loads(tmp)
        self.num_rel_classes = len(self.rel_classes) + 1

        self._objects_path = os.path.join(self._root, "objects.json")
        with open(self._objects_path, "r") as f:
            tmp = f.read()
            self.obj_classes = json.loads(tmp)
        self.num_obj_classes = len(self.obj_classes)

        if split == "val":
            self.img_transform = FasterRCNNDefaultValTransform(
                short=600, max_size=1000
            )
        else:
            self.img_transform = FasterRCNNDefaultTrainTransform(
                short=600, max_size=1000
            )
        self.split = split

    def __len__(self):
        return len(self._dict)

    def _hash_bbox(self, object):
        num_list = [object["category"]] + object["bbox"]
        return "_".join([str(num) for num in num_list])

    def __getitem__(self, idx):
        img_id = list(self._dict)[idx]
        img_path = self._img_path.format(img_id)
        img = mx.image.imread(img_path)

        item = self._dict[img_id]
        n_edges = len(item)

        # edge to node ids
        sub_node_hash = []
        ob_node_hash = []
        for i, it in enumerate(item):
            sub_node_hash.append(self._hash_bbox(it["subject"]))
            ob_node_hash.append(self._hash_bbox(it["object"]))
        node_set = sorted(list(set(sub_node_hash + ob_node_hash)))
        n_nodes = len(node_set)
        node_to_id = {}
        for i, node in enumerate(node_set):
            node_to_id[node] = i
        sub_id = []
        ob_id = []
        for i in range(n_edges):
            sub_id.append(node_to_id[sub_node_hash[i]])
            ob_id.append(node_to_id[ob_node_hash[i]])

        # node features
        bbox = mx.nd.zeros((n_nodes, 4))
        node_class_ids = mx.nd.zeros((n_nodes, 1))
        node_visited = [False for i in range(n_nodes)]
        for i, it in enumerate(item):
            if not node_visited[sub_id[i]]:
                ind = sub_id[i]
                sub = it["subject"]
                node_class_ids[ind] = sub["category"]
                # y1y2x1x2 to x1y1x2y2
                bbox[ind, 0] = sub["bbox"][2]
                bbox[ind, 1] = sub["bbox"][0]
                bbox[ind, 2] = sub["bbox"][3]
                bbox[ind, 3] = sub["bbox"][1]

                node_visited[ind] = True

            if not node_visited[ob_id[i]]:
                ind = ob_id[i]
                ob = it["object"]
                node_class_ids[ind] = ob["category"]
                # y1y2x1x2 to x1y1x2y2
                bbox[ind, 0] = ob["bbox"][2]
                bbox[ind, 1] = ob["bbox"][0]
                bbox[ind, 2] = ob["bbox"][3]
                bbox[ind, 3] = ob["bbox"][1]

                node_visited[ind] = True

        eta = 0.1
        node_class_vec = node_class_ids[:, 0].one_hot(
            self.num_obj_classes,
            on_value=1 - eta + eta / self.num_obj_classes,
            off_value=eta / self.num_obj_classes,
        )

        # augmentation
        if self.split == "val":
            img, bbox, _ = self.img_transform(img, bbox)
        else:
            img, bbox = self.img_transform(img, bbox)

        # build the graph
        g = dgl.DGLGraph()
        g.add_nodes(n_nodes)
        adjmat = np.zeros((n_nodes, n_nodes))
        predicate = []
        for i, it in enumerate(item):
            adjmat[sub_id[i], ob_id[i]] = 1
            predicate.append(it["predicate"])
        predicate = mx.nd.array(predicate).expand_dims(1)
        g.add_edges(sub_id, ob_id, {"rel_class": mx.nd.array(predicate) + 1})
        empty_edge_list = []
        for i in range(n_nodes):
            for j in range(n_nodes):
                if i != j and adjmat[i, j] == 0:
                    empty_edge_list.append((i, j))
        if len(empty_edge_list) > 0:
            src, dst = tuple(zip(*empty_edge_list))
            g.add_edges(
                src, dst, {"rel_class": mx.nd.zeros((len(empty_edge_list), 1))}
            )

        # assign features
        g.ndata["bbox"] = bbox
        g.ndata["node_class"] = node_class_ids
        g.ndata["node_class_vec"] = node_class_vec

        return g, img
