import argparse
import logging
import time

import mxnet as mx
import numpy as np
from data import *
from gluoncv.data.batchify import Pad
from gluoncv.utils import makedirs
from model import RelDN, faster_rcnn_resnet101_v1d_custom
from mxnet import gluon, nd
from utils import *

import dgl


def parse_args():
    parser = argparse.ArgumentParser(description="Train RelDN Model.")
    parser.add_argument(
        "--gpus",
        type=str,
        default="0",
        help="Training with GPUs, you can specify 1,3 for example.",
    )
    parser.add_argument(
        "--batch-size",
        type=int,
        default=8,
        help="Total batch-size for training.",
    )
    parser.add_argument(
        "--epochs", type=int, default=9, help="Training epochs."
    )
    parser.add_argument(
        "--lr-reldn",
        type=float,
        default=0.01,
        help="Learning rate for RelDN module.",
    )
    parser.add_argument(
        "--wd-reldn",
        type=float,
        default=0.0001,
        help="Weight decay for RelDN module.",
    )
    parser.add_argument(
        "--lr-faster-rcnn",
        type=float,
        default=0.01,
        help="Learning rate for Faster R-CNN module.",
    )
    parser.add_argument(
        "--wd-faster-rcnn",
        type=float,
        default=0.0001,
        help="Weight decay for RelDN module.",
    )
    parser.add_argument(
        "--lr-decay-epochs",
        type=str,
        default="5,8",
        help="Learning rate decay points.",
    )
    parser.add_argument(
        "--lr-warmup-iters",
        type=int,
        default=4000,
        help="Learning rate warm-up iterations.",
    )
    parser.add_argument(
        "--save-dir",
        type=str,
        default="params_resnet101_v1d_reldn",
        help="Path to save model parameters.",
    )
    parser.add_argument(
        "--log-dir",
        type=str,
        default="reldn_output.log",
        help="Path to save training logs.",
    )
    parser.add_argument(
        "--pretrained-faster-rcnn-params",
        type=str,
        required=True,
        help="Path to saved Faster R-CNN model parameters.",
    )
    parser.add_argument(
        "--freq-prior",
        type=str,
        default="freq_prior.pkl",
        help="Path to saved frequency prior data.",
    )
    parser.add_argument(
        "--verbose-freq",
        type=int,
        default=100,
        help="Frequency of log printing in number of iterations.",
    )

    args = parser.parse_args()
    return args


args = parse_args()

filehandler = logging.FileHandler(args.log_dir)
streamhandler = logging.StreamHandler()
logger = logging.getLogger("")
logger.setLevel(logging.INFO)
logger.addHandler(filehandler)
logger.addHandler(streamhandler)

# Hyperparams
ctx = [mx.gpu(int(i)) for i in args.gpus.split(",") if i.strip()]
if ctx:
    num_gpus = len(ctx)
    assert args.batch_size % num_gpus == 0
    per_device_batch_size = int(args.batch_size / num_gpus)
else:
    ctx = [mx.cpu()]
    per_device_batch_size = args.batch_size

aggregate_grad = per_device_batch_size > 1

nepoch = args.epochs
N_relations = 50
N_objects = 150
save_dir = args.save_dir
makedirs(save_dir)
batch_verbose_freq = args.verbose_freq
lr_decay_epochs = [int(i) for i in args.lr_decay_epochs.split(",")]

# Dataset and dataloader
vg_train = VGRelation(split="train")
logger.info("data loaded!")
train_data = gluon.data.DataLoader(
    vg_train,
    batch_size=len(ctx),
    shuffle=True,
    num_workers=8 * num_gpus,
    batchify_fn=dgl_mp_batchify_fn,
)
n_batches = len(train_data)

# Network definition
net = RelDN(n_classes=N_relations, prior_pkl=args.freq_prior)
net.spatial.initialize(mx.init.Normal(1e-4), ctx=ctx)
net.visual.initialize(mx.init.Normal(1e-4), ctx=ctx)
for k, v in net.collect_params().items():
    v.grad_req = "add" if aggregate_grad else "write"
net_params = net.collect_params()
net_trainer = gluon.Trainer(
    net.collect_params(),
    "adam",
    {"learning_rate": args.lr_reldn, "wd": args.wd_reldn},
)

det_params_path = args.pretrained_faster_rcnn_params
detector = faster_rcnn_resnet101_v1d_custom(
    classes=vg_train.obj_classes,
    pretrained_base=False,
    pretrained=False,
    additional_output=True,
)
detector.load_parameters(
    det_params_path, ctx=ctx, ignore_extra=True, allow_missing=True
)
for k, v in detector.collect_params().items():
    v.grad_req = "null"

detector_feat = faster_rcnn_resnet101_v1d_custom(
    classes=vg_train.obj_classes,
    pretrained_base=False,
    pretrained=False,
    additional_output=True,
)
detector_feat.load_parameters(
    det_params_path, ctx=ctx, ignore_extra=True, allow_missing=True
)
for k, v in detector_feat.collect_params().items():
    v.grad_req = "null"
for k, v in detector_feat.features.collect_params().items():
    v.grad_req = "add" if aggregate_grad else "write"
det_params = detector_feat.features.collect_params()
det_trainer = gluon.Trainer(
    detector_feat.features.collect_params(),
    "adam",
    {"learning_rate": args.lr_faster_rcnn, "wd": args.wd_faster_rcnn},
)


def get_data_batch(g_list, img_list, ctx_list):
    if g_list is None or len(g_list) == 0:
        return None, None
    n_gpu = len(ctx_list)
    size = len(g_list)
    if size < n_gpu:
        raise Exception("too small batch")
    step = size // n_gpu
    G_list = [
        g_list[i * step : (i + 1) * step]
        if i < n_gpu - 1
        else g_list[i * step : size]
        for i in range(n_gpu)
    ]
    img_list = [
        img_list[i * step : (i + 1) * step]
        if i < n_gpu - 1
        else img_list[i * step : size]
        for i in range(n_gpu)
    ]

    for G_slice, ctx in zip(G_list, ctx_list):
        for G in G_slice:
            G.ndata["bbox"] = G.ndata["bbox"].as_in_context(ctx)
            G.ndata["node_class"] = G.ndata["node_class"].as_in_context(ctx)
            G.ndata["node_class_vec"] = G.ndata["node_class_vec"].as_in_context(
                ctx
            )
            G.edata["rel_class"] = G.edata["rel_class"].as_in_context(ctx)
    img_list = [img.as_in_context(ctx) for img in img_list]
    return G_list, img_list


L_rel = gluon.loss.SoftmaxCELoss()

train_metric = mx.metric.Accuracy(name="rel_acc")
train_metric_top5 = mx.metric.TopKAccuracy(5, name="rel_acc_top5")
metric_list = [train_metric, train_metric_top5]


def batch_print(
    epoch, i, batch_verbose_freq, n_batches, btic, loss_rel_val, metric_list
):
    if (i + 1) % batch_verbose_freq == 0:
        print_txt = "Epoch[%d] Batch[%d/%d], time: %d, loss_rel=%.4f " % (
            epoch,
            i,
            n_batches,
            int(time.time() - btic),
            loss_rel_val / (i + 1),
        )
        for metric in metric_list:
            metric_name, metric_val = metric.get()
            print_txt += "%s=%.4f " % (metric_name, metric_val)
        logger.info(print_txt)
        btic = time.time()
        loss_rel_val = 0
    return btic, loss_rel_val


for epoch in range(nepoch):
    loss_rel_val = 0
    tic = time.time()
    btic = time.time()
    for metric in metric_list:
        metric.reset()
    if epoch == 0:
        net_trainer_base_lr = net_trainer.learning_rate
        det_trainer_base_lr = det_trainer.learning_rate
    if epoch == 5 or epoch == 8:
        net_trainer.set_learning_rate(net_trainer.learning_rate * 0.1)
        det_trainer.set_learning_rate(det_trainer.learning_rate * 0.1)
    for i, (G_list, img_list) in enumerate(train_data):
        if epoch == 0 and i < args.lr_warmup_iters:
            alpha = i / args.lr_warmup_iters
            warmup_factor = 1 / 3 * (1 - alpha) + alpha
            net_trainer.set_learning_rate(net_trainer_base_lr * warmup_factor)
            det_trainer.set_learning_rate(det_trainer_base_lr * warmup_factor)
        G_list, img_list = get_data_batch(G_list, img_list, ctx)
        if G_list is None or img_list is None:
            btic, loss_rel_val = batch_print(
                epoch,
                i,
                batch_verbose_freq,
                n_batches,
                btic,
                loss_rel_val,
                metric_list,
            )
            continue

        loss = []
        detector_res_list = []
        G_batch = []
        bbox_pad = Pad(axis=(0))
        with mx.autograd.record():
            for G_slice, img in zip(G_list, img_list):
                cur_ctx = img.context
                bbox_list = [G.ndata["bbox"] for G in G_slice]
                bbox_stack = bbox_pad(bbox_list).as_in_context(cur_ctx)
                with mx.autograd.pause():
                    ids, scores, bbox, feat, feat_ind, spatial_feat = detector(
                        img
                    )
                g_pred_batch = build_graph_train(
                    G_slice,
                    bbox_stack,
                    img,
                    ids,
                    scores,
                    bbox,
                    feat_ind,
                    spatial_feat,
                    scores_top_k=300,
                    overlap=False,
                )
                g_batch = l0_sample(g_pred_batch)
                if g_batch is None:
                    continue
                rel_bbox = g_batch.edata["rel_bbox"]
                batch_id = g_batch.edata["batch_id"].asnumpy()
                n_sample_edges = g_batch.number_of_edges()
                n_graph = len(G_slice)
                bbox_rel_list = []
                for j in range(n_graph):
                    eids = np.where(batch_id == j)[0]
                    if len(eids) > 0:
                        bbox_rel_list.append(rel_bbox[eids])
                bbox_rel_stack = bbox_pad(bbox_rel_list).as_in_context(cur_ctx)
                img_size = img.shape[2:4]
                bbox_rel_stack[:, :, 0] *= img_size[1]
                bbox_rel_stack[:, :, 1] *= img_size[0]
                bbox_rel_stack[:, :, 2] *= img_size[1]
                bbox_rel_stack[:, :, 3] *= img_size[0]
                _, _, _, spatial_feat_rel = detector_feat(
                    img, None, None, bbox_rel_stack
                )
                spatial_feat_rel_list = []
                for j in range(n_graph):
                    eids = np.where(batch_id == j)[0]
                    if len(eids) > 0:
                        spatial_feat_rel_list.append(
                            spatial_feat_rel[j, 0 : len(eids)]
                        )
                g_batch.edata["edge_feat"] = nd.concat(
                    *spatial_feat_rel_list, dim=0
                )

                G_batch.append(g_batch)

            G_batch = [net(G) for G in G_batch]

            for G_pred, img in zip(G_batch, img_list):
                if G_pred is None or G_pred.number_of_nodes() == 0:
                    continue
                loss_rel = L_rel(
                    G_pred.edata["preds"],
                    G_pred.edata["rel_class"],
                    G_pred.edata["sample_weights"],
                )
                loss.append(loss_rel.sum())
                loss_rel_val += loss_rel.mean().asscalar() / num_gpus

        if len(loss) == 0:
            btic, loss_rel_val = batch_print(
                epoch,
                i,
                batch_verbose_freq,
                n_batches,
                btic,
                loss_rel_val,
                metric_list,
            )
            continue
        for l in loss:
            l.backward()
        if (i + 1) % per_device_batch_size == 0 or i == n_batches - 1:
            net_trainer.step(args.batch_size)
            det_trainer.step(args.batch_size)
            if aggregate_grad:
                for k, v in net_params.items():
                    v.zero_grad()
                for k, v in det_params.items():
                    v.zero_grad()
        for G_pred, img_slice in zip(G_batch, img_list):
            if G_pred is None or G_pred.number_of_nodes() == 0:
                continue
            link_ind = np.where(G_pred.edata["rel_class"].asnumpy() > 0)[0]
            if len(link_ind) == 0:
                continue
            train_metric.update(
                [G_pred.edata["rel_class"][link_ind]],
                [G_pred.edata["preds"][link_ind]],
            )
            train_metric_top5.update(
                [G_pred.edata["rel_class"][link_ind]],
                [G_pred.edata["preds"][link_ind]],
            )
        btic, loss_rel_val = batch_print(
            epoch,
            i,
            batch_verbose_freq,
            n_batches,
            btic,
            loss_rel_val,
            metric_list,
        )
        if (i + 1) % batch_verbose_freq == 0:
            net.save_parameters("%s/model-%d.params" % (save_dir, epoch))
            detector_feat.features.save_parameters(
                "%s/detector_feat.features-%d.params" % (save_dir, epoch)
            )
    print_txt = "Epoch[%d], time: %d, loss_rel=%.4f," % (
        epoch,
        int(time.time() - tic),
        loss_rel_val / (i + 1),
    )
    for metric in metric_list:
        metric_name, metric_val = metric.get()
        print_txt += "%s=%.4f " % (metric_name, metric_val)
    logger.info(print_txt)
    net.save_parameters("%s/model-%d.params" % (save_dir, epoch))
    detector_feat.features.save_parameters(
        "%s/detector_feat.features-%d.params" % (save_dir, epoch)
    )
