from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import _init_paths
import os
import sys
import numpy as np
import argparse
import pprint
import pdb
import time

import cv2

import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import pickle
from roi_data_layer.roidb import combined_roidb
from roi_data_layer.roibatchLoader import roibatchLoader
from model.utils.config import cfg, cfg_from_file, cfg_from_list, get_output_dir
from model.rpn.bbox_transform import clip_boxes
# from model.nms.nms_wrapper import nms
from model.roi_layers import nms
from model.rpn.bbox_transform import bbox_transform_inv
from model.utils.net_utils import save_net, load_net, vis_detections
from model.VK_faster_rcnn.vgg16 import vgg16
from model.VK_faster_rcnn.resnet import resnet

try:
    xrange  # Python 2
except NameError:
    xrange = range  # Python 3


def parse_args():
    """
    Parse input arguments
    """
    parser = argparse.ArgumentParser(description='Test the JM-SGG model')
    parser.add_argument('--dataset', dest='dataset',
                        help='evaluation dataset',
                        default='visual_genome', type=str)
    parser.add_argument('--process', dest='process',
                        help='whether to use processed dataset',
                        action='store_true', default=False)
    parser.add_argument('--cfg', dest='cfg_file',
                        help='optional config file',
                        default='cfgs/res50.yml', type=str)
    parser.add_argument('--net', dest='net',
                        help='vgg16, res50, res101, res152',
                        default='vgg16', type=str)
    parser.add_argument('--set', dest='set_cfgs',
                        help='set config keys', default=None,
                        nargs=argparse.REMAINDER)
    parser.add_argument('--load_dir', dest='load_dir',
                        help='directory to load models', default="models",
                        type=str)
    parser.add_argument('--model_config', dest='model_config',
                        help='the config of model to load',
                        default='jm_sgg', type=str)
    parser.add_argument('--cuda', dest='cuda',
                        help='whether use CUDA',
                        action='store_true')
    parser.add_argument('--ls', dest='large_scale',
                        help='whether use large imag scale',
                        action='store_true')
    parser.add_argument('--mGPUs', dest='mGPUs',
                        help='whether use multiple GPUs',
                        action='store_true')
    parser.add_argument('--cag', dest='class_agnostic',
                        help='whether perform class_agnostic bbox regression',
                        action='store_true')
    parser.add_argument('--parallel_type', dest='parallel_type',
                        help='which part of model to parallel, 0: all, 1: model before roi pooling',
                        default=0, type=int)
    parser.add_argument('--checksession', dest='checksession',
                        help='checksession to load model',
                        default=1, type=int)
    parser.add_argument('--checkepoch', dest='checkepoch',
                        help='checkepoch to load network',
                        default=30, type=int)
    parser.add_argument('--checkpoint', dest='checkpoint',
                        help='checkpoint to load network',
                        default=48409, type=int)
    parser.add_argument('--vis', dest='vis',
                        help='visualization mode',
                        action='store_true')
    parser.add_argument('--eval_mode', dest='eval_mode',
                        help='the mode of evaluation (sgg, sgc or predcls)',
                        default='predcls', type=str)
    parser.add_argument('--top_k', dest='top_k',
                        help='evaluate top k recall',
                        default=50, type=int)
    parser.add_argument('--gc', dest='gc',
                        help='whether to apply graph constraint',
                        action='store_true', default=False)
    parser.add_argument('--tau', dest='tau',
                        help='the temperature parameter for similarity computation',
                        default=0.1, type=float)
    parser.add_argument('--bias_tau', dest='bias_tau',
                        help='the temperature parameter for biased similarity computation',
                        default=0.1, type=float)
    parser.add_argument('--newcfg', dest='newcfg',
                        help='whether to use new config for anchors',
                        action='store_true', default=False)
    parser.add_argument('--use_mean_field', dest='use_mean_field',
                        help='whether to use mean field inference',
                        action='store_true', default=False)
    parser.add_argument('--num_iter', dest='num_iter',
                        help='the iteration number for factor update',
                        default=2, type=int)
    parser.add_argument('--mean_recall', dest='mean_recall',
                        help='whether to use the mean recall metric',
                        action='store_true', default=False)
    parser.add_argument('--debias', dest='debias',
                        help='whether to derive unbiased prediction',
                        action='store_true', default=False)
    args = parser.parse_args()
    return args


class ProjectNet(torch.nn.Module):
    def __init__(self, in_dim, out_dim):
        super(ProjectNet, self).__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.proj = torch.nn.Sequential(
            torch.nn.Linear(self.in_dim, self.out_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(self.out_dim, self.out_dim)
        )

    def forward(self, x):
        x_proj = self.proj(x)

        return x_proj


def euclid_dist(x, y):
    x_sq = (x ** 2).mean(-1)
    x_sq_ = torch.stack([x_sq] * y.size(0), dim = 2)
    y_sq = (y ** 2).mean(-1)
    y_sq_ = torch.stack([y_sq] * x.size(1), dim = 0)
    y_sq_ = torch.stack([y_sq_] * x.size(0), dim = 0)
    xy = torch.mm(x.view(-1, x.size(-1)), y.t()) / x.size(-1)
    xy = xy.view(x.size(0), x.size(1), y.size(0))
    dist = x_sq_ + y_sq_ - 2 * xy

    return dist


def compute_iou(pos, pred_pos, epsilon = 1e-5):
    min_x = torch.min(pos[0], pred_pos[0])
    min_y = torch.min(pos[1], pred_pos[1])
    max_x = torch.max(pos[2], pred_pos[2])
    max_y = torch.max(pos[3], pred_pos[3])

    union = (max_x - min_x) * (max_y - min_y)

    min_x_ = torch.max(pos[0], pred_pos[0])
    min_y_ = torch.max(pos[1], pred_pos[1])
    max_x_ = torch.min(pos[2], pred_pos[2])
    max_y_ = torch.min(pos[3], pred_pos[3])

    if max_x_ > min_x_ and max_y_ > min_y_:
        intersection = (max_x_ - min_x_) * (max_y_ - min_y_)
    else:
        intersection = 0

    iou = intersection / (union + epsilon)

    return iou


def img_recall(obj_gt_boxes, obj_num_boxes, relation_triples, relation_num, pred_triples, nrelation,
               mean_recall = False, epsilon = 1e-5):
    if mean_recall:
        img_gt_triple_cnt = np.zeros(nrelation)
        img_recall_triple_cnt = np.zeros(nrelation)
    else:
        img_gt_triple_cnt = 0
        img_recall_triple_cnt = 0

    obj_num_boxes = obj_num_boxes.long().item()
    relation_num = relation_num.long().item()
    obj_pos_dict = dict(zip(obj_gt_boxes[:obj_num_boxes, 5].long().tolist(), obj_gt_boxes[:obj_num_boxes, :4]))
    obj_label_dict = dict(zip(obj_gt_boxes[:obj_num_boxes, 5].long().tolist(), obj_gt_boxes[:obj_num_boxes, 4]))

    for relation_id in range(relation_num):
        tmp_triple = relation_triples[relation_id]
        tmp_relation, tmp_head, tmp_tail = tmp_triple

        try:
            tmp_head_pos = obj_pos_dict[tmp_head.long().item()]
            tmp_head_label = obj_label_dict[tmp_head.long().item()]
            tmp_tail_pos = obj_pos_dict[tmp_tail.long().item()]
            tmp_tail_label = obj_label_dict[tmp_tail.long().item()]
        except:
            continue

        if mean_recall:
            img_gt_triple_cnt[tmp_relation.long().item()] += 1
        else:
            img_gt_triple_cnt += 1

        # search for the gt triple in the predicted triple list
        for triple_id in range(pred_triples.shape[0]):
            tmp_pred_triple = pred_triples[triple_id, :]
            tmp_pred_head_pos = tmp_pred_triple[:4]
            tmp_pred_head_label = tmp_pred_triple[4]
            tmp_pred_tail_pos = tmp_pred_triple[5:9]
            tmp_pred_tail_label = tmp_pred_triple[9]
            tmp_pred_relation = tmp_pred_triple[10]

            head_iou = compute_iou(tmp_head_pos, tmp_pred_head_pos)
            tail_iou = compute_iou(tmp_tail_pos, tmp_pred_tail_pos)
            iou_thr = 0.5 if args.eval_mode == 'sgg' else 0.99

            if tmp_pred_head_label == tmp_head_label and head_iou >= iou_thr and \
                    tmp_pred_tail_label == tmp_tail_label and tail_iou >= iou_thr and tmp_pred_relation == tmp_relation:
                if mean_recall:
                    img_recall_triple_cnt[tmp_relation.long().item()] += 1
                else:
                    img_recall_triple_cnt += 1
                break

    if mean_recall:
        img_recall = np.float32(img_recall_triple_cnt) / (np.float32(img_gt_triple_cnt) + epsilon)
    else:
        img_recall = float(img_recall_triple_cnt) / (float(img_gt_triple_cnt) + epsilon)

    return img_recall, img_recall_triple_cnt, img_gt_triple_cnt


if __name__ == '__main__':

    args = parse_args()

    print('Called with args:')
    print(args)

    if torch.cuda.is_available() and not args.cuda:
        print("WARNING: You have a CUDA device, so you should probably run with --cuda")

    np.random.seed(cfg.RNG_SEED)
    if args.dataset == "pascal_voc":
        args.imdb_name = "voc_2007_trainval"
        args.imdbval_name = "voc_2007_test"
        args.set_cfgs = ['ANCHOR_SCALES', '[8, 16, 32]', 'ANCHOR_RATIOS', '[0.5,1,2]']
    elif args.dataset == "pascal_voc_0712":
        args.imdb_name = "voc_2007_trainval+voc_2012_trainval"
        args.imdbval_name = "voc_2007_test"
        args.set_cfgs = ['ANCHOR_SCALES', '[8, 16, 32]', 'ANCHOR_RATIOS', '[0.5,1,2]']
    elif args.dataset == "coco":
        args.imdb_name = "coco_2014_train+coco_2014_valminusminival"
        args.imdbval_name = "coco_2014_minival"
        args.set_cfgs = ['ANCHOR_SCALES', '[4, 8, 16, 32]', 'ANCHOR_RATIOS', '[0.5,1,2]']
    elif args.dataset == "imagenet":
        args.imdb_name = "imagenet_train"
        args.imdbval_name = "imagenet_val"
        args.set_cfgs = ['ANCHOR_SCALES', '[8, 16, 32]', 'ANCHOR_RATIOS', '[0.5,1,2]']
    elif args.dataset == "vg":
        args.imdb_name = "vg_150-50-50_minitrain"
        args.imdbval_name = "vg_150-50-50_minival"
        args.set_cfgs = ['ANCHOR_SCALES', '[4, 8, 16, 32]', 'ANCHOR_RATIOS', '[0.5,1,2]']
    elif args.dataset == "visual_genome":
        if args.process:
            args.imdb_name = "visual_genome_train_process"
            args.imdbval_name = "visual_genome_test_process"
        else:
            args.imdb_name = "visual_genome_train"
            args.imdbval_name = "visual_genome_test"

        if args.newcfg:
            args.set_cfgs = ['ANCHOR_SCALES', '[2.22152954,4.12315647,7.21692515,12.60263013,22.7102731]',
                             'ANCHOR_RATIOS', '[0.23232838,0.63365731,1.28478321,3.15089189]']
        else:
            args.set_cfgs = ['ANCHOR_SCALES', '[4, 8, 16, 32]', 'ANCHOR_RATIOS', '[0.5,1,2]']

    args.cfg_file = "cfgs/{}_ls.yml".format(args.net) if args.large_scale else "cfgs/{}.yml".format(args.net)

    if args.cfg_file is not None:
        cfg_from_file(args.cfg_file)
    if args.set_cfgs is not None:
        cfg_from_list(args.set_cfgs)

    print('Using config:')
    pprint.pprint(cfg)

    cfg.TRAIN.USE_FLIPPED = False
    imdb, roidb, ratio_list, ratio_index = combined_roidb(args.imdbval_name, False)
    imdb.competition_mode(on=True)

    print('{:d} roidb entries'.format(len(roidb)))

    # input_dir = args.load_dir + "/" + args.net + "/" + args.dataset
    input_dir = args.load_dir + "/" + args.net + "/" + args.model_config
    if not os.path.exists(input_dir):
        raise Exception('There is no input directory for loading network from ' + input_dir)
    load_name = os.path.join(input_dir,
                             'faster_rcnn_{}_{}_{}.pth'.format(args.checksession, args.checkepoch, args.checkpoint))

    # initilize the network here.
    if args.net == 'vgg16':
        fasterRCNN = vgg16(imdb.obj_classes, pretrained=False, class_agnostic=args.class_agnostic,
                           mode=args.eval_mode)
    elif args.net == 'res101':
        fasterRCNN = resnet(imdb.obj_classes, 101, pretrained=False, class_agnostic=args.class_agnostic,
                            mode=args.eval_mode)
    elif args.net == 'res50':
        fasterRCNN = resnet(imdb.obj_classes, 50, pretrained=False, class_agnostic=args.class_agnostic,
                            mode=args.eval_mode)
    elif args.net == 'res152':
        fasterRCNN = resnet(imdb.obj_classes, 152, pretrained=False, class_agnostic=args.class_agnostic,
                            mode=args.eval_mode)
    else:
        print("network is not defined")
        pdb.set_trace()

    fasterRCNN.create_architecture()

    # the projection network
    if args.net == 'vgg16':
        emb_dim = 4096
    elif args.net in ['res101', 'res50', 'res152']:
        emb_dim = 2048
    else:
        raise ValueError("network is not defined")

    proj_net = ProjectNet(emb_dim, emb_dim)
    context_proj_net = ProjectNet(emb_dim, emb_dim)

    print("load checkpoint %s" % (load_name))
    checkpoint = torch.load(load_name)
    fasterRCNN.load_state_dict(checkpoint['model'])
    proj_net.load_state_dict(checkpoint['proj_net'])
    context_proj_net.load_state_dict(checkpoint['context_proj_net'])
    entity_emb = checkpoint['entity_emb']
    relation_emb = checkpoint['relation_emb']
    entity_relation_cnt = checkpoint['entity_relation_cnt']
    entity_relation_prob = checkpoint['entity_relation_prob']
    if 'pooling_mode' in checkpoint.keys():
        cfg.POOLING_MODE = checkpoint['pooling_mode']

    if args.cuda:
        fasterRCNN.cuda()
        proj_net.cuda()
        context_proj_net.cuda()
        entity_emb = entity_emb.cuda()
        relation_emb = relation_emb.cuda()
        entity_relation_cnt = entity_relation_cnt.cuda()
        entity_relation_prob = entity_relation_prob.cuda()

    entity_emb = Variable(entity_emb)
    entity_emb_proj = proj_net(entity_emb)
    relation_emb = Variable(relation_emb)
    entity_relation_cnt = Variable(entity_relation_cnt)
    entity_relation_prob = Variable(entity_relation_prob)

    print('load model successfully!')

    # initilize the tensor holder here.
    im_data = torch.FloatTensor(1)
    im_info = torch.FloatTensor(1)
    obj_num_boxes = torch.FloatTensor(1)
    obj_gt_boxes = torch.FloatTensor(1)
    region_num_boxes = torch.FloatTensor(1)
    region_gt_boxes = torch.FloatTensor(1)
    relation_triples = torch.FloatTensor(1)
    relation_num = torch.FloatTensor(1)

    # ship to cuda
    if args.cuda:
        im_data = im_data.cuda()
        im_info = im_info.cuda()
        obj_num_boxes = obj_num_boxes.cuda()
        obj_gt_boxes = obj_gt_boxes.cuda()
        region_num_boxes = region_num_boxes.cuda()
        region_gt_boxes = region_gt_boxes.cuda()
        relation_triples = relation_triples.cuda()
        relation_num = relation_num.cuda()

    # make variable
    im_data = Variable(im_data)
    im_info = Variable(im_info)
    obj_num_boxes = Variable(obj_num_boxes)
    obj_gt_boxes = Variable(obj_gt_boxes)
    region_num_boxes = Variable(region_num_boxes)
    region_gt_boxes = Variable(region_gt_boxes)
    relation_triples = Variable(relation_triples)
    relation_num = Variable(relation_num)

    if args.cuda:
        cfg.CUDA = True

    start_time = time.time()
    save_name = 'faster_rcnn_' + str(args.checkepoch)
    num_images = len(imdb.image_index)
    all_triples = []

    output_dir = get_output_dir(imdb, save_name)
    dataset = roibatchLoader(roidb, ratio_list, ratio_index, 1, \
                             imdb.num_obj_classes, training=False, normalize=False)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=1,
                                             shuffle=False, num_workers=0,
                                             pin_memory=True)

    _t = {'im_detect': time.time(), 'misc': time.time()}
    det_file = os.path.join(output_dir, 'predicted_triples.pkl')

    fasterRCNN.eval()
    data_iter = iter(dataloader)

    if args.mean_recall:
        gt_triple_cnt = np.zeros(imdb.num_relation_classes)
        recall_triple_cnt = np.zeros(imdb.num_relation_classes)
        recall_cnt = np.zeros(imdb.num_relation_classes)
        skip_cnt = 0
    else:
        gt_triple_cnt = 0
        recall_triple_cnt = 0
        recall_cnt = 0
        skip_cnt = 0

    for i in range(num_images):

        data = next(data_iter)
        with torch.no_grad():
            im_data.resize_(data[0].size()).copy_(data[0])
            im_info.resize_(data[1].size()).copy_(data[1])
            obj_gt_boxes.resize_(data[2].size()).copy_(data[2])
            obj_num_boxes.resize_(data[3].size()).copy_(data[3])
            region_gt_boxes.resize_(data[4].size()).copy_(data[4])
            region_num_boxes.resize_(data[5].size()).copy_(data[5])
            relation_triples.resize_(data[6].size()).copy_(data[6])
            relation_num.resize_(data[7].size()).copy_(data[7])

        det_tic = time.time()
        rois, cls_prob, bbox_pred, \
        rpn_loss_cls, rpn_loss_box, \
        RCNN_loss_cls, RCNN_loss_bbox, rois_label, \
        pred_feat, _pred_label, _pred_conf, pred_roi, _context_feat = fasterRCNN(im_data, im_info, obj_gt_boxes,
                                                                                 obj_num_boxes, use_context = True)
        context_feat = context_proj_net(_context_feat)

        if args.use_mean_field:
            # initialize the factors for objects
            all_energy = ((pred_feat.unsqueeze(1) - entity_emb.unsqueeze(0)) ** 2).mean(-1)
            all_exp_energy = torch.exp(-all_energy / args.tau)   # N_b x N_o

            if args.eval_mode == 'predcls':
                pred_label = _pred_label
                pred_conf = _pred_conf
                inst_prob = torch.zeros((pred_feat.shape[0], entity_emb.shape[0]), device=pred_feat.device)   # N_b x N_o
                for inst_id in range(inst_prob.shape[0]):
                    inst_prob[inst_id, pred_label.long()[inst_id]] = pred_conf.float()[inst_id]
            else:
                # predict the density of all instances
                inst_prob = all_exp_energy / all_exp_energy.sum(-1).unsqueeze(-1)   # N_b x N_o

            # project instance embeddings
            pred_feat_proj = proj_net(pred_feat)

            # initialize the factors for relations
            emb_diff = pred_feat_proj.unsqueeze(1) - pred_feat_proj.unsqueeze(0)
            all_inst_relation_energy = euclid_dist(emb_diff, relation_emb)  # N_b x N_b x (N_r + 1)
            context_feat = context_feat.view(pred_feat.shape[0], pred_feat.shape[0], -1)
            all_context_relation_energy = euclid_dist(context_feat, relation_emb)  # N_b x N_b x (N_r + 1)
            allr_exp_energy = torch.exp(-(all_inst_relation_energy + all_context_relation_energy) / args.tau)
            relation_prob = allr_exp_energy / allr_exp_energy.sum(-1).unsqueeze(-1)   # N_b x N_b x (N_r + 1)

            # compute all energies for ternary labels
            cls_emb_diff = entity_emb_proj.unsqueeze(1) - entity_emb_proj.unsqueeze(0)
            all_cls_relation_energy = euclid_dist(cls_emb_diff, relation_emb)  # N_o x N_o x (N_r + 1)

            # factor update
            try:
                inst_prob = inst_prob.cpu()
                relation_prob = relation_prob.cpu()
                all_cls_relation_energy = all_cls_relation_energy.cpu()
                all_energy = all_energy.cpu()
                all_inst_relation_energy = all_inst_relation_energy.cpu()
                all_context_relation_energy = all_context_relation_energy.cpu()

                for iter_id in range(args.num_iter):
                    all_weighted_energy = all_cls_relation_energy.unsqueeze(0).unsqueeze(0) * \
                                          relation_prob.unsqueeze(2).unsqueeze(2) * inst_prob.unsqueeze(0).unsqueeze(2) * \
                                          inst_prob.unsqueeze(1).unsqueeze(3)  # N_b x N_b x N_o x N_o x (N_r + 1)

                    # update object factors
                    if not args.eval_mode == 'predcls':
                        allo_energy = all_energy + all_weighted_energy.sum(1).sum(2).sum(-1) + \
                                      all_weighted_energy.sum(0).sum(1).sum(-1)   # N_b x N_o
                        allo_exp_energy = torch.exp(-allo_energy / args.tau)   # N_b x N_o
                        inst_prob = allo_exp_energy / allo_exp_energy.sum(-1).unsqueeze(-1)  # N_b x N_o

                    # update relation factors
                    allr_energy_ = all_inst_relation_energy + all_context_relation_energy + \
                                  all_weighted_energy.sum(2).sum(2)   # N_b x N_b x (N_r + 1)
                    allr_exp_energy_ = torch.exp(-allr_energy_ / args.tau)   # N_b x N_b x (N_r + 1)
                    relation_prob = allr_exp_energy_ / allr_exp_energy_.sum(-1).unsqueeze(-1)   # N_b x N_b x (N_r + 1)

                inst_prob = inst_prob.to(pred_feat.device)
                relation_prob = relation_prob.to(pred_feat.device)
            except:
                inst_prob = inst_prob.to(pred_feat.device)
                relation_prob = relation_prob.to(pred_feat.device)
                print('Skip update for large memory need')

            # get the confidence for each candidate triple
            relation_prob = relation_prob[:, :, :-1]  # N_b x N_b x N_r
            if not args.eval_mode == 'predcls':
                pred_conf, pred_label = torch.max(inst_prob, dim=1)

            pair_conf = pred_conf.unsqueeze(1) * pred_conf.unsqueeze(0)   # N_b x N_b
            conf = pair_conf.unsqueeze(-1) * relation_prob   # N_b x N_b x N_r
        else:
            if args.eval_mode == 'predcls':
                pred_label = _pred_label
                pred_conf = _pred_conf
                pred_cls_feat = entity_emb[pred_label.long(), :]
            else:
                # predict the density of all instances
                all_energy = ((pred_feat.unsqueeze(1) - entity_emb.unsqueeze(0)) ** 2).mean(-1)
                all_exp_energy = torch.exp(-all_energy / args.tau)
                inst_prob = all_exp_energy / all_exp_energy.sum(-1).unsqueeze(-1)   # N_b x N_o
                pred_conf, pred_label = torch.max(inst_prob, dim=1)
                pred_cls_feat = entity_emb[pred_label.long(), :]

            # project instance and entity embeddings
            pred_feat_proj = proj_net(pred_feat)
            pred_cls_feat_proj = proj_net(pred_cls_feat)

            try:
                # compute the pair-wise similarity with relationship embeddings
                emb_diff = pred_feat_proj.unsqueeze(1) - pred_feat_proj.unsqueeze(0)
                cls_emb_diff = pred_cls_feat_proj.unsqueeze(1) - pred_cls_feat_proj.unsqueeze(0)
                dist = euclid_dist(emb_diff, relation_emb)
                cls_dist = euclid_dist(cls_emb_diff, relation_emb)
                context_feat = context_feat.view(pred_feat.shape[0], pred_feat.shape[0], -1)
                context_dist = euclid_dist(context_feat, relation_emb)
                sim = torch.exp(-(dist + cls_dist + context_dist) / args.tau)   # N_b x N_b x (N_r + 1)
                sim = sim / sim.sum(-1).unsqueeze(-1)
                sim = sim[:, :, :-1]   # N_b x N_b x N_r
            except:
                skip_cnt += 1
                print('Skip the sample with large memory need')
                continue

            # get the confidence for each candidate triple
            pair_conf = pred_conf.unsqueeze(1) * pred_conf.unsqueeze(0)
            conf = pair_conf.unsqueeze(-1) * sim   # N_b x N_b x N_r

        # predict based on dataset bias
        if args.debias:
            # project entity embeddings
            pred_cls_feat = entity_emb[pred_label.long(), :]
            pred_cls_feat_proj = proj_net(pred_cls_feat)

            try:
                # compute the biased pair-wise similarity with relationship embeddings
                cls_emb_diff = pred_cls_feat_proj.unsqueeze(1) - pred_cls_feat_proj.unsqueeze(0)
                cls_dist = euclid_dist(cls_emb_diff, relation_emb)
                bias_sim = torch.exp(-cls_dist / args.bias_tau)  # N_b x N_b x (N_r + 1)
                bias_sim = bias_sim / bias_sim.sum(-1).unsqueeze(-1)
                bias_sim = bias_sim[:, :, :-1]   # N_b x N_b x N_r
            except:
                skip_cnt += 1
                print('Skip the sample with large memory need')
                continue

            # get the biased confidence for each candidate triple
            bias_conf = pair_conf.unsqueeze(-1) * bias_sim   # N_b x N_b x N_r

            # compute unbiased confidence
            # pdb.set_trace()
            conf = conf - bias_conf   # N_b x N_b x N_r

        # get the top k triples
        list_len = min(args.top_k, conf.view(-1).shape[0])
        img_triples = torch.zeros((args.top_k, 11)).to(im_data.device)
        num_box = conf.shape[0]
        num_relation = conf.shape[-1]

        tail_indices = torch.arange(num_box).to(im_data.device)
        tail_indices = torch.stack([tail_indices] * num_box, dim=1)
        tail_indices = torch.stack([tail_indices] * num_relation, dim=2)
        tail_indices = tail_indices.view(-1)

        head_indices = torch.arange(num_box).to(im_data.device)
        head_indices = torch.stack([head_indices] * num_box, dim=0)
        head_indices = torch.stack([head_indices] * num_relation, dim=2)
        head_indices = head_indices.view(-1)

        relation_indices = torch.arange(num_relation).to(im_data.device)
        relation_indices = torch.stack([relation_indices] * num_box, dim=0)
        relation_indices = torch.stack([relation_indices] * num_box, dim=0)
        relation_indices = relation_indices.view(-1)

        conf = conf.view(-1)
        _, order = torch.sort(conf, 0, True)
        # apply the graph constraint: one predicate for each head-tail pair
        if args.gc:
            use_flag = torch.zeros((num_box, num_box)).to(im_data.device)
            for order_id in range(order.shape[0]):
                tmp_idx = order[order_id].long().item()
                tmp_head_idx = head_indices[tmp_idx].long().item()
                tmp_tail_idx = tail_indices[tmp_idx].long().item()

                if use_flag[tmp_head_idx, tmp_tail_idx] == 1:
                    order[order_id] = -1
                else:
                    use_flag[tmp_head_idx, tmp_tail_idx] = 1

            keep_order = torch.nonzero((order >= 0).float()).squeeze(-1)
            order = order[keep_order]
            list_len = min(list_len, order.shape[0])

        keep = order[:list_len].long()

        keep_head_indices = head_indices[keep].long()
        head_positions = pred_roi[keep_head_indices, 1:5]
        head_labels = pred_label[keep_head_indices]
        img_triples[:list_len, :4] = head_positions
        img_triples[:list_len, 4] = head_labels

        keep_tail_indices = tail_indices[keep].long()
        tail_positions = pred_roi[keep_tail_indices, 1:5]
        tail_labels = pred_label[keep_tail_indices]
        img_triples[:list_len, 5:9] = tail_positions
        img_triples[:list_len, 9] = tail_labels

        keep_relation_indices = relation_indices[keep]
        img_triples[:list_len, 10] = keep_relation_indices
        all_triples.append(img_triples)

        # compute per image recall
        tmp_recall_cnt, tmp_recall_triple_cnt, tmp_gt_triple_cnt = img_recall(obj_gt_boxes[0, ...], obj_num_boxes[0],
                                                                              relation_triples[0, ...], relation_num[0],
                                                                              img_triples, imdb.num_relation_classes,
                                                                              mean_recall=args.mean_recall)
        if args.mean_recall:
            if np.sum(tmp_gt_triple_cnt) != 0:
                recall_cnt += tmp_recall_cnt
                recall_triple_cnt += tmp_recall_triple_cnt
                gt_triple_cnt += tmp_gt_triple_cnt
        else:
            if tmp_gt_triple_cnt != 0:
                recall_cnt += tmp_recall_cnt
                recall_triple_cnt += tmp_recall_triple_cnt
                gt_triple_cnt += tmp_gt_triple_cnt

        curr_time = time.time()
        pass_time = curr_time - start_time

        sys.stdout.write('im_detect: {:d}/{:d} {:.3f}s \r' \
                         .format(i + 1, num_images, pass_time))
        sys.stdout.flush()

    all_triples = torch.stack(all_triples, dim=0)
    torch.save({'all_triples': all_triples.cpu().detach()}, det_file)
    print('Predicted triples save to: ', det_file)
    print('')

    if args.mean_recall:
        avg_set_recall = np.mean(recall_triple_cnt / (gt_triple_cnt + 1e-5))
        print('Set recall: ', recall_triple_cnt / (gt_triple_cnt + 1e-5))
        print('Average set recall: ', avg_set_recall)
    else:
        avg_img_recall = float(recall_cnt) / float(num_images - skip_cnt)
        avg_set_recall = float(recall_triple_cnt) / float(gt_triple_cnt)
        print('Average per-image recall: ', avg_img_recall)
        print('Average set recall: ', avg_set_recall)

    os.system('watch nvidia-smi')