from __future__ import print_function
import argparse
import os
import random
import numpy as np
import torch
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import sys
import time
from collections import Counter

sys.path.append('../')
from dataset.threedfront_dataset_incremental import ThreedFrontDatasetSceneGraphIncremental,LengthGroupedCurriculumSampler
from torch.utils.data import DataLoader


from model.AE import AE
from model.losses import bce_loss
from helpers.util import bool_flag, _CustomDataParallel

from model.losses import calculate_model_losses, fullbank_info_nce

import torch.nn.functional as F
import json

from tensorboardX import SummaryWriter


parser = argparse.ArgumentParser()
# standard hyperparameters, batch size, learning rate, etc
parser.add_argument('--batchSize', type=int, default=8, help='input batch size')
parser.add_argument('--auxlr', type=float, help='auxiliary learning rate', default=0.0001)
parser.add_argument('--nepoch', type=int, default=200, help='number of epochs to train for')

# paths and filenames
parser.add_argument('--outf', type=str, default='checkpoint', help='output folder')
parser.add_argument('--model', type=str, default='', help='model path')
parser.add_argument('--dataset', required=False, type=str, default="/media/xxx/FRONT",
                    help="dataset path")
parser.add_argument('--logf', default='logs', help='folder to save tensorboard logs')
parser.add_argument('--exp', default='/home/xxx/data_weight/full_3d', help='experiment name')
parser.add_argument('--room_type', default='bedroom', help='room type [bedroom, livingroom, diningroom, library, all]')

# GCN parameters
parser.add_argument('--residual', type=bool_flag, default=False, help="residual in GCN")
parser.add_argument('--pooling', type=str, default='avg', help="pooling method in GCN")

# dataset related
parser.add_argument('--large', default=False, type=bool_flag,
                    help='large set of class labels. Use mapping.json when false')

parser.add_argument('--with_feats', type=bool_flag, default=True,
                    help="reads latent point features."
                         "If not existing, they get generated at the beginning.")  # TODO

parser.add_argument('--shuffle_objs', type=bool_flag, default=True, help="shuffle objs of a scene")

parser.add_argument('--num_box_params', default=7, type=int, help="number of the dimension of the bbox. [6,7]")

# training and architecture related
parser.add_argument('--workers', type=int, help='number of data loading workers', default=4)


parser.add_argument('--loadmodel', default=False, type=bool_flag)
parser.add_argument('--loadepoch', default=90, type=int, help='only valid when loadmodel is true')

parser.add_argument('--network_type', default='increment3dG', choices=['increment3dG', 'baseline1', 'baseline2','baseline3'], type=str)


args = parser.parse_args()
print(args)


def to_cuda(x, device='cuda'):
    if torch.is_tensor(x):
        return x.to(device, non_blocking=True)
    if isinstance(x, dict):
        return {k: to_cuda(v, device) for k, v in x.items()}
    if isinstance(x, list):
        return [to_cuda(v, device) for v in x]
    return x


def train():
    """ Train the network based on the provided argparse parameters
    """
    args.manualSeed = random.randint(1, 10000)  # optionally fix seed 7494
    print("Random Seed: ", args.manualSeed)

    print(torch.__version__)

    random.seed(args.manualSeed)
    torch.manual_seed(args.manualSeed)

    # instantiate scene graph dataset for training
    dataset = ThreedFrontDatasetSceneGraphIncremental(
        root=args.dataset,
        split='train_scans',
        shuffle_objs=args.shuffle_objs,
        with_feats=args.with_feats,
        large=args.large,
        seed=False,
        room_type=args.room_type,
        recompute_feats=False)
    
    loader  = torch.utils.data.DataLoader(dataset,
                        batch_sampler=LengthGroupedCurriculumSampler(dataset, args.batchSize),    # batch size 1 for incremental testing
                        collate_fn=dataset.collate_fn_inc,
                        num_workers=int(args.workers))

    # number of object classes and relationship classes
    num_classes = len(dataset.classes)
    num_relationships = len(dataset.relationships) + 1

    try:
        os.makedirs(args.outf)
    except OSError:
        pass
    # instantiate the model
    print('datasize: ',len(dataset.vocab['object_idx_to_name']))     # 14
    print(sorted([n.strip() for n in dataset.vocab['object_idx_to_name']]))

    model = AE(root=args.dataset,type=args.network_type, vocab=dataset.vocab, residual=args.residual, gconv_pooling=args.pooling, num_box_params=args.num_box_params )

    if torch.cuda.is_available():
        model = model.cuda()

    if args.loadmodel:
        model.load_networks(exp=args.exp, epoch=args.loadepoch, restart_optim=False)

    # initialize tensorboard writer
    writer = SummaryWriter(args.exp + "/" + args.logf)

    # optimizer for model

    params = filter(lambda p: p.requires_grad, list(model.parameters()))
    optimizer_bl = optim.Adam(params, lr=args.auxlr)

    print("---- Model and Dataset built ----")

    if not os.path.exists(args.exp + "/" + args.outf):
        os.makedirs(args.exp + "/" + args.outf)

    # save parameters so that we can read them later on evaluation
    with open(os.path.join(args.exp, 'args.json'), 'w') as f:
        json.dump(args.__dict__, f, indent=2)
    print("Saving all parameters under:")
    print(os.path.join(args.exp, 'args.json'))

    torch.autograd.set_detect_anomaly(True)
    counter = model.counter if model.counter else 0

    bank_cache = {} 
    bbox_file = "/media/xxx/xxx/FRONT/cat_jid_test_small.json"
    with open(bbox_file, "r") as read_file:
        box_data = json.load(read_file)
        box_data['chair'].update(box_data['stool'])
    train_bbox_file = "/media/xxx/xxx/FRONT/cat_jid_trainval_small.json"
    with open(train_bbox_file, "r") as read_file:
        train_box_data = json.load(read_file)
        train_box_data['chair'].update(train_box_data['stool'])

    print("---- Starting training loop! ----")

    start_epoch = model.epoch if model.epoch else 0
    for epoch in range(start_epoch, args.nepoch):
        print('Epoch: {}/{}'.format(epoch, args.nepoch))

        for batch in loader:                         # merger the data with same step in a batch
            model.reset_all_scene_cls_states() # clear scene states for every batch of new scenes
            enc   = to_cuda(batch['encoder'])       
            meta  = to_cuda(batch['step_meta'])              # (N_step,)

            K = enc['obj_to_step'].max().item() + 1     

            for k in range(K):                       
                #print('---------------------------------Step {} / {}'.format(k, K))
                obj_step_mask = enc['obj_to_step'] == k
                if not obj_step_mask.any(): continue     
                global_obj_index = torch.nonzero(obj_step_mask, as_tuple=False).squeeze(1)  # gloabl obj index with same step in this batch （nonzero returns the index of non-zero elements, as_tuple=False returns a tensor of indices）

                pre_m = enc['old_mask'][obj_step_mask]
                new_m = enc['new_mask'][obj_step_mask]

                objs     = enc['objs'][obj_step_mask]
                boxes    = enc['boxes'][obj_step_mask]

                feats    = enc.get('feats')
                feats    = feats[obj_step_mask] 

                obj_scene_ids = enc['obj_to_scene'][obj_step_mask]  # (N_step_obj,)

                if enc['triples'].numel():
                    tri_step_mask     = enc['triple_to_step'] == k
                    triples_step      = enc['triples'][tri_step_mask]
                    triple_scene_ids  = enc['triple_to_scene'][tri_step_mask]  # (N_tri_step,)
                else:
                    triples_step     = enc['triples'].new_empty(0, 3)
                    triple_scene_ids = enc['triple_to_scene'].new_empty(0, dtype=torch.long)


                pred_boxs, pred_shapes = model.forward_incremental_3D_(
                    obj_batch_scene_ids = obj_scene_ids,
                    objs         = objs,
                    boxes        = boxes,
                    triples      = triples_step,
                    new_mask     = new_m,
                    obj_indices  = global_obj_index,
                    triple_scene_ids = triple_scene_ids)        

                gt_boxs   = boxes [new_m]
                gt_shapes = feats [new_m] 

                lb,_ = calculate_model_losses(args, pred_boxs,  gt_boxs,   name='box')
                ls,_ = calculate_model_losses(args, pred_shapes, gt_shapes, name='shape')
                # ls = fullbank_info_nce(
                # pred         = pred_shapes,        
                # target       = gt_shapes,          
                # labels       = objs[new_m].cpu(),   
                # train_box_data = train_box_data, 
                # code_dict    = model.code_dict,
                # label_classes= dataset.classes,  
                # bank_cache   = bank_cache,         
                # tau          = 0.12,
                # topK         = 128)
                (5*lb + 1*ls).backward()


            torch.nn.utils.clip_grad_norm_(model.parameters(), 5.)
            optimizer_bl.step()
            optimizer_bl.zero_grad()


            for group in optimizer_bl.param_groups:
                for p in group['params']:
                    if p.grad is not None and p.requires_grad and torch.isnan(p.grad).any():
                        print('NaN grad in step {}.'.format(counter))
                        p.grad[torch.isnan(p.grad)] = 0


            counter += 1
            if counter % 50 == 0:
                message = "loss at {}: box {:.4f}\tshape {:.4f}\t".format(
                    counter, lb, ls )

                print(message)

            writer.add_scalar('Train_Loss_BBox', lb, counter)
            writer.add_scalar('Train_Loss_Shape', ls, counter)

        if epoch % 15 == 0:
            model.save(args.exp, args.outf, epoch, counter=counter)
            print('saved model_{}'.format(epoch))

    writer.close()


if __name__ == "__main__":
    train()
