import numpy as np
from mxnet import nd

import dgl


def bbox_improve(bbox):
    """bbox encoding"""
    area = (bbox[:, 2] - bbox[:, 0]) * (bbox[:, 3] - bbox[:, 1])
    return nd.concat(bbox, area.expand_dims(1))


def extract_edge_bbox(g):
    """bbox encoding"""
    src, dst = g.edges(order="eid")
    n = g.number_of_edges()
    src_bbox = g.ndata["pred_bbox"][src.asnumpy()]
    dst_bbox = g.ndata["pred_bbox"][dst.asnumpy()]
    edge_bbox = nd.zeros((n, 4), ctx=g.ndata["pred_bbox"].context)
    edge_bbox[:, 0] = nd.stack(src_bbox[:, 0], dst_bbox[:, 0]).min(axis=0)
    edge_bbox[:, 1] = nd.stack(src_bbox[:, 1], dst_bbox[:, 1]).min(axis=0)
    edge_bbox[:, 2] = nd.stack(src_bbox[:, 2], dst_bbox[:, 2]).max(axis=0)
    edge_bbox[:, 3] = nd.stack(src_bbox[:, 3], dst_bbox[:, 3]).max(axis=0)
    return edge_bbox


def build_graph_train(
    g_slice,
    gt_bbox,
    img,
    ids,
    scores,
    bbox,
    feat_ind,
    spatial_feat,
    iou_thresh=0.5,
    bbox_improvement=True,
    scores_top_k=50,
    overlap=False,
):
    """given ground truth and predicted bboxes, assign the label to the predicted w.r.t iou_thresh"""
    # match and re-factor the graph
    img_size = img.shape[2:4]
    gt_bbox[:, :, 0] /= img_size[1]
    gt_bbox[:, :, 1] /= img_size[0]
    gt_bbox[:, :, 2] /= img_size[1]
    gt_bbox[:, :, 3] /= img_size[0]
    bbox[:, :, 0] /= img_size[1]
    bbox[:, :, 1] /= img_size[0]
    bbox[:, :, 2] /= img_size[1]
    bbox[:, :, 3] /= img_size[0]

    n_graph = len(g_slice)
    g_pred_batch = []
    for gi in range(n_graph):
        g = g_slice[gi]
        ctx = g.ndata["bbox"].context
        inds = np.where(scores[gi, :, 0].asnumpy() > 0)[0].tolist()
        if len(inds) == 0:
            return None
        if len(inds) > scores_top_k:
            top_score_inds = (
                scores[gi, inds, 0].asnumpy().argsort()[::-1][0:scores_top_k]
            )
            inds = np.array(inds)[top_score_inds].tolist()

        n_nodes = len(inds)
        roi_ind = feat_ind[gi, inds].squeeze(axis=1)
        g_pred = dgl.DGLGraph()
        g_pred.add_nodes(
            n_nodes,
            {
                "pred_bbox": bbox[gi, inds],
                "node_feat": spatial_feat[gi, roi_ind],
                "node_class_pred": ids[gi, inds, 0],
                "node_class_logit": nd.log(scores[gi, inds, 0] + 1e-7),
            },
        )

        # iou matching
        ious = nd.contrib.box_iou(
            gt_bbox[gi], g_pred.ndata["pred_bbox"]
        ).asnumpy()
        H, W = ious.shape
        h = H
        w = W
        pred_to_gt_ind = np.array([-1 for i in range(W)])
        pred_to_gt_class_match = [0 for i in range(W)]
        pred_to_gt_class_match_id = [0 for i in range(W)]
        while h > 0 and w > 0:
            ind = int(ious.argmax())
            row_ind = ind // W
            col_ind = ind % W
            if ious[row_ind, col_ind] < iou_thresh:
                break
            pred_to_gt_ind[col_ind] = row_ind
            gt_node_class = g.ndata["node_class"][row_ind]
            pred_node_class = g_pred.ndata["node_class_pred"][col_ind]
            if gt_node_class == pred_node_class:
                pred_to_gt_class_match[col_ind] = 1
                pred_to_gt_class_match_id[col_ind] = row_ind
            ious[row_ind, :] = -1
            ious[:, col_ind] = -1
            h -= 1
            w -= 1

        n_nodes = g_pred.number_of_nodes()
        triplet = []
        adjmat = np.zeros((n_nodes, n_nodes))

        src, dst = g.all_edges(order="eid")
        eid_keys = np.column_stack([src.asnumpy(), dst.asnumpy()])
        eid_dict = {}
        for i, key in enumerate(eid_keys):
            k = tuple(key)
            if k not in eid_dict:
                eid_dict[k] = [i]
            else:
                eid_dict[k].append(i)
        ori_rel_class = g.edata["rel_class"].asnumpy()
        for i in range(n_nodes):
            for j in range(n_nodes):
                if i != j:
                    if pred_to_gt_class_match[i] and pred_to_gt_class_match[j]:
                        sub_gt_id = pred_to_gt_class_match_id[i]
                        ob_gt_id = pred_to_gt_class_match_id[j]
                        eids = eid_dict[(sub_gt_id, ob_gt_id)]
                        rel_cls = ori_rel_class[eids]
                        n_edges_between = len(rel_cls)
                        for ii in range(n_edges_between):
                            triplet.append((i, j, rel_cls[ii]))
                        adjmat[i, j] = 1
                    else:
                        triplet.append((i, j, 0))
        src, dst, rel_class = tuple(zip(*triplet))
        rel_class = nd.array(rel_class, ctx=ctx).expand_dims(1)
        g_pred.add_edges(src, dst, data={"rel_class": rel_class})

        # other operations
        n_nodes = g_pred.number_of_nodes()
        n_edges = g_pred.number_of_edges()
        if bbox_improvement:
            g_pred.ndata["pred_bbox"] = bbox_improve(g_pred.ndata["pred_bbox"])
        g_pred.edata["rel_bbox"] = extract_edge_bbox(g_pred)
        g_pred.edata["batch_id"] = nd.zeros((n_edges, 1), ctx=ctx) + gi

        # remove non-overlapping edges
        if overlap:
            overlap_ious = nd.contrib.box_iou(
                g_pred.ndata["pred_bbox"][:, 0:4],
                g_pred.ndata["pred_bbox"][:, 0:4],
            ).asnumpy()
            cols, rows = np.where(overlap_ious <= 1e-7)
            if cols.shape[0] > 0:
                eids = g_pred.edge_ids(cols, rows)[2].asnumpy().tolist()
                if len(eids):
                    g_pred.remove_edges(eids)
                    if g_pred.number_of_edges() == 0:
                        g_pred = None
        g_pred_batch.append(g_pred)

    if n_graph > 1:
        return dgl.batch(g_pred_batch)
    else:
        return g_pred_batch[0]


def build_graph_validate_gt_obj(
    img, gt_ids, bbox, spatial_feat, bbox_improvement=True, overlap=False
):
    """given ground truth bbox and label, build graph for validation"""
    n_batch = img.shape[0]
    img_size = img.shape[2:4]
    bbox[:, :, 0] /= img_size[1]
    bbox[:, :, 1] /= img_size[0]
    bbox[:, :, 2] /= img_size[1]
    bbox[:, :, 3] /= img_size[0]
    ctx = img.context

    g_batch = []
    for btc in range(n_batch):
        inds = np.where(bbox[btc].sum(1).asnumpy() > 0)[0].tolist()
        if len(inds) == 0:
            continue
        n_nodes = len(inds)
        g_pred = dgl.DGLGraph()
        g_pred.add_nodes(
            n_nodes,
            {
                "pred_bbox": bbox[btc, inds],
                "node_feat": spatial_feat[btc, inds],
                "node_class_pred": gt_ids[btc, inds, 0],
                "node_class_logit": nd.zeros_like(
                    gt_ids[btc, inds, 0], ctx=ctx
                ),
            },
        )

        edge_list = []
        for i in range(n_nodes - 1):
            for j in range(i + 1, n_nodes):
                edge_list.append((i, j))
        src, dst = tuple(zip(*edge_list))
        g_pred.add_edges(src, dst)
        g_pred.add_edges(dst, src)

        n_nodes = g_pred.number_of_nodes()
        n_edges = g_pred.number_of_edges()
        if bbox_improvement:
            g_pred.ndata["pred_bbox"] = bbox_improve(g_pred.ndata["pred_bbox"])
        g_pred.edata["rel_bbox"] = extract_edge_bbox(g_pred)
        g_pred.edata["batch_id"] = nd.zeros((n_edges, 1), ctx=ctx) + btc

        g_batch.append(g_pred)

    if len(g_batch) == 0:
        return None
    if len(g_batch) > 1:
        return dgl.batch(g_batch)
    return g_batch[0]


def build_graph_validate_gt_bbox(
    img,
    ids,
    scores,
    bbox,
    spatial_feat,
    gt_ids=None,
    bbox_improvement=True,
    overlap=False,
):
    """given ground truth bbox, build graph for validation"""
    n_batch = img.shape[0]
    img_size = img.shape[2:4]
    bbox[:, :, 0] /= img_size[1]
    bbox[:, :, 1] /= img_size[0]
    bbox[:, :, 2] /= img_size[1]
    bbox[:, :, 3] /= img_size[0]
    ctx = img.context

    g_batch = []
    for btc in range(n_batch):
        id_btc = scores[btc][:, :, 0].argmax(0)
        score_btc = scores[btc][:, :, 0].max(0)
        inds = np.where(bbox[btc].sum(1).asnumpy() > 0)[0].tolist()
        if len(inds) == 0:
            continue
        n_nodes = len(inds)
        g_pred = dgl.DGLGraph()
        g_pred.add_nodes(
            n_nodes,
            {
                "pred_bbox": bbox[btc, inds],
                "node_feat": spatial_feat[btc, inds],
                "node_class_pred": id_btc,
                "node_class_logit": nd.log(score_btc + 1e-7),
            },
        )

        edge_list = []
        for i in range(n_nodes - 1):
            for j in range(i + 1, n_nodes):
                edge_list.append((i, j))
        src, dst = tuple(zip(*edge_list))
        g_pred.add_edges(src, dst)
        g_pred.add_edges(dst, src)

        n_nodes = g_pred.number_of_nodes()
        n_edges = g_pred.number_of_edges()
        if bbox_improvement:
            g_pred.ndata["pred_bbox"] = bbox_improve(g_pred.ndata["pred_bbox"])
        g_pred.edata["rel_bbox"] = extract_edge_bbox(g_pred)
        g_pred.edata["batch_id"] = nd.zeros((n_edges, 1), ctx=ctx) + btc

        g_batch.append(g_pred)

    if len(g_batch) == 0:
        return None
    if len(g_batch) > 1:
        return dgl.batch(g_batch)
    return g_batch[0]


def build_graph_validate_pred(
    img,
    ids,
    scores,
    bbox,
    feat_ind,
    spatial_feat,
    bbox_improvement=True,
    scores_top_k=50,
    overlap=False,
):
    """given predicted bbox, build graph for validation"""
    n_batch = img.shape[0]
    img_size = img.shape[2:4]
    bbox[:, :, 0] /= img_size[1]
    bbox[:, :, 1] /= img_size[0]
    bbox[:, :, 2] /= img_size[1]
    bbox[:, :, 3] /= img_size[0]
    ctx = img.context

    g_batch = []
    for btc in range(n_batch):
        inds = np.where(scores[btc, :, 0].asnumpy() > 0)[0].tolist()
        if len(inds) == 0:
            continue
        if len(inds) > scores_top_k:
            top_score_inds = (
                scores[btc, inds, 0].asnumpy().argsort()[::-1][0:scores_top_k]
            )
            inds = np.array(inds)[top_score_inds].tolist()
        n_nodes = len(inds)
        roi_ind = feat_ind[btc, inds].squeeze(axis=1)

        g_pred = dgl.DGLGraph()
        g_pred.add_nodes(
            n_nodes,
            {
                "pred_bbox": bbox[btc, inds],
                "node_feat": spatial_feat[btc, roi_ind],
                "node_class_pred": ids[btc, inds, 0],
                "node_class_logit": nd.log(scores[btc, inds, 0] + 1e-7),
            },
        )

        edge_list = []
        for i in range(n_nodes - 1):
            for j in range(i + 1, n_nodes):
                edge_list.append((i, j))
        src, dst = tuple(zip(*edge_list))
        g_pred.add_edges(src, dst)
        g_pred.add_edges(dst, src)

        n_nodes = g_pred.number_of_nodes()
        n_edges = g_pred.number_of_edges()
        if bbox_improvement:
            g_pred.ndata["pred_bbox"] = bbox_improve(g_pred.ndata["pred_bbox"])
        g_pred.edata["rel_bbox"] = extract_edge_bbox(g_pred)
        g_pred.edata["batch_id"] = nd.zeros((n_edges, 1), ctx=ctx) + btc

        g_batch.append(g_pred)

    if len(g_batch) == 0:
        return None
    if len(g_batch) > 1:
        return dgl.batch(g_batch)
    return g_batch[0]
