#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import time
import argparse
import sys
import numpy as np
import torch
from tqdm import tqdm

from network.BEV_Unet import BEV_Unet
from network.ptBEV import ptBEVnet
from data_loader.dataset import collate_fn_BEV, collate_fn_BEV_test, SemKITTI, SemKITTI_label_name,spherical_dataset,voxel_dataset
#ignore weird np warning
import warnings
warnings.filterwarnings("ignore")


def fast_hist(pred, label, n):
    k = (label >= 0) & (label < n)
    bin_count = np.bincount(n * label[k].astype(int) + pred[k], minlength=n ** 2)
    return bin_count[:n ** 2].reshape(n, n)


def per_class_iu(hist):
    return np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist))


def fast_hist_crop(output, target, unique_label):
    hist = fast_hist(output.flatten(), target.flatten(), np.max(unique_label)+1)
    hist = hist[unique_label,:]
    hist = hist[:,unique_label]
    return hist


def SemKITTI2train(label):
    if isinstance(label, list):
        return [SemKITTI2train_single(a) for a in label]
    else:
        return SemKITTI2train_single(label)


def SemKITTI2train_single(label):
    return label - 1 # uint8 trick


def train2SemKITTI(input_label):
    # delete 0 label (uses uint8 trick : 0 - 1 = 255 )
    return input_label + 1


def main(args):
    data_path = args.data_dir
    test_batch_size = args.test_batch_size
    model_save_path = args.model_save_path
    output_path = args.test_output_path
    compression_model = args.grid_size[2]
    grid_size = args.grid_size
    pytorch_device = torch.device('cuda:0')
    model = args.model
    if model == 'polar':
        fea_dim = 9
        circular_padding = True
    elif model == 'traditional':
        fea_dim = 7
        circular_padding = False

    # prepare miou fun
    unique_label = np.asarray(sorted(list(SemKITTI_label_name.keys())))[1:] - 1
    unique_label_str = [SemKITTI_label_name[x] for x in unique_label+1]

    # prepare model
    my_BEV_model = BEV_Unet(
        n_class=len(unique_label), n_height=compression_model, input_batch_norm=True, dropout=0.5, circular_padding=circular_padding)
    my_model = ptBEVnet(
        my_BEV_model, pt_model='pointnet', grid_size=grid_size, fea_dim=fea_dim, max_pt_per_encode=256,
        out_pt_fea_dim=512, kernal_size=1, pt_selection='random', fea_compre=compression_model)
    
    if os.path.exists(model_save_path):
        my_model.load_state_dict(torch.load(model_save_path))
    my_model.to(pytorch_device)

    # prepare dataset
    test_pt_dataset = SemKITTI(data_path + '/sequences/', imageset='test', return_ref=True)
    val_pt_dataset = SemKITTI(data_path + '/sequences/', imageset='val', return_ref=True)
    if model == 'polar':
        test_dataset = spherical_dataset(test_pt_dataset, grid_size=grid_size, ignore_label=0, fixed_volume_space=True, return_test=True)
        val_dataset = spherical_dataset(val_pt_dataset, grid_size=grid_size, ignore_label=0, fixed_volume_space=True)
    elif model == 'traditional':
        test_dataset = voxel_dataset(test_pt_dataset, grid_size=grid_size, ignore_label=0, fixed_volume_space=True, return_test=True)
        val_dataset = voxel_dataset(val_pt_dataset, grid_size=grid_size, ignore_label=0, fixed_volume_space=True)
    
    test_dataset_loader = torch.utils.data.DataLoader(
        dataset=test_dataset, batch_size=test_batch_size, collate_fn=collate_fn_BEV_test, shuffle=False, num_workers=4)
    val_dataset_loader = torch.utils.data.DataLoader(
        dataset=val_dataset, batch_size=test_batch_size, collate_fn=collate_fn_BEV, shuffle=False, num_workers=4)

    # validation
    print('*'*80)
    print('Test network performance on validation split')
    print('*'*80)
    pbar = tqdm(total=len(val_dataset_loader))
    my_model.eval()
    hist_list = []
    time_list = []
    with torch.no_grad():
        for i_iter_val, (_, val_vox_label, val_grid, val_pt_labs, val_pt_fea) in enumerate(val_dataset_loader):
            val_vox_label = SemKITTI2train(val_vox_label)
            val_pt_labs = SemKITTI2train(val_pt_labs)
            val_pt_fea_ten = [torch.from_numpy(i).type(torch.FloatTensor).to(pytorch_device) for i in val_pt_fea]
            val_grid_ten = [torch.from_numpy(i[:,:2]).to(pytorch_device) for i in val_grid]
            val_label_tensor = val_vox_label.type(torch.LongTensor).to(pytorch_device)

            torch.cuda.synchronize()
            start_time = time.time()
            predict_labels = my_model(val_pt_fea_ten, val_grid_ten)
            torch.cuda.synchronize()
            time_list.append(time.time()-start_time)

            predict_labels = torch.argmax(predict_labels,dim=1)
            predict_labels = predict_labels.cpu().detach().numpy()
            for count,i_val_grid in enumerate(val_grid):
                hist_list.append(fast_hist_crop(predict_labels[count, val_grid[count][:,0], val_grid[count][:,1], val_grid[count][:,2]], val_pt_labs[count], unique_label))
            pbar.update(1)
    
    iou = per_class_iu(sum(hist_list))
    print('Validation per class iou: ')
    for class_name, class_iou in zip(unique_label_str,iou):
        print('%s : %.2f%%' % (class_name, class_iou*100))
    val_miou = np.nanmean(iou) * 100
    del val_vox_label, val_grid,val_pt_fea, val_grid_ten
    pbar.close()
    print('Current val miou is %.3f ' % val_miou)
    print('Inference time per %d is %.4f seconds\n' % (test_batch_size,np.mean(time_list)))
    
    # test
    print('*'*80)
    print('Generate predictions for test split')
    print('*'*80)
    pbar = tqdm(total=len(test_dataset_loader))
    with torch.no_grad():
        for i_iter_test,(_,_,test_grid,_,test_pt_fea,test_index) in enumerate(test_dataset_loader):
            # predict
            test_pt_fea_ten = [torch.from_numpy(i).type(torch.FloatTensor).to(pytorch_device) for i in test_pt_fea]
            test_grid_ten = [torch.from_numpy(i[:,:2]).to(pytorch_device) for i in test_grid]

            predict_labels = my_model(test_pt_fea_ten,test_grid_ten)
            predict_labels = torch.argmax(predict_labels,1).type(torch.uint8)
            predict_labels = predict_labels.cpu().detach().numpy()
            # write to label file
            for count,i_test_grid in enumerate(test_grid):
                test_pred_label = predict_labels[count,test_grid[count][:,0],test_grid[count][:,1],test_grid[count][:,2]]
                test_pred_label = train2SemKITTI(test_pred_label)
                test_pred_label = np.expand_dims(test_pred_label,axis=1)
                save_dir = test_pt_dataset.im_idx[test_index[count]]
                _,dir2 = save_dir.split('/sequences/',1)
                new_save_dir = output_path + '/sequences/' + dir2.replace('velodyne','predictions')[:-3]+'label'
                if not os.path.exists(os.path.dirname(new_save_dir)):
                    try:
                        os.makedirs(os.path.dirname(new_save_dir))
                    except OSError as exc:
                        if exc.errno != errno.EEXIST:
                            raise
                test_pred_label = test_pred_label.astype(np.uint32)
                test_pred_label.tofile(new_save_dir)
            pbar.update(1)
    del test_grid,test_pt_fea,test_index
    pbar.close()
    print('Predicted test labels are saved in %s. Need to be shifted to original label format before submitting to the Competition website.' % output_path)
    print('Remapping script can be found in semantic-kitti-api.')


if __name__ == '__main__':
    # Testing settings
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('-d', '--data_dir', default='/home/mount/Datasets/Semantic_KITTI/dataset')
    parser.add_argument('-p', '--model_save_path', default='pretrained_weight/SemKITTI_PolarSeg.pt')
    parser.add_argument('-o', '--test_output_path', default='out/SemKITTI_test')
    parser.add_argument('-m', '--model', choices=['polar','traditional'], default='polar', help='training model: polar or traditional (default: polar)')
    parser.add_argument('-s', '--grid_size', nargs='+', type=int, default=[480,360,32], help='grid size of BEV representation (default: [480,360,32])')
    parser.add_argument('--test_batch_size', type=int, default=1, help='batch size for training (default: 1)')
    
    args = parser.parse_args()
    if not len(args.grid_size) == 3:
        raise Exception('Invalid grid size! Grid size should have 3 dimensions.')

    print(' '.join(sys.argv))
    print(args)
    main(args)
