import os

os.environ['CUDA_VISIBLE_DEVICES'] = '1'
import argparse
import gc
import numpy as np
import open3d as o3d
import torch
import torch.utils.data as Data

from dataset import ShapeNet_Heart_Slice, ShapeNet_Heart_Slice_components
from models import PCN2Brunch,PCN6Brunch,PCNNoBrunch
from visualization import plot_pcd_one_view
from metrics.metric import l1_cd, l2_cd, f_score  # , emd
import pandas as pd

CATEGORIES_PCN = ['airplane', 'cabinet', 'car', 'chair', 'lamp', 'sofa', 'table', 'vessel']
CATEGORIES_PCN_NOVEL = ['bus', 'bed', 'bookshelf', 'bench', 'guitar', 'motorbike', 'skateboard', 'pistol']


def make_dir(dir_path):
    if not os.path.exists(dir_path):
        os.makedirs(dir_path)


def export_ply(filename, points):
    pc = o3d.geometry.PointCloud()
    pc.points = o3d.utility.Vector3dVector(points)
    o3d.io.write_point_cloud(filename, pc, write_ascii=True)

def random_sample(pc, n):
    idx = np.random.permutation(pc.shape[0])
    if idx.shape[0] < n:
        idx = np.concatenate([idx, np.random.randint(pc.shape[0], size=n - pc.shape[0])])
    return pc[idx[:n]]


def test_single_category(category, model, params, save=True):
    if save:
        cat_dir = os.path.join(params.result_dir, category)
        image_dir = os.path.join(cat_dir, 'image')
        output_dir = os.path.join(cat_dir, 'output')
        make_dir(cat_dir)
        make_dir(image_dir)
        make_dir(output_dir)

    test_dataset = ShapeNet_Heart_Slice_components('../data/CTA/cta_normal/pointcloud', 'test' if params.novel else 'test', category)
    test_dataloader = Data.DataLoader(test_dataset, batch_size=params.batch_size, shuffle=False)

    index = 1
    total_l1_cd, total_l2_cd, total_f_score = 0.0, 0.0, 0.0
    total_l1_cd_component, total_l2_cd_component, total_f_score_component = 0.0, 0.0, 0.0
    total_l1_cd_slice, total_l2_cd_slice, total_f_score_slice = 0.0, 0.0, 0.0

    nameL=['']*len(test_dataloader)
    fvL =np.zeros((len(test_dataloader), 1024))
    with torch.no_grad():
        for i, data_ in enumerate(test_dataloader):
            (p_slice, c_slice, c_shape,
             lv_pc, rv_pc, aro_pc, la_pc,
             ra_pc, myo_pc, path) = data_

            p_slice, c_slice, c_shape = p_slice.to(params.device), c_slice.to(params.device), c_shape.to(params.device)
            lv_pc, rv_pc, aro_pc = lv_pc.to(params.device), rv_pc.to(params.device), aro_pc.to(params.device)
            la_pc, ra_pc, myo_pc = la_pc.to(params.device), ra_pc.to(params.device), myo_pc.to(params.device)
            # _, c_ = model(p)
            # coarse_slice_pred, coarse_shape_pred, fine_shape_pred, fine_id = model(p_slice)
            # rec,coarse_slice_pred, c_, _, fine_id = model(p)
            (rec_slice_pred,
             rotate_slice_pred,
             c_,
             coarse_lv_pred, coarse_rv_pred, coarse_aro_pred,
             coarse_la_pred, coarse_ra_pred, coarse_myo_pred,
             _component, _, feature_vector) = model(p_slice)

            nameL[i]=path[0].split('/')[-1]
            for i_fv in range(1024):
                fvL[i,i_fv]= feature_vector.detach().cpu().numpy()[0][i_fv]
            # print('fv:', feature_vector.detach().cpu().numpy().shape)

            #_=torch.cat([coarse_lv_pred, coarse_rv_pred, coarse_aro_pred,coarse_la_pred, coarse_ra_pred, coarse_myo_pred],dim=1)
            # c_shape=random_sample( c_shape_ ,len(_))
            # print(_.shape)

            # total_l1_cd += l1_cd(_,c_shape).item()
            # total_l2_cd += l2_cd(_, c_shape).item()
            # total_l1_cd_component += l1_cd(_component, c_shape).item()
            # total_l2_cd_component += l2_cd(_component, c_shape).item()
            # total_l1_cd_slice += l1_cd(rotate_slice_pred, c_slice).item()
            # total_l2_cd_slice += l2_cd(rotate_slice_pred, c_slice).item()

            # lv,la,myo
            # total_l1_cd += l1_cd(coarse_lv_pred, lv_pc).item()
            # total_l2_cd += l2_cd(coarse_lv_pred, lv_pc).item()
            # total_l1_cd_component += l1_cd(coarse_la_pred, la_pc).item()
            # total_l2_cd_component += l2_cd(coarse_la_pred, la_pc).item()
            # total_l1_cd_slice += l1_cd(coarse_myo_pred, myo_pc).item()
            # total_l2_cd_slice += l2_cd(coarse_myo_pred, myo_pc).item()

            # total_l1_cd += l1_cd(coarse_rv_pred, rv_pc).item()
            # total_l2_cd += l2_cd(coarse_rv_pred, rv_pc).item()
            # total_l1_cd_component += l1_cd(coarse_ra_pred, ra_pc).item()
            # total_l2_cd_component += l2_cd(coarse_ra_pred, ra_pc).item()
            # total_l1_cd_slice += l1_cd(coarse_aro_pred, aro_pc).item()
            # total_l2_cd_slice += l2_cd(coarse_aro_pred, aro_pc).item()
            #
            for i in range(len(c_shape)):
                index=i

                # input_pc = p_slice[i].detach().cpu().numpy()
                # output_pc = _[i].detach().cpu().numpy()
                # gt_pc = c_shape[i].detach().cpu().numpy()
                # total_f_score += f_score(output_pc, gt_pc)

                # output_pc = _component[i].detach().cpu().numpy()
                # total_f_score_component += f_score(output_pc, gt_pc)
                #
                # output_pc = rotate_slice_pred[i].detach().cpu().numpy()
                # gt_pc = c_slice[i].detach().cpu().numpy()
                # total_f_score_slice += f_score(output_pc, gt_pc)

                # lv,la,myo
                # input_pc = p_slice[i].detach().cpu().numpy()
                # output_pc = coarse_lv_pred[i].detach().cpu().numpy()
                # gt_pc = lv_pc[i].detach().cpu().numpy()
                # total_f_score += f_score(output_pc, gt_pc)
                #
                # output_pc = coarse_la_pred[i].detach().cpu().numpy()
                # gt_pc = la_pc[i].detach().cpu().numpy()
                # total_f_score_component += f_score(output_pc, gt_pc)
                #
                # output_pc = coarse_myo_pred[i].detach().cpu().numpy()
                # gt_pc = myo_pc[i].detach().cpu().numpy()
                # total_f_score_slice += f_score(output_pc, gt_pc)

                # rv,ra,aro
                # input_pc = p_slice[i].detach().cpu().numpy()
                # output_pc = coarse_rv_pred[i].detach().cpu().numpy()
                # gt_pc = rv_pc[i].detach().cpu().numpy()
                # total_f_score += f_score(output_pc, gt_pc)
                #
                # output_pc = coarse_ra_pred[i].detach().cpu().numpy()
                # gt_pc = ra_pc[i].detach().cpu().numpy()
                # total_f_score_component += f_score(output_pc, gt_pc)
                #
                # output_pc = coarse_aro_pred[i].detach().cpu().numpy()
                # gt_pc = aro_pc[i].detach().cpu().numpy()
                # total_f_score_slice += f_score(output_pc, gt_pc)


                if save:
                    plot_pcd_one_view(os.path.join(image_dir, os.path.basename(path[0])+'.png'), [p_slice[index].detach().cpu().numpy(),
                                       c_[index].detach().cpu().numpy(),
                                       rotate_slice_pred[index].detach().cpu().numpy(),
                                       c_slice[index].detach().cpu().numpy(),
                                       c_[index].detach().cpu().numpy(),
                                       coarse_lv_pred[index].detach().cpu().numpy(),
                                       coarse_rv_pred[index].detach().cpu().numpy(),
                                       coarse_aro_pred[index].detach().cpu().numpy(),
                                       coarse_la_pred[index].detach().cpu().numpy(),
                                       coarse_ra_pred[index].detach().cpu().numpy(),
                                       coarse_myo_pred[index].detach().cpu().numpy(),

                                       _component[index].detach().cpu().numpy(),
                                       _[index].detach().cpu().numpy(),
                                       c_shape[index].detach().cpu().numpy()],
                                      ['Input Slice', 'Coarse Slice', 'Rotate Slice', 'Ground Truth Slice',
                                       'Coarse Shape','lv','rv','aro','la','ra','myo','Dense component',
                                       'Dense Shape', 'Ground Truth Shape'], xlim=(-0.35, 0.35), ylim=(-0.35, 0.35),
                                      zlim=(-0.35, 0.35))
                    export_ply(os.path.join(output_dir, os.path.basename(path[0]).split('.')[0]+'_slice.ply'), rotate_slice_pred[index].detach().cpu().numpy())
                    export_ply(os.path.join(output_dir, os.path.basename(path[0]).split('.')[0] + '_componentshape.ply'), _[index].detach().cpu().numpy())
                index += 1
    # df_loss = pd.DataFrame()
    # df_loss['name'] = nameL
    # for i_fv in range(1024):
    #     df_loss['fv '+str(i_fv)] = fvL[:,i_fv]
    # fv_csv_path=os.path.join(params.result_dir, 'fv_'+category+'.csv')
    # df_loss.to_csv(fv_csv_path)

    avg_l1_cd = total_l1_cd / len(test_dataset)
    avg_l2_cd = total_l2_cd / len(test_dataset)
    avg_f_score = total_f_score / len(test_dataset)
    avg_l1_cd_com = total_l1_cd_component / len(test_dataset)
    avg_l2_cd_com = total_l2_cd_component / len(test_dataset)
    avg_f_score_com = total_f_score_component / len(test_dataset)
    avg_l1_cd_slice = total_l1_cd_slice / len(test_dataset)
    avg_l2_cd_slice = total_l2_cd_slice / len(test_dataset)
    avg_f_score_slice = total_f_score_slice / len(test_dataset)

    return (avg_l1_cd, avg_l2_cd, avg_f_score,
            avg_l1_cd_com, avg_l2_cd_com, avg_f_score_com,
            avg_l1_cd_slice, avg_l2_cd_slice, avg_f_score_slice)


def test(params, save=False):
    if save:
        make_dir(params.result_dir)

    print(params.exp_name)

    # load pretrained model
    model = PCN6Brunch(num_dense=16384, latent_dim=1024, grid_size=4).to(params.device)
    # model = PCNNoBrunch(num_dense=16384, latent_dim=1024, grid_size=4).to(params.device)

    model.load_state_dict(torch.load(params.ckpt_path))
    model.eval()

    print('\033[33m{:20s}{:20s}{:20s}{:20s}'
          '{:20s}{:20s}{:20s}'
          '{:20s}{:20s}{:20s}\033[0m'.format('Category',
                                               'L1_CD(1e-3)', 'L2_CD(1e-4)', 'FScore-0.01(%)',
                                               'COM L1_CD(1e-3)', 'COM L2_CD(1e-4)', 'COM FScore-0.01(%)',
                                               'SLICE L1_CD(1e-3)', 'SLICE L2_CD(1e-4)', 'SLICE FScore-0.01(%)'))
    print('\033[33m{:20s}{:20s}{:20s}{:20s}'
          '{:20s}{:20s}{:20s}'
          '{:20s}{:20s}{:20s}\033[0m'.format('--------', '-----------', '-----------', '--------------',
                                             '-----------', '-----------', '--------------',
                                             '-----------', '-----------', '--------------'))

    if params.category == 'all':
        l1_cds, l2_cds, fscores,l1_cds_c, l2_cds_c, fscores_c,l1_cds_s, l2_cds_s, fscores_s = list(), list(), list(),list(), list(), list(),list(), list(), list()
        categories=['a2c']#['a5c','a4c','a2c'] #['lax'
        for category in categories:
        # category='banmo'
            (avg_l1_cd, avg_l2_cd, avg_f_score,
             avg_l1_cd_com, avg_l2_cd_com, avg_f_score_com,
             avg_l1_cd_slice, avg_l2_cd_slice, avg_f_score_slice) = test_single_category(category, model, params, save)
            print('{:20s}{:<20.4f}{:<20.4f}{:<20.4f}'
                  '{:<20.4f}{:<20.4f}{:<20.4f}'
                  '{:<20.4f}{:<20.4f}{:<20.4f}'.format(category.title(),
                                                       1e3 * avg_l1_cd, 1e4 * avg_l2_cd,1e2 * avg_f_score,
                                                       1e3 * avg_l1_cd_com, 1e4 * avg_l2_cd_com, 1e2 * avg_f_score_com,
                                                       1e3 * avg_l1_cd_slice, 1e4 * avg_l2_cd_slice, 1e2 * avg_f_score_slice))
            l1_cds.append(avg_l1_cd)
            l2_cds.append(avg_l2_cd)
            fscores.append(avg_f_score)
            l1_cds_c.append(avg_l1_cd_com)
            l2_cds_c.append(avg_l2_cd_com)
            fscores_c.append(avg_f_score_com)
            l1_cds_s.append(avg_l1_cd_slice)
            l2_cds_s.append(avg_l2_cd_slice)
            fscores_s.append(avg_f_score_slice)

        print('\033[33m{:20s}{:20s}{:20s}{:20s}\033[0m'.format('--------', '-----------', '-----------',
                                                               '--------------'))
        print('\033[32m{:20s}{:<20.4f}{:<20.4f}{:<20.4f}'
              '{:<20.4f}{:<20.4f}{:<20.4f}'
              '{:<20.4f}{:<20.4f}{:<20.4f}\033[0m'.format('Average', np.mean(l1_cds) * 1e3, np.mean(l2_cds) * 1e4, np.mean(fscores) * 1e2,
                                                          np.mean(l1_cds_c) * 1e3, np.mean(l2_cds_c) * 1e4, np.mean(fscores_c) * 1e2,
                                                          np.mean(l1_cds_s) * 1e3, np.mean(l2_cds_s) * 1e4, np.mean(fscores_s) * 1e2))
    else:
        avg_l1_cd, avg_l2_cd, avg_f_score = test_single_category(params.category, model, params, save)
        print('{:20s}{:<20.4f}{:<20.4f}{:<20.4f}'.format(params.category.title(),
                                                   1e3 * avg_l1_cd, 1e4 * avg_l2_cd,1e2 * avg_f_score,
                                                  ))
    gc.collect()


# TYPE='PCN-yindao-Slice-240830'
# TYPE='PCN-Slice-yindao-240902'
# TYPE = 'PCN-Rec-Slice-component-240903'

# TYPE='Rec-Slice-6sigcomponent-compare-0929'
# TYPE='Slice-6component-0929'
# TYPE='Slice-nosig-0929'
TYPE='Rec-Slice-6sigcomponent-compare-0929'
# TYPE='Rec-Slice-6component-0929'
if __name__ == '__main__':
    parser = argparse.ArgumentParser('Point Cloud Completion Testing')
    parser.add_argument('--exp_name', type=str, help='Tag of experiment')
    parser.add_argument('--result_dir', type=str,
                        default='./log/'+TYPE+'/results_component', help='Results directory')
    parser.add_argument('--ckpt_path', type=str,
                        default='./log/'+TYPE+'/all/checkpoints/best_l1_cd.pth',
                        help='The path of pretrained model.')
    parser.add_argument('--category', type=str, default='all', help='Category of point clouds')
    parser.add_argument('--batch_size', type=int, default=1, help='Batch size for data loader')
    parser.add_argument('--num_workers', type=int, default=1, help='Num workers for data loader')
    parser.add_argument('--device', type=str, default='cuda:0', help='Device for testing')
    parser.add_argument('--save', type=bool, default='best_l1_cd', help='Saving test result')
    parser.add_argument('--novel', type=bool, default=False, help='unseen categories for testing')
    params = parser.parse_args()

    test(params, params.save)
