import argparse
import logging
import time

import mxnet as mx
import numpy as np
from data import *
from gluoncv.data.batchify import Pad
from model import RelDN, faster_rcnn_resnet101_v1d_custom
from mxnet import gluon, nd
from utils import *

import dgl


def parse_args():
    parser = argparse.ArgumentParser(
        description="Validate Pre-trained RelDN Model."
    )
    parser.add_argument(
        "--gpus",
        type=str,
        default="0",
        help="Training with GPUs, you can specify 1,3 for example.",
    )
    parser.add_argument(
        "--batch-size",
        type=int,
        default=8,
        help="Total batch-size for training.",
    )
    parser.add_argument(
        "--metric",
        type=str,
        default="sgdet",
        help="Evaluation metric, could be 'predcls', 'phrcls', 'sgdet' or 'sgdet+'.",
    )
    parser.add_argument(
        "--pretrained-faster-rcnn-params",
        type=str,
        required=True,
        help="Path to saved Faster R-CNN model parameters.",
    )
    parser.add_argument(
        "--reldn-params",
        type=str,
        required=True,
        help="Path to saved Faster R-CNN model parameters.",
    )
    parser.add_argument(
        "--faster-rcnn-params",
        type=str,
        required=True,
        help="Path to saved Faster R-CNN model parameters.",
    )
    parser.add_argument(
        "--log-dir",
        type=str,
        default="reldn_output.log",
        help="Path to save training logs.",
    )
    parser.add_argument(
        "--freq-prior",
        type=str,
        default="freq_prior.pkl",
        help="Path to saved frequency prior data.",
    )
    parser.add_argument(
        "--verbose-freq",
        type=int,
        default=100,
        help="Frequency of log printing in number of iterations.",
    )
    args = parser.parse_args()
    return args


args = parse_args()

filehandler = logging.FileHandler(args.log_dir)
streamhandler = logging.StreamHandler()
logger = logging.getLogger("")
logger.setLevel(logging.INFO)
logger.addHandler(filehandler)
logger.addHandler(streamhandler)

# Hyperparams
ctx = [mx.gpu(int(i)) for i in args.gpus.split(",") if i.strip()]
if ctx:
    num_gpus = len(ctx)
    assert args.batch_size % num_gpus == 0
    per_device_batch_size = int(args.batch_size / num_gpus)
else:
    ctx = [mx.cpu()]
    per_device_batch_size = args.batch_size
batch_size = args.batch_size
N_relations = 50
N_objects = 150
batch_verbose_freq = args.verbose_freq

mode = args.metric
metric_list = []
topk_list = [20, 50, 100]
if mode == "predcls":
    for topk in topk_list:
        metric_list.append(PredCls(topk=topk))
if mode == "phrcls":
    for topk in topk_list:
        metric_list.append(PhrCls(topk=topk))
if mode == "sgdet":
    for topk in topk_list:
        metric_list.append(SGDet(topk=topk))
if mode == "sgdet+":
    for topk in topk_list:
        metric_list.append(SGDetPlus(topk=topk))
for metric in metric_list:
    metric.reset()

semantic_only = False
net = RelDN(
    n_classes=N_relations,
    prior_pkl=args.freq_prior,
    semantic_only=semantic_only,
)
net.load_parameters(args.reldn_params, ctx=ctx)

# dataset and dataloader
vg_val = VGRelation(split="val")
logger.info("data loaded!")
val_data = gluon.data.DataLoader(
    vg_val,
    batch_size=len(ctx),
    shuffle=False,
    num_workers=16 * num_gpus,
    batchify_fn=dgl_mp_batchify_fn,
)
n_batches = len(val_data)

detector = faster_rcnn_resnet101_v1d_custom(
    classes=vg_val.obj_classes,
    pretrained_base=False,
    pretrained=False,
    additional_output=True,
)
params_path = args.pretrained_faster_rcnn_params
detector.load_parameters(
    params_path, ctx=ctx, ignore_extra=True, allow_missing=True
)

detector_feat = faster_rcnn_resnet101_v1d_custom(
    classes=vg_val.obj_classes,
    pretrained_base=False,
    pretrained=False,
    additional_output=True,
)
detector_feat.load_parameters(
    params_path, ctx=ctx, ignore_extra=True, allow_missing=True
)

detector_feat.features.load_parameters(args.faster_rcnn_params, ctx=ctx)


def get_data_batch(g_list, img_list, ctx_list):
    if g_list is None or len(g_list) == 0:
        return None, None
    n_gpu = len(ctx_list)
    size = len(g_list)
    if size < n_gpu:
        raise Exception("too small batch")
    step = size // n_gpu
    G_list = [
        g_list[i * step : (i + 1) * step]
        if i < n_gpu - 1
        else g_list[i * step : size]
        for i in range(n_gpu)
    ]
    img_list = [
        img_list[i * step : (i + 1) * step]
        if i < n_gpu - 1
        else img_list[i * step : size]
        for i in range(n_gpu)
    ]

    for G_slice, ctx in zip(G_list, ctx_list):
        for G in G_slice:
            G.ndata["bbox"] = G.ndata["bbox"].as_in_context(ctx)
            G.ndata["node_class"] = G.ndata["node_class"].as_in_context(ctx)
            G.ndata["node_class_vec"] = G.ndata["node_class_vec"].as_in_context(
                ctx
            )
            G.edata["rel_class"] = G.edata["rel_class"].as_in_context(ctx)
    img_list = [img.as_in_context(ctx) for img in img_list]
    return G_list, img_list


for i, (G_list, img_list) in enumerate(val_data):
    G_list, img_list = get_data_batch(G_list, img_list, ctx)
    if G_list is None or img_list is None:
        if (i + 1) % batch_verbose_freq == 0:
            print_txt = "Batch[%d/%d] " % (i, n_batches)
            for metric in metric_list:
                metric_name, metric_val = metric.get()
                print_txt += "%s=%.4f " % (metric_name, metric_val)
            logger.info(print_txt)
        continue

    detector_res_list = []
    G_batch = []
    bbox_pad = Pad(axis=(0))
    # loss_cls_val = 0
    for G_slice, img in zip(G_list, img_list):
        cur_ctx = img.context
        if mode == "predcls":
            bbox_list = [G.ndata["bbox"] for G in G_slice]
            bbox_stack = bbox_pad(bbox_list).as_in_context(cur_ctx)
            ids, scores, bbox, spatial_feat = detector(
                img, None, None, bbox_stack
            )

            node_class_list = [G.ndata["node_class"] for G in G_slice]
            node_class_stack = bbox_pad(node_class_list).as_in_context(cur_ctx)
            g_pred_batch = build_graph_validate_gt_obj(
                img,
                node_class_stack,
                bbox,
                spatial_feat,
                bbox_improvement=True,
                overlap=False,
            )
        elif mode == "phrcls":
            # use ground truth bbox
            bbox_list = [G.ndata["bbox"] for G in G_slice]
            bbox_stack = bbox_pad(bbox_list).as_in_context(cur_ctx)
            ids, scores, bbox, spatial_feat = detector(
                img, None, None, bbox_stack
            )

            g_pred_batch = build_graph_validate_gt_bbox(
                img,
                ids,
                scores,
                bbox,
                spatial_feat,
                bbox_improvement=True,
                overlap=False,
            )
        else:
            # use predicted bbox
            ids, scores, bbox, feat, feat_ind, spatial_feat = detector(img)
            g_pred_batch = build_graph_validate_pred(
                img,
                ids,
                scores,
                bbox,
                feat_ind,
                spatial_feat,
                bbox_improvement=True,
                scores_top_k=75,
                overlap=False,
            )
        if not semantic_only:
            rel_bbox = g_pred_batch.edata["rel_bbox"]
            batch_id = g_pred_batch.edata["batch_id"].asnumpy()
            n_sample_edges = g_pred_batch.number_of_edges()
            # g_pred_batch.edata['edge_feat'] = mx.nd.zeros((n_sample_edges, 49), ctx=cur_ctx)
            n_graph = len(G_slice)
            bbox_rel_list = []
            for j in range(n_graph):
                eids = np.where(batch_id == j)[0]
                if len(eids) > 0:
                    bbox_rel_list.append(rel_bbox[eids])
            bbox_rel_stack = bbox_pad(bbox_rel_list).as_in_context(cur_ctx)
            _, _, _, spatial_feat_rel = detector_feat(
                img, None, None, bbox_rel_stack
            )
            spatial_feat_rel_list = []
            for j in range(n_graph):
                eids = np.where(batch_id == j)[0]
                if len(eids) > 0:
                    spatial_feat_rel_list.append(
                        spatial_feat_rel[j, 0 : len(eids)]
                    )
            g_pred_batch.edata["edge_feat"] = nd.concat(
                *spatial_feat_rel_list, dim=0
            )

        G_batch.append(g_pred_batch)

    G_batch = [net(G) for G in G_batch]

    for G_slice, G_pred, img_slice in zip(G_list, G_batch, img_list):
        for G_gt, G_pred_one in zip(G_slice, [G_pred]):
            if G_pred_one is None or G_pred_one.number_of_nodes() == 0:
                continue
            gt_objects, gt_triplet = extract_gt(G_gt, img_slice.shape[2:4])
            pred_objects, pred_triplet = extract_pred(G_pred, joint_preds=True)
            for metric in metric_list:
                if (
                    isinstance(metric, PredCls)
                    or isinstance(metric, PhrCls)
                    or isinstance(metric, SGDet)
                ):
                    metric.update(gt_triplet, pred_triplet)
                else:
                    metric.update(
                        (gt_objects, gt_triplet), (pred_objects, pred_triplet)
                    )
    if (i + 1) % batch_verbose_freq == 0:
        print_txt = "Batch[%d/%d] " % (i, n_batches)
        for metric in metric_list:
            metric_name, metric_val = metric.get()
            print_txt += "%s=%.4f " % (metric_name, metric_val)
        logger.info(print_txt)

print_txt = "Batch[%d/%d] " % (n_batches, n_batches)
for metric in metric_list:
    metric_name, metric_val = metric.get()
    print_txt += "%s=%.4f " % (metric_name, metric_val)
logger.info(print_txt)
