from __future__ import print_function, division
import torch
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.multiprocessing as mp
import lossFunction
from dataLoader import scannet_Dataset, scannet_Merge, modelnet_Dataset, modelnet_Merge, S3DIS
import torch.distributed as dist
import train
import torch.optim as optim
# from examples.minkunet import MinkUNet34C
# from examples.classification_modelnet40 import MinkowskiFCNN
from models import MinkUNet14A, MinkowskiFCNN, URS, global_encoder, GMMNnetwork
import os
import argparse
from examples.common import seed_all
import MinkowskiEngine as ME
from torch.utils.data.distributed import DistributedSampler
import gensim.downloader as api
import numpy as np


parser = argparse.ArgumentParser()
parser.add_argument("--voxel_size", type=float, default=0.05)
parser.add_argument("--epochs", type=int, default=1000)
parser.add_argument("--batch_size_scannet", default=2, type=int)
parser.add_argument("--batch_size_modelnet", default=2, type=int)
parser.add_argument("--num_workers", type=int, default=0)
parser.add_argument("--seed", type=int, default=777)
parser.add_argument("--saveName", type=str, default="Full_final_64")
parser.add_argument("--testName", type=str, default="Full_final_64")
parser.add_argument("--world_size", type=int, default=2)
parser.add_argument("--device", type=int, default=0)
parser.add_argument("--pretrained_modelnet", type=str, default= os.environ['HOME'] + '/project/unsupervised_segmentation/pretrained_modelnet.pth')
parser.add_argument("--dataRoot_scannet", type=str, default=os.environ['HOME'] + "/dataset/scannet/scans")
parser.add_argument("--dataRoot_modelnet", type=str, default=os.environ['HOME'] + "/dataset/modelnet40_ply_hdf5_2048")
parser.add_argument("--train_file", type=str, default=os.environ['HOME'] + "/project/unsupervised_segmentation/splits/scannet/scannetv2_train.txt")
# parser.add_argument("--val_file", type=str, default=os.environ['HOME'] + "/project/unsupervised_segmentation/splits/scannet/scannetv2_train.txt")
parser.add_argument("--val_file", type=str, default=os.environ['HOME'] + "/project/unsupervised_segmentation/splits/scannet/scannetv2_val.txt")
# parser.add_argument("--model_list", type=list, default=['toilet', 'door', 'curtain'])
# parser.add_argument("--model_list", type=list, default=['chair', 'table', 'bed',
#                                                         'sink', 'bathtub', 'door', 'curtain'])


# parser.add_argument("--model_list", type=list, default=['chair', 'bookshelf', 'sofa', 'table'])

# parser.add_argument("--model_list_validation", type=list, default=['chair', 'bookshelf', 'sofa', 'table'])
parser.add_argument("--model_list", type=list, default=['chair', 'table', 'bed', 'sink', 'bathtub', 'door',
                                                             'curtain', 'desk', 'bookshelf', 'sofa', 'toilet'])
#
#
parser.add_argument("--model_list_validation", type=list, default=['chair', 'table', 'bed', 'sink', 'bathtub', 'door',
                                                         'curtain', 'desk', 'bookshelf', 'sofa', 'toilet'])

parser.add_argument("--model_list_zero_shot", type=list, default=['chair', 'table', 'bed', 'sink', 'bathtub', 'door',
                                                         'curtain', 'desk', 'bookshelf', 'sofa', 'toilet', 'wall', 'floor',
                                                         'cabinet', 'window', 'picture', 'counter', 'refrigerator', 'shower curtain',
                                                         'other furniture'])

# ['wall', 'floor', 'cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window', 'bookshelf', 'picture', 'counter', 'desk', 'curtain', 'refrigerator', 'shower curtain', 'toilet', 'sink', 'bathtub', 'otherfurniture']

parser.add_argument("--word2vec", type=str, default='word2vector.npy')

parser.add_argument(
    "--lr",
    type=float,
    default=0.007,
    metavar="LR",
    help="learning rate (default: auto)",
)

# parser.add_argument("--valida_model_list", type=list, default= ['bathtub', 'bed', 'bench', 'bookshelf', 'curtain', 'desk', 'door',
#                                                                 'dresser', 'flower_pot', 'night_stand', 'piano', 'plant',
#                                                                 'sink', 'sofa', 'stairs', 'table', 'tent', 'toilet', 'tv_stand',
#                                                                 'vase', 'wardrobe', 'chair'])
parser.add_argument("--valida_model_list", type=list, default=['airplane', 'bathtub', 'bed', 'bench', 'bookshelf', 'bottle',
                                                               'bowl', 'car', 'chair', 'cone', 'cup', 'curtain', 'desk', 'door',
                                                               'dresser', 'flower_pot', 'glass_box', 'guitar', 'keyboard', 'lamp',
                                                               'laptop', 'mantel', 'monitor', 'night_stand', 'person', 'piano',
                                                               'plant', 'radio', 'range_hood', 'sink', 'sofa', 'stairs', 'stool',
                                                               'table', 'tent', 'toilet', 'tv_stand', 'vase', 'wardrobe', 'xbox'])
parser.add_argument("--feature_dim", type=int, default=96)
parser.add_argument("--noise_dim", type=int, default=300)
parser.add_argument("--embed_dim", type=int, default=300)
parser.add_argument("--hidden_size", type=int, default=256)
parser.add_argument("--lr_generator", type=float, default=0.0002)
parser.add_argument("--batch_size_generator", type=int, default=128)
parser.add_argument("--saved_validation_images", type=int, default=10)
parser.add_argument("--generator_model", type=str, default="gmmn", choices=["gmmn", "dae"])

# 1_(0,1)
# parser.add_argument("--posi_threshold", type=float, default=0.9997)
# parser.add_argument("--nega_threshold", type=float, default=0.00005)
# 1_(2,3)
# parser.add_argument("--posi_threshold", type=float, default=0.9997)
# parser.add_argument("--nega_threshold", type=float, default=0.0001)
# 2_(0,1)
# parser.add_argument("--posi_threshold", type=float, default=0.9997)
# parser.add_argument("--nega_threshold", type=float, default=0.00001)

parser.add_argument("--threshold", type=int, default=[0.5, 0.5, 0.5, 0.5, 0.5])
parser.add_argument("--posi_threshold", type=float, default=0.9998)
parser.add_argument("--nega_threshold", type=float, default=0.00005)

def main_worker(gpu, ngpus_per_node, config):
    global min_time
    config.device = gpu
    if config.device is not None:
        print("Use GPU: {} for training".format(config.device))
    rank = 0 * ngpus_per_node + gpu

    dist.init_process_group(
        backend="nccl",
        init_method="tcp://127.0.0.1:45551",
        world_size=config.world_size,
        rank=rank,
    )
    # create model
    torch.cuda.set_device(config.device)

    generator = GMMNnetwork(config.noise_dim, config.embed_dim,\
                config.hidden_size, config.feature_dim, embed_feature_size=0).to(config.device)
    model_scannet = MinkUNet14A(1, len(config.model_list)+ 1, config).to(config.device)
    # model_scannet = MinkUNet14A(1, 20).cuda()
    model_modelnet = 0
    # model_modelnet = global_encoder(1, 5).to(config.device)
    # model_modelnet = MinkowskiFCNN(1, 40).to(config.device)
    # model_URS = URS().to(config.device)
    model_URS = 0
    # model_URS = URS().cuda()
    # model_scannet.load_state_dict(torch.load('model_scannet.pth'))


    model_scannet = torch.nn.parallel.DistributedDataParallel(model_scannet, find_unused_parameters=True, device_ids=[config.device])
    model_scannet = ME.MinkowskiSyncBatchNorm.convert_sync_batchnorm(model_scannet)

    # model_scannet.load_state_dict(torch.load('model_scannet.pth'))
    # model_URS.load_state_dict(torch.load('model_URS.pth'))


    '''
    model_dict = model_modelnet.state_dict()
    pretrained_dict = torch.load(config.pretrained_modelnet)['state_dict']
    # 1. filter out unnecessary keys
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
    # 2. overwrite entries in the existing state dict
    model_dict.update(pretrained_dict)
    model_modelnet.load_state_dict(model_dict)
    '''
    # model_modelnet.load_state_dict(torch.load(config.pretrained_modelnet)['state_dict'])

    # if model == 'test':
    #     fine_LSTM = torch.load('output/' + "730" + testName + "fine_LSTM.pkl", map_location=lambda storage, loc:storage.cuda(use_gpu))
    scannet_train = scannet_Dataset(config=config, phase="train")
    scannet_val = scannet_Dataset(config=config, phase="val")
    modelnet_train = modelnet_Dataset(config=config)

    # scannet_train = S3DIS(test_area_idx=1,config=config,phase="train")
    # scannet_val = S3DIS(test_area_idx=1,config=config,phase="val")

    # import pdb
    # pdb.set_trace()

    train_dataloader = DataLoader(scannet_train,
                               collate_fn=scannet_Merge,
                               batch_size=config.batch_size_scannet,
                               shuffle=False,
                               sampler=DistributedSampler(scannet_train, num_replicas=config.world_size, rank=rank),
                               num_workers=config.num_workers)

    val_dataloader = DataLoader(scannet_val,
                                  collate_fn=scannet_Merge,
                                  batch_size=1,
                                  shuffle=False,
                                  sampler=DistributedSampler(scannet_val, num_replicas=config.world_size, rank=rank),
                                  num_workers=config.num_workers)

    modelnet_dataloader = DataLoader(modelnet_train,
                                  collate_fn=modelnet_Merge,
                                  batch_size=config.batch_size_modelnet,
                                  shuffle=False,
                                  sampler=DistributedSampler(modelnet_train, num_replicas=config.world_size, rank=rank),
                                  num_workers=config.num_workers)

    scannet_dataloader = {'train': train_dataloader, 'val': val_dataloader}
    # criterion = LossFunction.coarse_heatmap(use_gpu, batchSize, landmarkNum, image_scale)
    criterion = lossFunction.contrastive(config.device).to(config.device)
    criterion_generator = lossFunction.GMMNLoss(sigma=[2, 5, 10, 20, 40, 80], cuda=config.device).build_loss()
    # params = list(model_URS.parameters()) + list(model_modelnet.parameters()) + list(model_scannet.parameters())
    params = list(model_scannet.parameters())
    optimizer = optim.Adam(params)
    # optimizer_D = optim.Adam(list(model_modelnet.parameters()))

    # train_params = [{"params": generator.get_1x_lr_params(), "lr": config.lr}, \
    #                 {"params": generator.get_10x_lr_params(), "lr": config.lr * 10}]

    # Define Optimizer adn generator optimizer
    optimizer = torch.optim.Adam(params)
    optimizer_G = optim.Adam(generator.parameters())

    # model_scannet.load_state_dict(torch.load('model_scannet_test.pth'))
    #model_URS.load_state_dict(torch.load('model_URS_test.pth'))
    #optimizer.load_state_dict(torch.load('optimizer_test.pth'))
    models = (model_modelnet, model_scannet, model_URS, generator)
    dataloaders = (scannet_dataloader, modelnet_dataloader)

    train.train(dataloaders, models, optimizer, optimizer_G, criterion, criterion_generator, config)

def main():
    config = parser.parse_args()

    '''
    model = api.load("word2vec-google-news-300")
    CLASS_LABELS = ['wall', 'floor', 'cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window', 'bookshelf',
                    'picture', 'counter', 'desk', 'curtain', 'refrigerator', 'shower curtain', 'toilet', 'sink',
                    'bathtub', 'other furniture', 'ceiling', 'floor', 'wall', 'beam', 'column', 'window', 'door',
                    'table', 'chair', 'sofa', 'bookshelf', 'board', 'clutter']

    CLASS_LABELS = ['wall', 'floor', 'cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window', 'bookshelf',
                    'picture', 'counter', 'desk', 'curtain', 'refrigerator', 'shower_curtain', 'toilet', 'sink',
                    'bathtub', 'furniture', 'ceiling', 'floor', 'wall', 'beam', 'column', 'window', 'door',
                    'table', 'chair', 'sofa', 'bookshelf', 'board', 'clutter']

    word2vector = {}
    for item in CLASS_LABELS:
        word2vector[item] = model[item]

    word2vector['shower curtain'] = word2vector['shower_curtain']
    word2vector['other furniture'] = word2vector['furniture']


    np.save('word2vector.npy', word2vector)
    word2vector = np.load('word2vector.npy', allow_pickle=True).item()
    print(word2vector)
    '''
    print(
        "Testing ",
        config.world_size,
        "batch size ",
        config.world_size * config.batch_size_scannet
    )
    mp.spawn(main_worker, nprocs=config.world_size, args=(config.world_size, config))
    # main_worker(0, 1, config)


if __name__ == "__main__":
    main()