import argparse
import numpy as np
import torch
import os
import random
import glob
from tqdm import tqdm
import kaolin as kal
import point_cloud_utils as pcu
from PIL import Image
import lpips
import torch.nn as nn
from torchvision import utils
from torchvision import transforms, utils
from torch.utils import data



def seed_everything(seed):
    if seed < 0:
        return
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)


def sample_point_with_mesh_name(name, n_sample=2048, normalized_scale=1.0):
    if 'npz' in name:
        pcd = np.load(name)['pcd']
        if pcd.shape[0] != 1:  # The first dimension is 1
            pcd = pcd[np.newaxis, :, :]
        pcd = pcd[:, :n_sample, :]
        return torch.from_numpy(pcd).float().cuda()

    if 'ply' in name:
        v = pcu.load_mesh_v(name)
        point_clouds = np.random.permutation(v)[:n_sample, :]
        scale = 0.9
        if 'chair' in name:
            scale = 0.7
        if 'animal' in name:
            scale = 0.7
        if 'car' in name:
            normalized_scale = 0.9  # We sample the car using 0.9 surface
        point_clouds = point_clouds / scale * normalized_scale  # Make them in the same scale
        return torch.from_numpy(point_clouds).float().cuda().unsqueeze(dim=0)

    mesh_1 = kal.io.obj.import_mesh(name)
    if mesh_1.vertices.shape[0] == 0:
        return None
    vertices = mesh_1.vertices.cuda()
    scale = (vertices.max(dim=0)[0] - vertices.min(dim=0)[0]).max()
    mesh_v1 = vertices / scale * normalized_scale
    mesh_f1 = mesh_1.faces.cuda()
    points, _ = kal.ops.mesh.sample_points(mesh_v1.unsqueeze(dim=0), mesh_f1, n_sample)
    return points.cuda()

def pairwise_chamfer_distance(sample_pcs, n_sample):
    all_sample_pcs = []
    normalized_scale = 1.0
    for name in tqdm(sample_pcs):
        all_sample_pcs.append(sample_point_with_mesh_name(name, n_sample, normalized_scale=normalized_scale))
    all_sample_pcs = [p for p in all_sample_pcs if p is not None]
    all_sample_pcs = torch.cat(all_sample_pcs, dim=0)
    pairwise_cd = []
    for i_ref_sample in tqdm(range(len(all_sample_pcs))):
        sample_a = all_sample_pcs[i_ref_sample:i_ref_sample+1]
        for j_ref_sample in range(i_ref_sample+1,len(all_sample_pcs)):
            sample_b = all_sample_pcs[j_ref_sample:j_ref_sample+1]
            chamfer = kal.metrics.pointcloud.chamfer_distance(sample_a, sample_b) * 1e+3
            pairwise_cd.append(chamfer)
    pairwise_cd = torch.cat(pairwise_cd, dim=0)
    return torch.mean(pairwise_cd), torch.std(pairwise_cd)

def intra_chamfer_distance(ref_pcs, sample_pcs, n_sample, group_size = 100):
    all_sample_pcs = []
    normalized_scale = 1.0
    for name in tqdm(sample_pcs):
        all_sample_pcs.append(sample_point_with_mesh_name(name, n_sample, normalized_scale=normalized_scale))
    all_sample_pcs = [p for p in all_sample_pcs if p is not None]
    all_sample_pcs = torch.cat(all_sample_pcs, dim=0)
    
    all_ref_pcs = []
    for name in tqdm(ref_pcs):
        all_ref_pcs.append(sample_point_with_mesh_name(name, n_sample, normalized_scale=normalized_scale))
    all_ref_pcs = [p for p in all_ref_pcs if p is not None]
    all_ref_pcs = torch.cat(all_ref_pcs, dim=0)
    
    all_sample_pcs_center_index = []
    for i in tqdm(range(len(all_sample_pcs))):
        sample_pcs = all_sample_pcs[i:i+1]
        cd = 100000
        index = 0
        for j in range(len(all_ref_pcs)):
            ref_pcs = all_ref_pcs[j:j+1]
            new_cd = kal.metrics.pointcloud.chamfer_distance(sample_pcs, ref_pcs) * 1e+3
            if cd > new_cd:
                cd = new_cd
                index = j
        all_sample_pcs_center_index.append(index)

    intra_chamfer_distance = []
    for i in tqdm(range(len(all_ref_pcs))):
        kind_sample_pcs = []
        ref_pcs = all_ref_pcs[i]
        for k in range(len(all_sample_pcs)):
            if len(kind_sample_pcs) > group_size:
                break
            if all_sample_pcs_center_index[k] == i:
                kind_sample_pcs.append(all_sample_pcs[k:k+1])
        
        if len(kind_sample_pcs) > 1:
            kind_sample_pcs = [p for p in kind_sample_pcs if p is not None]
            kind_sample_pcs = torch.cat(kind_sample_pcs, dim=0)

            part_pairwise_cd = []
    
            for m in range(len(kind_sample_pcs)):
                sample_a = kind_sample_pcs[m:m+1]
                for n in range(m+1,len(kind_sample_pcs)):
                    sample_b = kind_sample_pcs[n:n+1]
                    chamfer = kal.metrics.pointcloud.chamfer_distance(sample_a, sample_b) * 1e+3
                    part_pairwise_cd.append(chamfer)
    
            part_pairwise_cd = torch.cat(part_pairwise_cd, dim=0)
            intra_chamfer_distance.append(part_pairwise_cd)

    intra_chamfer_distance = torch.cat(intra_chamfer_distance, dim=0)

    return torch.mean(intra_chamfer_distance), torch.std(intra_chamfer_distance)

def chamfer_distance(ref_pcs, sample_pcs, batch_size):
    all_rec_pcs = []
    n_sample = 2048
    normalized_scale = 1.0
    for name in tqdm(ref_pcs):
        all_rec_pcs.append(sample_point_with_mesh_name(name, n_sample, normalized_scale=normalized_scale))
    all_sample_pcs = []
    for name in tqdm(sample_pcs):
        # This is generated
        all_sample_pcs.append(sample_point_with_mesh_name(name, n_sample, normalized_scale=normalized_scale))
    all_rec_pcs = [p for p in all_rec_pcs if p is not None]
    all_sample_pcs = [p for p in all_sample_pcs if p is not None]
    all_rec_pcs = torch.cat(all_rec_pcs, dim=0)
    all_sample_pcs = torch.cat(all_sample_pcs, dim=0)
    all_cd = []
    for i_ref_p in tqdm(range(len(ref_pcs))):
        ref_p = all_rec_pcs[i_ref_p]
        cd_lst = []
        for sample_b_start in range(0, len(sample_pcs), batch_size):
            sample_b_end = min(len(sample_pcs), sample_b_start + batch_size)
            sample_batch = all_sample_pcs[sample_b_start:sample_b_end]

            batch_size_sample = sample_batch.size(0)
            chamfer = kal.metrics.pointcloud.chamfer_distance(
                ref_p.unsqueeze(dim=0).expand(batch_size_sample, -1, -1),
                sample_batch) * 1e+3
            cd_lst.append(chamfer)
        cd_lst = torch.cat(cd_lst, dim=0)
        all_cd.append(cd_lst.unsqueeze(dim=0))
    all_cd = torch.cat(all_cd, dim=0)
    return all_cd


def compute_all_metrics(sample_pcs, ref_pcs, batch_size, save_name=None):
    M_rs_cd = chamfer_distance(ref_pcs, sample_pcs, batch_size)
    import pickle
    pickle.dump(M_rs_cd.data.cpu().numpy(), open(save_name, 'wb'))


def evaluate(args):
    # Set the random seed
    seed_everything(41)
    gen_path = args.gen_path
    if args.use_npz:
        gen_models = glob.glob(os.path.join(gen_path, '*.npz'))
    else:
        gen_models = glob.glob(os.path.join(gen_path, '*.obj'))

    train_path = args.train_data_path
    train_models = []
    files = os.listdir(train_path)
    if args.use_npz:
        train_models = glob.glob(os.path.join(train_path, '*.npz'))
        #for file in files:
        #    train_models.append(os.path.join(train_path, file, 'model.npz'))
    else:
        train_models = glob.glob(os.path.join(train_path, '*.obj'))
    train_models = sorted(train_models)
    
    gen_models = gen_models[:args.n_shape]
    with torch.no_grad():
        if args.eval_cd:
            #compute_all_metrics(gen_models, train_models, args.batch_size, args.save_name)
            M_rs_cd = chamfer_distance(train_models, gen_models, args.batch_size)
            print("CD:", np.mean(M_rs_cd.data.cpu().numpy()))
        else:
            intra_cd, intra_cd_var = intra_chamfer_distance(train_models, gen_models, args.n_sample)
            print("Intra Chamfer Distance:", intra_cd)
            print("Intra Chamfer Distance Std:", intra_cd_var)
            pairwise_cd, pairwise_cd_var = pairwise_chamfer_distance(gen_models, args.n_sample)
            print("Pairwise Chamfer Distance:", pairwise_cd)
            print("Pairwise Chamfer Distance Std:", pairwise_cd_var)
        

def lpips_eval(args):
    device = 'cuda'
    
    with torch.no_grad():
        lpips_fn = lpips.LPIPS(net='vgg').to(device)
        preprocess = transforms.Compose([
            transforms.Resize([1024, 1024]),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
        ])
    
    gen_path = args.gen_images_path + '/img'
    gen_files = os.listdir(gen_path)
    gen_files.sort()
    
    rgb_data = []
    data_path = args.train_images_path
    data_files = os.listdir(data_path)
    for k in tqdm(range(len(data_files))):
        file = data_files[k]
        rgb_data_part = []
        rgb_imgs = os.listdir(os.path.join(data_path,file))
        rgb_imgs.remove('transforms.json')
        rgb_imgs.remove('mask')
        rgb_imgs.sort()
        for m in range(0,len(rgb_imgs),3):
            img = preprocess(Image.open(os.path.join(data_path,file,rgb_imgs[m]))).to(device)
            rgb_data_part.append(img)
        rgb_data_part = torch.unsqueeze(torch.cat(rgb_data_part,dim=0),dim=0)
        rgb_data.append(rgb_data_part)

    rgb_data = torch.cat(rgb_data,dim=0).to(device)

    gen_data_center_index = []
    for i in tqdm(range(len(gen_files))):
        image_files = os.listdir(os.path.join(gen_path, gen_files[i]))
        image_files.sort()
        img11 = preprocess(Image.open(os.path.join(gen_path, gen_files[i], image_files[6]))).to(device)
        lpips_distance = 100
        lpips_index = 0
        for j in range(rgb_data.shape[0]):
            lpips_distance_new = lpips_fn(img11,rgb_data[j,6])
            if lpips_distance_new < lpips_distance:
                lpips_distance = lpips_distance_new
                lpips_index = j
        gen_data_center_index.append(lpips_index)
        
    
    intra_lpips = []
    for i in range(rgb_data.shape[0]):
        part_data = []
        for j in range(len(gen_data_center_index)):
            if gen_data_center_index[j]==i and len(part_data) < 50:
                image_files = os.listdir(os.path.join(gen_path, gen_files[j]))
                image_files.sort()
                gen_data_part = []
                for n in range(0,len(image_files),3):
                    img = torch.unsqueeze(preprocess(Image.open(os.path.join(gen_path,gen_files[j],image_files[n]))),dim=0)
                    gen_data_part.append(img)
                gen_data_part = torch.unsqueeze(torch.cat(gen_data_part, dim=0), dim=0)
                part_data.append(gen_data_part)
        if len(part_data)>1:
            part_data = torch.cat(part_data,dim=0).to(device)
            part_lpips = []
            for m in range(len(part_data)):
                for n in range(m+1,len(part_data)):
                    part_part_data = []
                    for l in range(4):
                        part_part_data.append(lpips_fn(part_data[m,l],part_data[n,l]))
                    part_part_data = torch.tensor(part_part_data)
                    part_lpips.append(part_part_data.mean())
            part_lpips = torch.tensor(part_lpips)
            intra_lpips.append(part_lpips.mean())
    
    intra_lpips = torch.tensor(intra_lpips)

    print("Intra LPIPS:", intra_lpips.mean())
    print("Intra LPIPS Std:", intra_lpips.std())
    
    gen_files_part = gen_files[:100]
    
    pairwise_lpips = []
    
    for i in tqdm(range(len(gen_files_part))):      
        image_files = os.listdir(os.path.join(gen_path, gen_files_part[i]))
        image_files.sort()       
        for j in tqdm(range(i+1,len(gen_files_part))):
            image_files2 = os.listdir(os.path.join(gen_path, gen_files_part[j]))
            image_files2.sort()
            part_pairwise_lpips = []
            for k in range(0,24,3):
                img1 = preprocess(Image.open(os.path.join(gen_path, gen_files_part[i], image_files[k]))).to(device)
                img2 = preprocess(Image.open(os.path.join(gen_path, gen_files_part[j], image_files2[k]))).to(device)
                part_pairwise_lpips.append(lpips_fn(img1,img2))
            
            part_pairwise_lpips = torch.tensor(part_pairwise_lpips)
            pairwise_lpips.append(part_pairwise_lpips.mean())
            
            torch.cuda.empty_cache()
                
    pairwise_lpips = torch.tensor(pairwise_lpips)

    print("Pairwise LPIPS:", pairwise_lpips.mean())
    print("Pairwise LPIPS Std", pairwise_lpips.std())
    

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--save_name", type=str, required=False, help="path to the save results")
    parser.add_argument("--gen_path", type=str, required=False, help="path to the generated shapes")
    parser.add_argument("--train_data_path", type=str, required=False, help="path to the original training data")
    parser.add_argument("--gen_images_path", type=str, required=False, help="path to rendered generated images")
    parser.add_argument("--train_images_path", type=str, required=False, help="path to rendered training images")
    parser.add_argument("--n_points", type=int, default=2048, help="Number of points used for evaluation")
    parser.add_argument("--batch_size", type=int, default=50, help="batch size to compute chamfer distance")
    parser.add_argument("--n_shape", type=int, default=1000, help="number of shapes for evaluations")
    parser.add_argument("--use_npz", type=bool, default=False, help="whether the generated shape is npz or not")
    parser.add_argument("--n_sample", type=int, default=2048, help="n_sample")
    parser.add_argument("--eval_cd", type=bool, default=False, help="eval chamfer distance or eval chamfer distance diversity")
    args = parser.parse_args()

    if args.eval_cd:
        evaluate(args)
    else:
        #evaluate(args)
        lpips_eval(args)
    