from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import _init_paths
import os
import sys
import numpy as np
import argparse
import pprint
import pdb
import time
from operator import itemgetter

import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.optim as optim

import torchvision.transforms as transforms
from torch.utils.data.sampler import Sampler
from torch.distributions.categorical import Categorical

from roi_data_layer.roidb import combined_roidb
from roi_data_layer.roibatchLoader import roibatchLoader
from model.utils.config import cfg, cfg_from_file, cfg_from_list, get_output_dir
from model.utils.net_utils import weights_normal_init, save_net, load_net, \
    adjust_learning_rate, save_checkpoint, clip_gradient

from model.VK_faster_rcnn.vgg16 import vgg16
from model.VK_faster_rcnn.resnet import resnet


def parse_args():
    """
    Parse input arguments
    """
    parser = argparse.ArgumentParser(description='Train the JM-SGG model')
    parser.add_argument('--dataset', dest='dataset',
                        help='training dataset',
                        default='visual_genome', type=str)
    parser.add_argument('--process', dest='process',
                        help='whether to use processed dataset',
                        action='store_true', default=False)
    parser.add_argument('--net', dest='net',
                        help='vgg16, res50, res101',
                        default='vgg16', type=str)
    parser.add_argument('--start_epoch', dest='start_epoch',
                        help='starting epoch',
                        default=1, type=int)
    parser.add_argument('--epochs', dest='max_epochs',
                        help='number of epochs to train',
                        default=30, type=int)
    parser.add_argument('--pretrain_epochs', dest='pretrain_epochs',
                        help='the number of epochs to pre-train the model',
                        default=20, type=int)
    parser.add_argument('--disp_interval', dest='disp_interval',
                        help='number of iterations to display',
                        default=100, type=int)
    parser.add_argument('--checkpoint_interval', dest='checkpoint_interval',
                        help='number of iterations to display',
                        default=10000, type=int)

    parser.add_argument('--save_dir', dest='save_dir',
                        help='directory to save models', default="models",
                        type=str)
    parser.add_argument('--model_config', dest='model_config',
                        help='the config of model to save',
                        default='jm_sgg', type=str)
    parser.add_argument('--nw', dest='num_workers',
                        help='number of worker to load data',
                        default=0, type=int)
    parser.add_argument('--cuda', dest='cuda',
                        help='whether use CUDA',
                        action='store_true')
    parser.add_argument('--ls', dest='large_scale',
                        help='whether use large imag scale',
                        action='store_true')
    parser.add_argument('--mGPUs', dest='mGPUs',
                        help='whether use multiple GPUs',
                        action='store_true')
    parser.add_argument('--bs', dest='batch_size',
                        help='batch_size',
                        default=4, type=int)
    parser.add_argument('--cag', dest='class_agnostic',
                        help='whether perform class_agnostic bbox regression',
                        action='store_true')

    # config optimization
    parser.add_argument('--o', dest='optimizer',
                        help='training optimizer',
                        default="sgd", type=str)
    parser.add_argument('--lr', dest='lr',
                        help='starting learning rate',
                        default=0.001, type=float)
    parser.add_argument('--lr_decay_step', dest='lr_decay_step',
                        help='step to do learning rate decay, unit is epoch',
                        default=10, type=int)
    parser.add_argument('--lr_decay_gamma', dest='lr_decay_gamma',
                        help='learning rate decay ratio',
                        default=0.1, type=float)

    # set training session
    parser.add_argument('--s', dest='session',
                        help='training session',
                        default=1, type=int)

    # resume trained model
    parser.add_argument('--r', dest='resume',
                        help='resume checkpoint or not',
                        default=False, type=bool)
    parser.add_argument('--checksession', dest='checksession',
                        help='checksession to load model',
                        default=1, type=int)
    parser.add_argument('--checkepoch', dest='checkepoch',
                        help='checkepoch to load model',
                        default=1, type=int)
    parser.add_argument('--checkpoint', dest='checkpoint',
                        help='checkpoint to load model',
                        default=0, type=int)

    # log and diaplay
    parser.add_argument('--use_tfb', dest='use_tfboard',
                        help='whether use tensorboard',
                        action='store_true')

    # visual knowledge training
    parser.add_argument('--tau', dest='tau',
                        help='the temperature parameter',
                        default=0.1, type=float)
    parser.add_argument('--beta', dest='beta',
                        help='the exponential moving average ratio',
                        default=0.8, type=float)
    parser.add_argument('--likelihood_weight', dest='likelihood_weight',
                        help='the weight of maximum likelihood learning',
                        default=1.0, type=float)
    parser.add_argument('--neg_size', dest='neg_size',
                        help='the negative sampling size',
                        default=3, type=int)
    parser.add_argument('--bias_box', dest='bias_box',
                        help='whether to use biased gt boxes during training',
                        action='store_true', default=False)
    parser.add_argument('--newcfg', dest='newcfg',
                        help='whether to use new config for anchors',
                        action='store_true', default=False)

    args = parser.parse_args()
    return args


class sampler(Sampler):
    def __init__(self, train_size, batch_size):
        self.num_data = train_size
        self.num_per_batch = int(train_size / batch_size)
        self.batch_size = batch_size
        self.range = torch.arange(0, batch_size).view(1, batch_size).long()
        self.leftover_flag = False
        if train_size % batch_size:
            self.leftover = torch.arange(self.num_per_batch * batch_size, train_size).long()
            self.leftover_flag = True

    def __iter__(self):
        rand_num = torch.randperm(self.num_per_batch).view(-1, 1) * self.batch_size
        self.rand_num = rand_num.expand(self.num_per_batch, self.batch_size) + self.range

        self.rand_num_view = self.rand_num.view(-1)

        if self.leftover_flag:
            self.rand_num_view = torch.cat((self.rand_num_view, self.leftover), 0)

        return iter(self.rand_num_view)

    def __len__(self):
        return self.num_data

class ProjectNet(torch.nn.Module):
    def __init__(self, in_dim, out_dim):
        super(ProjectNet, self).__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.proj = torch.nn.Sequential(
            torch.nn.Linear(self.in_dim, self.out_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(self.out_dim, self.out_dim)
        )

    def forward(self, x):
        x_proj = self.proj(x)

        return x_proj

# Update the entity/relation embeddings with instance embeddings
def update_emb(feats, labels, emb, beta=0.8, epsilon=1e-6):
    if labels.shape[0] == 0:
        batch_mean = torch.zeros_like(emb)
    else:
        onehot_label = torch.zeros((labels.shape[0], emb.shape[0])).scatter_(1, labels.long().unsqueeze(
            -1).cpu(), 1).float().to(emb.device)
        batch_feat = feats.unsqueeze(1) * onehot_label.unsqueeze(-1)
        batch_mean = batch_feat.sum(0) / (onehot_label.unsqueeze(-1).sum(0) + epsilon)

    batch_mask = (batch_mean.sum(-1) != 0).float().unsqueeze(-1)
    emb = emb.detach() * (1 - batch_mask) + (
            emb.detach() * beta + batch_mean * (1 - beta)) * batch_mask

    emb = emb.detach()
    emb.requires_grad = True
    return emb

# the likelihood loss defined upon whole scene graphs
def likelihood_loss(gt_feat, gt_label, obj_img_num, head_feat_proj, tail_feat_proj, head_cls_feat_proj,
                    tail_cls_feat_proj, context_feat, relation_label, relation_img_num, entity_emb, entity_emb_proj,
                    relation_emb, gt_id, relations, tau=0.1, neg_size=1, epsilon=1e-6):
    # compute the energy of all objects
    obj_energy = ((gt_feat - entity_emb[gt_label.long()]) ** 2).mean(-1)

    # compute the energy of all relations
    context_energy = ((context_feat - relation_emb[relation_label.long()]) ** 2).mean(-1)
    obj_relation_energy = ((head_feat_proj + relation_emb[relation_label.long()] - tail_feat_proj) ** 2).mean(-1)
    cls_relation_energy = ((head_cls_feat_proj + relation_emb[relation_label.long()] - tail_cls_feat_proj) ** 2).mean(-1)
    relation_energy = context_energy + obj_relation_energy + cls_relation_energy

    # compute the object energies within each scene graph
    img_obj_energy_list = list()
    obj_num_cnt = 0
    for img_id in range(obj_img_num.shape[0]):
        tmp_obj_num = obj_img_num[img_id].long().item()
        tmp_img_obj_energy = obj_energy[obj_num_cnt:(obj_num_cnt + tmp_obj_num)].sum()
        img_obj_energy_list.append(tmp_img_obj_energy)
        obj_num_cnt += tmp_obj_num

    img_obj_energy = torch.stack(img_obj_energy_list, dim = 0)

    # compute the relation energies within each scene graph
    img_relation_energy_list = list()
    relation_num_cnt = 0
    for img_id in range(len(relation_img_num)):
        tmp_relation_num = relation_img_num[img_id]
        tmp_img_relation_energy = relation_energy[relation_num_cnt:(relation_num_cnt + tmp_relation_num)].sum()
        img_relation_energy_list.append(tmp_img_relation_energy)
        relation_num_cnt += tmp_relation_num

    img_relation_energy = torch.stack(img_relation_energy_list, dim = 0)

    # compute the likelihood of positive scene graph
    exp_energy = torch.exp(-(img_obj_energy + img_relation_energy) / tau)

    # define the variational distribution
    allo_energy = ((gt_feat.unsqueeze(1) - entity_emb.unsqueeze(0)) ** 2).mean(-1)
    allo_exp_energy = torch.exp(-allo_energy / tau)
    o_prob = allo_exp_energy / (allo_exp_energy.sum(-1).unsqueeze(-1) + epsilon)
    o_mask = (allo_exp_energy.sum(-1).unsqueeze(-1) == 0).float()
    o_prob = o_prob * (1 - o_mask) + o_mask / o_prob.shape[1]
    o_distribution = Categorical(o_prob)

    all_context_energy = ((context_feat.unsqueeze(1) - relation_emb.unsqueeze(0)) ** 2).mean(-1)
    all_trans_energy = (
            (head_feat_proj.unsqueeze(1) + relation_emb.unsqueeze(0) - tail_feat_proj.unsqueeze(1)) ** 2).mean(-1)
    allr_exp_energy = torch.exp(-(all_context_energy + all_trans_energy) / tau)
    r_prob = allr_exp_energy / (allr_exp_energy.sum(-1).unsqueeze(-1) + epsilon)
    r_mask = (allr_exp_energy.sum(-1).unsqueeze(-1) == 0).float()
    r_prob = r_prob * (1 - r_mask) + r_mask / r_prob.shape[1]
    r_distribution = Categorical(r_prob)

    # negative sampling from variational distribution
    neg_exp_energy_cnt = torch.zeros_like(exp_energy).to(exp_energy.device)
    for neg_idx in range(neg_size):
        # sample objects and relations from variational distribution
        o_index = o_distribution.sample().long()
        r_index = r_distribution.sample().long()

        # get the head and tail index
        label_dict = dict(zip(gt_id.long().tolist(), o_index.long().tolist()))
        h_index_list = list()
        t_index_list = list()

        for relation_id in range(relations.shape[0]):
            tmp_relation_index = relations[relation_id][0].long().item()
            tmp_subject_id = relations[relation_id][1].long().item()
            tmp_object_id = relations[relation_id][2].long().item()

            if tmp_subject_id in label_dict and tmp_object_id in label_dict:
                h_index_list.append(label_dict[tmp_subject_id])
                t_index_list.append(label_dict[tmp_object_id])

        h_index = torch.tensor(h_index_list).long().to(r_index.device)
        t_index = torch.tensor(t_index_list).long().to(r_index.device)

        # compute the object and relation energy for negative scene graph
        neg_obj_energy = ((gt_feat - entity_emb[o_index]) ** 2).mean(-1)

        neg_context_energy = ((context_feat - relation_emb[r_index]) ** 2).mean(-1)
        neg_obj_relation_energy = ((head_feat_proj + relation_emb[r_index] - tail_feat_proj) ** 2).mean(-1)
        neg_cls_relation_energy = (
                (entity_emb_proj[h_index] + relation_emb[r_index] - entity_emb_proj[t_index]) ** 2).mean(-1)
        neg_relation_energy = neg_context_energy + neg_obj_relation_energy + neg_cls_relation_energy

        # compute the object energies within each negative scene graph
        neg_img_obj_energy_list = list()
        obj_num_cnt = 0
        for img_id in range(obj_img_num.shape[0]):
            tmp_obj_num = obj_img_num[img_id].long().item()
            tmp_neg_img_obj_energy = neg_obj_energy[obj_num_cnt:(obj_num_cnt + tmp_obj_num)].sum()
            neg_img_obj_energy_list.append(tmp_neg_img_obj_energy)
            obj_num_cnt += tmp_obj_num

        neg_img_obj_energy = torch.stack(neg_img_obj_energy_list, dim=0)

        # compute the relation energies within each negative scene graph
        neg_img_relation_energy_list = list()
        relation_num_cnt = 0
        for img_id in range(len(relation_img_num)):
            tmp_relation_num = relation_img_num[img_id]
            tmp_neg_img_relation_energy = neg_relation_energy[relation_num_cnt:(relation_num_cnt + tmp_relation_num)].sum()
            neg_img_relation_energy_list.append(tmp_neg_img_relation_energy)
            relation_num_cnt += tmp_relation_num

        neg_img_relation_energy = torch.stack(neg_img_relation_energy_list, dim=0)

        # compute the likelihood of negative scene graph
        neg_exp_energy = torch.exp(-(neg_img_obj_energy + neg_img_relation_energy) / tau)
        neg_exp_energy_cnt += neg_exp_energy

    neg_exp_energy_mean = neg_exp_energy_cnt / neg_size

    # define the likelihood loss
    pos_ratio = exp_energy / (exp_energy + neg_exp_energy_mean + epsilon)
    if (pos_ratio < epsilon).float().sum() == 0:
        likelihood_loss = -torch.log(pos_ratio).sum() / pos_ratio.shape[0]
    else:
        likelihood_loss = torch.zeros_like(pos_ratio[0])

    return likelihood_loss

if __name__ == '__main__':

    args = parse_args()

    print('Called with args:')
    print(args)

    if args.dataset == "pascal_voc":
        args.imdb_name = "voc_2007_trainval"
        args.imdbval_name = "voc_2007_test"
        args.set_cfgs = ['ANCHOR_SCALES', '[8, 16, 32]', 'ANCHOR_RATIOS', '[0.5,1,2]', 'MAX_NUM_GT_BOXES', '20']
    elif args.dataset == "pascal_voc_0712":
        args.imdb_name = "voc_2007_trainval+voc_2012_trainval"
        args.imdbval_name = "voc_2007_test"
        args.set_cfgs = ['ANCHOR_SCALES', '[8, 16, 32]', 'ANCHOR_RATIOS', '[0.5,1,2]', 'MAX_NUM_GT_BOXES', '20']
    elif args.dataset == "coco":
        args.imdb_name = "coco_2014_train+coco_2014_valminusminival"
        args.imdbval_name = "coco_2014_minival"
        args.set_cfgs = ['ANCHOR_SCALES', '[4, 8, 16, 32]', 'ANCHOR_RATIOS', '[0.5,1,2]', 'MAX_NUM_GT_BOXES', '50']
    elif args.dataset == "imagenet":
        args.imdb_name = "imagenet_train"
        args.imdbval_name = "imagenet_val"
        args.set_cfgs = ['ANCHOR_SCALES', '[4, 8, 16, 32]', 'ANCHOR_RATIOS', '[0.5,1,2]', 'MAX_NUM_GT_BOXES', '30']
    elif args.dataset == "vg":
        # train sizes: train, smalltrain, minitrain
        # train scale: ['150-50-20', '150-50-50', '500-150-80', '750-250-150', '1750-700-450', '1600-400-20']
        args.imdb_name = "vg_150-50-50_minitrain"
        args.imdbval_name = "vg_150-50-50_minival"
        args.set_cfgs = ['ANCHOR_SCALES', '[4, 8, 16, 32]', 'ANCHOR_RATIOS', '[0.5,1,2]', 'MAX_NUM_GT_BOXES', '50']
    elif args.dataset == "visual_genome":
        if args.process:
            args.imdb_name = "visual_genome_train_process"
            args.imdbval_name = "visual_genome_test_process"
        else:
            args.imdb_name = "visual_genome_train"
            args.imdbval_name = "visual_genome_test"

        if args.newcfg:
            args.set_cfgs = ['ANCHOR_SCALES', '[2.22152954,4.12315647,7.21692515,12.60263013,22.7102731]',
                             'ANCHOR_RATIOS', '[0.23232838,0.63365731,1.28478321,3.15089189]', 'MAX_NUM_GT_BOXES', '50']
        else:
            args.set_cfgs = ['ANCHOR_SCALES', '[4, 8, 16, 32]', 'ANCHOR_RATIOS', '[0.5,1,2]', 'MAX_NUM_GT_BOXES', '50']

    args.cfg_file = "cfgs/{}_ls.yml".format(args.net) if args.large_scale else "cfgs/{}.yml".format(args.net)

    if args.cfg_file is not None:
        cfg_from_file(args.cfg_file)
    if args.set_cfgs is not None:
        cfg_from_list(args.set_cfgs)

    print('Using config:')
    pprint.pprint(cfg)
    np.random.seed(cfg.RNG_SEED)

    # torch.backends.cudnn.benchmark = True
    if torch.cuda.is_available() and not args.cuda:
        print("WARNING: You have a CUDA device, so you should probably run with --cuda")

    # train set
    # -- Note: Use validation set and disable the flipped to enable faster loading.
    cfg.TRAIN.USE_FLIPPED = True
    cfg.USE_GPU_NMS = args.cuda
    imdb, roidb, ratio_list, ratio_index = combined_roidb(args.imdb_name)
    train_size = len(roidb)

    print('{:d} roidb entries'.format(len(roidb)))

    # output_dir = args.save_dir + "/" + args.net + "/" + args.dataset
    output_dir = args.save_dir + "/" + args.net + "/" + args.model_config
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    sampler_batch = sampler(train_size, args.batch_size)

    dataset = roibatchLoader(roidb, ratio_list, ratio_index, args.batch_size, \
                             imdb.num_obj_classes, training=True)

    dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size,
                                             sampler=sampler_batch, num_workers=args.num_workers)

    # initilize the tensor holder here.
    # currently, the region bounding boxes are not used during training
    im_data = torch.FloatTensor(1)
    im_info = torch.FloatTensor(1)
    obj_num_boxes = torch.FloatTensor(1)
    obj_gt_boxes = torch.FloatTensor(1)
    region_num_boxes = torch.FloatTensor(1)
    region_gt_boxes = torch.FloatTensor(1)
    relation_triples = torch.FloatTensor(1)
    relation_num = torch.FloatTensor(1)

    # shift to cuda
    if args.cuda:
        im_data = im_data.cuda()
        im_info = im_info.cuda()
        obj_num_boxes = obj_num_boxes.cuda()
        obj_gt_boxes = obj_gt_boxes.cuda()
        region_num_boxes = region_num_boxes.cuda()
        region_gt_boxes = region_gt_boxes.cuda()
        relation_triples = relation_triples.cuda()
        relation_num = relation_num.cuda()

    # make variable
    im_data = Variable(im_data)
    im_info = Variable(im_info)
    obj_num_boxes = Variable(obj_num_boxes)
    obj_gt_boxes = Variable(obj_gt_boxes)
    region_num_boxes = Variable(region_num_boxes)
    region_gt_boxes = Variable(region_gt_boxes)
    relation_triples = Variable(relation_triples)
    relation_num = Variable(relation_num)

    # initialize the knowledge base
    if args.net == 'vgg16':
        emb_dim = 4096
    elif args.net in ['res101', 'res50', 'res152']:
        emb_dim = 2048
    else:
        raise ValueError("network is not defined")

    if args.resume:
        load_name = os.path.join(output_dir,
                                 'faster_rcnn_{}_{}_{}.pth'.format(args.checksession, args.checkepoch, args.checkpoint))
        print("loading checkpoint %s" % (load_name))
        checkpoint = torch.load(load_name)
        entity_emb = checkpoint['entity_emb']
        relation_emb = checkpoint['relation_emb']
        entity_relation_cnt = checkpoint['entity_relation_cnt']
        entity_relation_prob = checkpoint['entity_relation_prob']
    else:
        entity_emb = torch.randn((imdb.num_obj_classes, emb_dim))
        relation_emb = torch.randn((imdb.num_relation_classes + 1, emb_dim))
        entity_relation_cnt = torch.zeros((imdb.num_obj_classes, imdb.num_relation_classes))
        entity_relation_prob = torch.zeros((imdb.num_obj_classes, imdb.num_relation_classes))

    if args.cuda:
        entity_emb = entity_emb.cuda()
        relation_emb = relation_emb.cuda()

    entity_emb = Variable(entity_emb)
    relation_emb = Variable(relation_emb)
    entity_emb.requires_grad = True
    relation_emb.requires_grad = True

    # initilize the network here.
    if args.cuda:
        cfg.CUDA = True

    if args.net == 'vgg16':
        fasterRCNN = vgg16(imdb.obj_classes, pretrained=True, class_agnostic=args.class_agnostic,
                           bias_box=args.bias_box)
    elif args.net == 'res101':
        fasterRCNN = resnet(imdb.obj_classes, 101, pretrained=True, class_agnostic=args.class_agnostic,
                            bias_box=args.bias_box)
    elif args.net == 'res50':
        fasterRCNN = resnet(imdb.obj_classes, 50, pretrained=True, class_agnostic=args.class_agnostic,
                            bias_box=args.bias_box)
    elif args.net == 'res152':
        fasterRCNN = resnet(imdb.obj_classes, 152, pretrained=True, class_agnostic=args.class_agnostic,
                            bias_box=args.bias_box)
    else:
        print("network is not defined")
        pdb.set_trace()

    fasterRCNN.create_architecture()

    # the projection network
    proj_net = ProjectNet(emb_dim, emb_dim)
    context_proj_net = ProjectNet(emb_dim, emb_dim)

    lr = cfg.TRAIN.LEARNING_RATE
    lr = args.lr
    # tr_momentum = cfg.TRAIN.MOMENTUM
    # tr_momentum = args.momentum

    params = []
    for key, value in dict(fasterRCNN.named_parameters()).items():
        if value.requires_grad:
            if 'bias' in key:
                params += [{'params': [value], 'lr': lr * (cfg.TRAIN.DOUBLE_BIAS + 1), \
                            'weight_decay': cfg.TRAIN.BIAS_DECAY and cfg.TRAIN.WEIGHT_DECAY or 0}]
            else:
                params += [{'params': [value], 'lr': lr, 'weight_decay': cfg.TRAIN.WEIGHT_DECAY}]

    params += [{'params': proj_net.parameters(), 'lr': lr, 'weight_decay': cfg.TRAIN.WEIGHT_DECAY}]
    params += [{'params': context_proj_net.parameters(), 'lr': lr, 'weight_decay': cfg.TRAIN.WEIGHT_DECAY}]
    params += [{'params': entity_emb, 'lr': lr, 'weight_decay': 0}]
    params += [{'params': relation_emb, 'lr': lr, 'weight_decay': 0}]

    if args.cuda:
        fasterRCNN.cuda()
        proj_net.cuda()
        context_proj_net.cuda()

    if args.optimizer == "adam":
        lr = lr * 0.1
        optimizer = torch.optim.Adam(params)

    elif args.optimizer == "sgd":
        optimizer = torch.optim.SGD(params, momentum=cfg.TRAIN.MOMENTUM)

    if args.resume:
        load_name = os.path.join(output_dir,
                                 'faster_rcnn_{}_{}_{}.pth'.format(args.checksession, args.checkepoch, args.checkpoint))
        print("loading checkpoint %s" % (load_name))
        checkpoint = torch.load(load_name)
        args.session = checkpoint['session']
        args.start_epoch = checkpoint['epoch']
        fasterRCNN.load_state_dict(checkpoint['model'])
        proj_net.load_state_dict(checkpoint['proj_net'])
        context_proj_net.load_state_dict(checkpoint['context_proj_net'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        lr = optimizer.param_groups[0]['lr']
        if 'pooling_mode' in checkpoint.keys():
            cfg.POOLING_MODE = checkpoint['pooling_mode']
        print("loaded checkpoint %s" % (load_name))

    if args.mGPUs:
        fasterRCNN = nn.DataParallel(fasterRCNN)
        proj_net = nn.DataParallel(proj_net)
        context_proj_net = nn.DataParallel(context_proj_net)

    iters_per_epoch = int(train_size / args.batch_size)

    if args.use_tfboard:
        from tensorboardX import SummaryWriter

        logger = SummaryWriter("logs")

    for epoch in range(args.start_epoch, args.max_epochs + 1):
        # initialize the entity embeddings when the pre-training is done
        if epoch == (args.pretrain_epochs + 1):
            print("Initialize entity embedding.")

            fasterRCNN.train()
            data_iter = iter(dataloader)
            entity_cnt = torch.zeros(imdb.num_obj_classes)
            start = time.time()

            for step in range(iters_per_epoch):
                data = next(data_iter)
                with torch.no_grad():
                    im_data.resize_(data[0].size()).copy_(data[0])
                    im_info.resize_(data[1].size()).copy_(data[1])
                    obj_gt_boxes.resize_(data[2].size()).copy_(data[2])
                    obj_num_boxes.resize_(data[3].size()).copy_(data[3])
                    region_gt_boxes.resize_(data[4].size()).copy_(data[4])
                    region_num_boxes.resize_(data[5].size()).copy_(data[5])
                    relation_triples.resize_(data[6].size()).copy_(data[6])
                    relation_num.resize_(data[7].size()).copy_(data[7])

                fasterRCNN.zero_grad()
                rois, cls_prob, bbox_pred, \
                rpn_loss_cls, rpn_loss_box, \
                RCNN_loss_cls, RCNN_loss_bbox, rois_label, gt_feat, gt_label, \
                gt_id, context_dict = fasterRCNN(im_data, im_info, obj_gt_boxes, obj_num_boxes, use_context = True)

                for gt_index in range(gt_feat.shape[0]):
                    tmp_cls = gt_label[gt_index].long().item()
                    entity_emb.data[tmp_cls] = (entity_emb.data[tmp_cls] * entity_cnt.data[tmp_cls] + gt_feat.data[gt_index, :].detach()) / \
                                               (entity_cnt.data[tmp_cls] + 1)
                    entity_cnt[tmp_cls] += 1

                if step % args.disp_interval == 0:
                    end = time.time()
                    print("[session %d][epoch %2d][iter %4d/%4d] time cost: %f" \
                          % (args.session, epoch, step, iters_per_epoch, end - start))

        # setting to train mode
        fasterRCNN.train()
        proj_net.train()
        context_proj_net.train()
        running_loss_det = 0
        running_loss_likelihood = 0
        start = time.time()

        if epoch % (args.lr_decay_step + 1) == 0:
            adjust_learning_rate(optimizer, args.lr_decay_gamma)
            lr *= args.lr_decay_gamma

        data_iter = iter(dataloader)
        for step in range(iters_per_epoch):
            data = next(data_iter)
            with torch.no_grad():
                im_data.resize_(data[0].size()).copy_(data[0])
                im_info.resize_(data[1].size()).copy_(data[1])
                obj_gt_boxes.resize_(data[2].size()).copy_(data[2])
                obj_num_boxes.resize_(data[3].size()).copy_(data[3])
                region_gt_boxes.resize_(data[4].size()).copy_(data[4])
                region_num_boxes.resize_(data[5].size()).copy_(data[5])
                relation_triples.resize_(data[6].size()).copy_(data[6])
                relation_num.resize_(data[7].size()).copy_(data[7])

            fasterRCNN.zero_grad()
            rois, cls_prob, bbox_pred, \
            rpn_loss_cls, rpn_loss_box, \
            RCNN_loss_cls, RCNN_loss_bbox, rois_label, gt_feat, \
            gt_label, gt_id, context_dict = fasterRCNN(im_data, im_info, obj_gt_boxes, obj_num_boxes, use_context=True)

            # update the entity embeddings with instance embeddings
            entity_emb = update_emb(gt_feat, gt_label, entity_emb, beta=args.beta)

            # project the entity embeddings to relation space
            entity_emb_proj = proj_net(entity_emb)

            ### detection loss
            loss_det = rpn_loss_cls.mean() + rpn_loss_box.mean() \
                       + RCNN_loss_cls.mean() + RCNN_loss_bbox.mean()

            ### likelihood loss
            # extract all relation triples in a mini-batch
            relations_list = list()
            relation_img_id_list = list()

            for img_id in range(relation_num.shape[0]):
                for relation_id in range(relation_num[img_id].long().item()):
                    tmp_relation = relation_triples[img_id][relation_id]
                    relations_list.append(tmp_relation)
                    relation_img_id_list.append(img_id)

            relations = torch.stack(relations_list, dim=0)

            # get the embedding of heads, tails and relationships
            feat_dict = dict(zip(gt_id.long().tolist(), gt_feat))
            label_dict = dict(zip(gt_id.long().tolist(), gt_label.long().tolist()))
            head_feat_list = list()
            tail_feat_list = list()
            head_cls_feat_list = list()
            tail_cls_feat_list = list()
            context_feat_list = list()
            relation_label_list = list()
            valid_relation_num = [0 for img_id in range(relation_num.shape[0])]

            for relation_id in range(relations.shape[0]):
                tmp_relation_index = relations[relation_id][0].long().item()
                tmp_subject_id = relations[relation_id][1].long().item()
                tmp_object_id = relations[relation_id][2].long().item()
                tmp_relation_img_id = relation_img_id_list[relation_id]

                if tmp_subject_id in feat_dict and tmp_object_id in feat_dict:
                    head_feat_list.append(feat_dict[tmp_subject_id])
                    tail_feat_list.append(feat_dict[tmp_object_id])
                    head_cls_feat_list.append(entity_emb[label_dict[tmp_subject_id], :])
                    tail_cls_feat_list.append(entity_emb[label_dict[tmp_object_id], :])
                    context_feat_list.append(context_dict[(tmp_subject_id, tmp_object_id)])
                    relation_label_list.append(relations[relation_id][0].long())
                    entity_relation_cnt[label_dict[tmp_subject_id]][tmp_relation_index] += 1
                    valid_relation_num[tmp_relation_img_id] += 1

            if len(head_feat_list) != 0:
                # get the stacked embedding for head, tail and predicate
                head_feat = torch.stack(head_feat_list, dim=0)
                tail_feat = torch.stack(tail_feat_list, dim=0)
                head_cls_feat = torch.stack(head_cls_feat_list, dim=0)
                tail_cls_feat = torch.stack(tail_cls_feat_list, dim=0)
                context_feat = torch.stack(context_feat_list, dim=0)
                relation_label = torch.stack(relation_label_list, dim=0)

                # project the instance and entity embeddings to the relation space
                head_feat_proj = proj_net(head_feat)
                tail_feat_proj = proj_net(tail_feat)
                head_cls_feat_proj = proj_net(head_cls_feat)
                tail_cls_feat_proj = proj_net(tail_cls_feat)

                # the loss constraint for maximum likelihood learning
                if epoch <= args.pretrain_epochs:
                    loss_likelihood = 0
                else:
                    # update the relation embeddings with context embeddings
                    context_feat_proj = context_proj_net(context_feat)
                    relation_emb = update_emb(context_feat_proj, relation_label, relation_emb, beta=args.beta)

                    loss_likelihood = likelihood_loss(gt_feat, gt_label, obj_num_boxes, head_feat_proj, tail_feat_proj,
                                                    head_cls_feat_proj, tail_cls_feat_proj, context_feat_proj,
                                                    relation_label, valid_relation_num, entity_emb, entity_emb_proj,
                                                    relation_emb, gt_id, relations, tau=args.tau, neg_size=args.neg_size)
            else:
                loss_likelihood = 0

            # optimize the model
            if epoch <= args.pretrain_epochs:
                loss = loss_det
                running_loss_det += loss_det.item()
            else:
                loss = loss_det + args.likelihood_weight * loss_likelihood
                running_loss_det += loss_det.item()
                if loss_likelihood != 0:
                    running_loss_likelihood += loss_likelihood.item()

            # backward
            optimizer.zero_grad()
            loss.backward()
            if args.net == 'vgg16':
                clip_gradient(fasterRCNN, 10.)
                if epoch > args.pretrain_epochs and loss_likelihood != 0:
                    clip_gradient(proj_net, 10.)
                    clip_gradient(context_proj_net, 10.)
                    try:
                        entity_emb.grad.clamp_(-10., 10.)
                        relation_emb.grad.clamp_(-10., 10.)
                    except:
                        pass
            optimizer.step()

            if step % args.disp_interval == 0:
                end = time.time()
                if step > 0:
                    running_loss_det /= (args.disp_interval + 1)
                    running_loss_likelihood /= (args.disp_interval + 1)

                if args.mGPUs:
                    loss_rpn_cls = rpn_loss_cls.mean().item()
                    loss_rpn_box = rpn_loss_box.mean().item()
                    loss_rcnn_cls = RCNN_loss_cls.mean().item()
                    loss_rcnn_box = RCNN_loss_bbox.mean().item()
                    fg_cnt = torch.sum(rois_label.data.ne(0))
                    bg_cnt = rois_label.data.numel() - fg_cnt
                else:
                    loss_rpn_cls = rpn_loss_cls.item()
                    loss_rpn_box = rpn_loss_box.item()
                    loss_rcnn_cls = RCNN_loss_cls.item()
                    loss_rcnn_box = RCNN_loss_bbox.item()
                    fg_cnt = torch.sum(rois_label.data.ne(0))
                    bg_cnt = rois_label.data.numel() - fg_cnt

                print("[session %d][epoch %2d][iter %4d/%4d] detection loss: %.4f, likelihood loss: %.4f, lr: %.2e" \
                      % (args.session, epoch, step, iters_per_epoch, running_loss_det, running_loss_likelihood, lr))
                print("\t\t\tfg/bg=(%d/%d), time cost: %f" % (fg_cnt, bg_cnt, end - start))
                print("\t\t\trpn_cls: %.4f, rpn_box: %.4f, rcnn_cls: %.4f, rcnn_box %.4f" \
                      % (loss_rpn_cls, loss_rpn_box, loss_rcnn_cls, loss_rcnn_box))
                if args.use_tfboard:
                    info = {
                        'loss_det': running_loss_det,
                        'loss_likelihood': running_loss_likelihood,
                        'loss_rpn_cls': loss_rpn_cls,
                        'loss_rpn_box': loss_rpn_box,
                        'loss_rcnn_cls': loss_rcnn_cls,
                        'loss_rcnn_box': loss_rcnn_box
                    }
                    logger.add_scalars("logs_s_{}/losses".format(args.session), info,
                                       (epoch - 1) * iters_per_epoch + step)

                running_loss_det = 0
                running_loss_likelihood = 0
                start = time.time()

        entity_relation_prob = entity_relation_cnt / entity_relation_cnt.sum(1).unsqueeze(1)

        save_name = os.path.join(output_dir, 'faster_rcnn_{}_{}_{}.pth'.format(args.session, epoch, step))
        save_checkpoint({
            'session': args.session,
            'epoch': epoch + 1,
            'model': fasterRCNN.module.state_dict() if args.mGPUs else fasterRCNN.state_dict(),
            'proj_net': proj_net.module.state_dict() if args.mGPUs else proj_net.state_dict(),
            'context_proj_net': context_proj_net.module.state_dict() if args.mGPUs else context_proj_net.state_dict(),
            'entity_emb': entity_emb.cpu().detach(),
            'relation_emb': relation_emb.cpu().detach(),
            'entity_relation_cnt': entity_relation_cnt,
            'entity_relation_prob': entity_relation_prob,
            'optimizer': optimizer.state_dict(),
            'pooling_mode': cfg.POOLING_MODE,
            'class_agnostic': args.class_agnostic,
        }, save_name)
        print('save model: {}'.format(save_name))

    if args.use_tfboard:
        logger.close()

    os.system('watch nvidia-smi')
