import logging
import time
from operator import attrgetter, itemgetter

import mxnet as mx
import numpy as np
from gluoncv.data.batchify import Pad
from gluoncv.model_zoo import get_model
from mxnet import gluon, nd
from mxnet.gluon import nn

import dgl
from dgl.nn.mxnet import GraphConv
from dgl.utils import toindex


def iou(boxA, boxB):
    # determine the (x, y)-coordinates of the intersection rectangle
    xA = max(boxA[0], boxB[0])
    yA = max(boxA[1], boxB[1])
    xB = min(boxA[2], boxB[2])
    yB = min(boxA[3], boxB[3])

    interArea = max(0, xB - xA) * max(0, yB - yA)
    if interArea < 1e-7:
        return 0

    boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])
    boxBArea = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1])
    if boxAArea + boxBArea - interArea < 1e-7:
        return 0

    iou_val = interArea / float(boxAArea + boxBArea - interArea)
    return iou_val


def object_iou_thresh(gt_object, pred_object, iou_thresh=0.5):
    obj_iou = iou(gt_object[1:5], pred_object[1:5])
    if obj_iou >= iou_thresh:
        return True
    return False


def triplet_iou_thresh(pred_triplet, gt_triplet, iou_thresh=0.5):
    sub_iou = iou(gt_triplet[5:9], pred_triplet[5:9])
    if sub_iou >= iou_thresh:
        ob_iou = iou(gt_triplet[9:13], pred_triplet[9:13])
        if ob_iou >= iou_thresh:
            return True
    return False


@mx.metric.register
@mx.metric.alias("auc")
class AUCMetric(mx.metric.EvalMetric):
    def __init__(self, name="auc", eps=1e-12):
        super(AUCMetric, self).__init__(name)
        self.eps = eps

    def update(self, labels, preds):
        mx.metric.check_label_shapes(labels, preds)
        label_weight = labels[0].asnumpy()
        preds = preds[0].asnumpy()
        tmp = []
        for i in range(preds.shape[0]):
            tmp.append((label_weight[i], preds[i][1]))
        tmp = sorted(tmp, key=itemgetter(1), reverse=True)
        label_sum = label_weight.sum()
        if label_sum == 0 or label_sum == label_weight.size:
            return

        label_one_num = np.count_nonzero(label_weight)
        label_zero_num = len(label_weight) - label_one_num
        total_area = label_zero_num * label_one_num
        height = 0
        width = 0
        area = 0
        for a, _ in tmp:
            if a == 1.0:
                height += 1.0
            else:
                width += 1.0
                area += height

        self.sum_metric += area / total_area
        self.num_inst += 1


@mx.metric.register
@mx.metric.alias("predcls")
class PredCls(mx.metric.EvalMetric):
    """Metric with ground truth object location and label"""

    def __init__(self, topk=20, iou_thresh=0.99):
        super(PredCls, self).__init__("predcls@%d" % (topk))
        self.topk = topk
        self.iou_thresh = iou_thresh

    def update(self, labels, preds):
        if labels is None or preds is None:
            self.num_inst += 1
            return
        preds = preds[preds[:, 0].argsort()[::-1]]
        m = min(self.topk, preds.shape[0])
        count = 0
        gt_edge_num = labels.shape[0]
        label_matched = [False for label in labels]
        for i in range(m):
            pred = preds[i]
            for j in range(gt_edge_num):
                if label_matched[j]:
                    continue
                label = labels[j]
                if int(label[2]) == int(pred[2]) and triplet_iou_thresh(
                    pred, label, self.iou_thresh
                ):
                    count += 1
                    label_matched[j] = True

        total = labels.shape[0]
        self.sum_metric += count / total
        self.num_inst += 1


@mx.metric.register
@mx.metric.alias("phrcls")
class PhrCls(mx.metric.EvalMetric):
    """Metric with ground truth object location and predicted object label from detector"""

    def __init__(self, topk=20, iou_thresh=0.99):
        super(PhrCls, self).__init__("phrcls@%d" % (topk))
        self.topk = topk
        self.iou_thresh = iou_thresh

    def update(self, labels, preds):
        if labels is None or preds is None:
            self.num_inst += 1
            return
        preds = preds[preds[:, 1].argsort()[::-1]]
        m = min(self.topk, preds.shape[0])
        count = 0
        gt_edge_num = labels.shape[0]
        label_matched = [False for label in labels]
        for i in range(m):
            pred = preds[i]
            for j in range(gt_edge_num):
                if label_matched[j]:
                    continue
                label = labels[j]
                if (
                    int(label[2]) == int(pred[2])
                    and int(label[3]) == int(pred[3])
                    and int(label[4]) == int(pred[4])
                    and triplet_iou_thresh(pred, label, self.iou_thresh)
                ):
                    count += 1
                    label_matched[j] = True
        total = labels.shape[0]
        self.sum_metric += count / total
        self.num_inst += 1


@mx.metric.register
@mx.metric.alias("sgdet")
class SGDet(mx.metric.EvalMetric):
    """Metric with predicted object information by the detector"""

    def __init__(self, topk=20, iou_thresh=0.5):
        super(SGDet, self).__init__("sgdet@%d" % (topk))
        self.topk = topk
        self.iou_thresh = iou_thresh

    def update(self, labels, preds):
        if labels is None or preds is None:
            self.num_inst += 1
            return
        preds = preds[preds[:, 1].argsort()[::-1]]
        m = min(self.topk, len(preds))
        count = 0
        gt_edge_num = labels.shape[0]
        label_matched = [False for label in labels]
        for i in range(m):
            pred = preds[i]
            for j in range(gt_edge_num):
                if label_matched[j]:
                    continue
                label = labels[j]
                if (
                    int(label[2]) == int(pred[2])
                    and int(label[3]) == int(pred[3])
                    and int(label[4]) == int(pred[4])
                    and triplet_iou_thresh(pred, label, self.iou_thresh)
                ):
                    count += 1
                    label_matched[j] = True
        total = labels.shape[0]
        self.sum_metric += count / total
        self.num_inst += 1


@mx.metric.register
@mx.metric.alias("sgdet+")
class SGDetPlus(mx.metric.EvalMetric):
    """Metric proposed by `Graph R-CNN for Scene Graph Generation`"""

    def __init__(self, topk=20, iou_thresh=0.5):
        super(SGDetPlus, self).__init__("sgdet+@%d" % (topk))
        self.topk = topk
        self.iou_thresh = iou_thresh

    def update(self, labels, preds):
        label_objects, label_triplets = labels
        pred_objects, pred_triplets = preds
        if label_objects is None or pred_objects is None:
            self.num_inst += 1
            return
        count = 0
        # count objects
        object_matched = [False for obj in label_objects]
        m = len(pred_objects)
        gt_obj_num = label_objects.shape[0]
        for i in range(m):
            pred = pred_objects[i]
            for j in range(gt_obj_num):
                if object_matched[j]:
                    continue
                label = label_objects[j]
                if int(label[0]) == int(pred[0]) and object_iou_thresh(
                    pred, label, self.iou_thresh
                ):
                    count += 1
                    object_matched[j] = True

        # count predicate and triplet
        pred_triplets = pred_triplets[pred_triplets[:, 1].argsort()[::-1]]
        m = min(self.topk, len(pred_triplets))
        gt_triplet_num = label_triplets.shape[0]
        triplet_matched = [False for label in label_triplets]
        predicate_matched = [False for label in label_triplets]
        for i in range(m):
            pred = pred_triplets[i]
            for j in range(gt_triplet_num):
                label = label_triplets[j]
                if not predicate_matched:
                    if int(label[2]) == int(pred[2]) and triplet_iou_thresh(
                        pred, label, self.iou_thresh
                    ):
                        count += label[3]
                        predicate_matched[j] = True
                if not triplet_matched[j]:
                    if (
                        int(label[2]) == int(pred[2])
                        and int(label[3]) == int(pred[3])
                        and int(label[4]) == int(pred[4])
                        and triplet_iou_thresh(pred, label, self.iou_thresh)
                    ):
                        count += 1
                        triplet_matched[j] = True
        # compute sum
        total = labels.shape[0]
        N = gt_obj_num + 2 * total
        self.sum_metric += count / N
        self.num_inst += 1


def extract_gt(g, img_size):
    """extract prediction from ground truth graph"""
    if g is None or g.number_of_nodes() == 0:
        return None, None
    gt_eids = np.where(g.edata["rel_class"].asnumpy() > 0)[0]
    if len(gt_eids) == 0:
        return None, None

    gt_class = g.ndata["node_class"][:, 0].asnumpy()
    gt_bbox = g.ndata["bbox"].asnumpy()
    gt_bbox[:, 0] /= img_size[1]
    gt_bbox[:, 1] /= img_size[0]
    gt_bbox[:, 2] /= img_size[1]
    gt_bbox[:, 3] /= img_size[0]

    gt_objects = np.vstack([gt_class, gt_bbox.transpose(1, 0)]).transpose(1, 0)

    gt_node_ids = g.find_edges(gt_eids)
    gt_node_sub = gt_node_ids[0].asnumpy()
    gt_node_ob = gt_node_ids[1].asnumpy()
    gt_rel_class = g.edata["rel_class"][gt_eids, 0].asnumpy() - 1
    gt_sub_class = gt_class[gt_node_sub]
    gt_ob_class = gt_class[gt_node_ob]

    gt_sub_bbox = gt_bbox[gt_node_sub]
    gt_ob_bbox = gt_bbox[gt_node_ob]

    n = len(gt_eids)
    gt_triplets = np.vstack(
        [
            np.ones(n),
            np.ones(n),
            gt_rel_class,
            gt_sub_class,
            gt_ob_class,
            gt_sub_bbox.transpose(1, 0),
            gt_ob_bbox.transpose(1, 0),
        ]
    ).transpose(1, 0)
    return gt_objects, gt_triplets


def extract_pred(g, topk=100, joint_preds=False):
    """extract prediction from prediction graph for validation and visualization"""
    if g is None or g.number_of_nodes() == 0:
        return None, None

    pred_class = g.ndata["node_class_pred"].asnumpy()
    pred_class_prob = g.ndata["node_class_logit"].asnumpy()
    pred_bbox = g.ndata["pred_bbox"][:, 0:4].asnumpy()

    pred_objects = np.vstack([pred_class, pred_bbox.transpose(1, 0)]).transpose(
        1, 0
    )

    score_pred = g.edata["score_pred"].asnumpy()
    score_phr = g.edata["score_phr"].asnumpy()
    score_pred_topk_eids = (-score_pred).argsort()[0:topk].tolist()
    score_phr_topk_eids = (-score_phr).argsort()[0:topk].tolist()
    topk_eids = sorted(list(set(score_pred_topk_eids + score_phr_topk_eids)))

    pred_rel_prob = g.edata["preds"][topk_eids].asnumpy()
    if joint_preds:
        pred_rel_class = pred_rel_prob[:, 1:].argmax(axis=1)
    else:
        pred_rel_class = pred_rel_prob.argmax(axis=1)

    pred_node_ids = g.find_edges(topk_eids)
    pred_node_sub = pred_node_ids[0].asnumpy()
    pred_node_ob = pred_node_ids[1].asnumpy()

    pred_sub_class = pred_class[pred_node_sub]
    pred_sub_class_prob = pred_class_prob[pred_node_sub]
    pred_sub_bbox = pred_bbox[pred_node_sub]

    pred_ob_class = pred_class[pred_node_ob]
    pred_ob_class_prob = pred_class_prob[pred_node_ob]
    pred_ob_bbox = pred_bbox[pred_node_ob]

    pred_triplets = np.vstack(
        [
            score_pred[topk_eids],
            score_phr[topk_eids],
            pred_rel_class,
            pred_sub_class,
            pred_ob_class,
            pred_sub_bbox.transpose(1, 0),
            pred_ob_bbox.transpose(1, 0),
        ]
    ).transpose(1, 0)
    return pred_objects, pred_triplets
