from models import models_ae
from models import modeling_vqvae

from timm.models import create_model

from models.noise2noise_dataset import Noise2NoiseDatasetv2
import torch
import wandb
from tqdm import tqdm
from models.misc import MetricLogger
from scipy.spatial import cKDTree as KDTree
from models.utils import validate_tensors_require_grad, matching_cube_batch, matching_cube_batch_aligned
import json
import os
from tempfile import NamedTemporaryFile
import argparse
import numpy as np
import mcubes
import trimesh
from scipy.stats import wasserstein_distance
import math
import point_cloud_utils as pcu
from plyfile import PlyData, PlyElement


def normal_consistency_mesh(P, Q, normals_P, normals_Q):
    
    tree = KDTree(Q)

    nc_sum = 0.0
    count = 0

    for i, p in enumerate(P):
        _, idx = tree.query(p)
        q = Q[idx]

        n_p = normals_P[i]
        n_q = normals_Q[idx]
        dot_product = np.clip(np.dot(n_p, n_q), -1.0, 1.0)

        nc_sum += dot_product
        count += 1

    NC = nc_sum / count if count > 0 else 0.0
    return NC


def earth_movers_distance(cloud1, cloud2):
    
    emd_x = wasserstein_distance(cloud1[:,0], cloud2[:,0])
    emd_y = wasserstein_distance(cloud1[:,1], cloud2[:,1])
    emd_z = wasserstein_distance(cloud1[:,2], cloud2[:,2])
    return (emd_x + emd_y + emd_z) / 3

def save_points_as_ply(points: np.ndarray, save_path: str) -> None:
    
    vertices = np.zeros(points.shape[0], dtype=[
        ('x', 'f4'), ('y', 'f4'), ('z', 'f4')
    ])
    vertices['x'] = points[:, 0]
    vertices['y'] = points[:, 1]
    vertices['z'] = points[:, 2]

    PlyData([PlyElement.describe(vertices, 'vertex')], text=True).write(save_path)

def evaluate(config):
    # prepare the model
    model = models_ae.ae_d512_m512()

    model.load_state_dict(torch.load(config['model_path'], map_location='cuda')['model'])

    model.to(config['device'])
    model.eval()

    # prepare the dataset
    dataset = Noise2NoiseDatasetv2(
        **config['dataset_config']
    )
    metric_logger = MetricLogger()

    test_length = len(dataset.test_model_list)

    mesh_results = []

    # 4 metric
    # 1. IoU
    # 2. Chamfer Distance
    # 3. F-Score
    # 4. Normal Consistency
    for i in tqdm(range(test_length)):
        # NOTE: I used the near-surface points on abc, shapenet is volume
        gt_points, noise1_points, queries, surface, labels, normals, model_name = dataset.get_eval_item(i, use_train_data=False, use_one_shape_only=False, use_test_data=True, use_vol_points=True, get_name=True)

        queries = queries.unsqueeze(0).to(config['device'])
        labels = labels.to(config['device'])

        point_cloud = gt_points if config['use_gt_points'] else noise1_points
        point_cloud = point_cloud.unsqueeze(0).to(config['device'])

        with torch.no_grad():
            noise_logits = model(point_cloud, queries)['logits']

            noise_mask = (noise_logits > 0.0).float().squeeze(0)
            noise_intersection = (noise_mask * labels).sum()
            noise_union = (noise_mask + labels).gt(0).sum()
            noise_iou = noise_intersection * 1.0 / noise_union

            # following the 3dshape2vec evaluation logic
            density = 128
            gap = 2. / density
            x = np.linspace(-1, 1, density+1)
            y = np.linspace(-1, 1, density+1)
            z = np.linspace(-1, 1, density+1)
            xv, yv, zv = np.meshgrid(x, y, z)
            grid = torch.from_numpy(np.stack([xv, yv, zv]).astype(
                np.float32)).view(3, -1).transpose(0, 1)[None].cuda()

            output = model(point_cloud, grid)['logits']
            volume = output.view(density+1, density+1, density+1).permute(1, 0, 2).cpu().numpy()

            verts, faces = mcubes.marching_cubes(volume, 0)
            verts *= gap
            verts -= 1.
            mesh = trimesh.Trimesh(verts, faces)

            mesh_results.append({
                'index': i,
                'mesh': mesh,
                'model_name': model_name
            })


            # chamfer distance
            pred = mesh.sample(100000)

            gt_points = np.asarray(surface.cpu().numpy()).astype(np.float32)  
            _, gt_normals = pcu.estimate_point_cloud_normals_knn(gt_points, 32)

            pred_points = np.asarray(pred).astype(np.float32)  
            _, pred_normals = pcu.estimate_point_cloud_normals_knn(pred_points, 32)

            # save_points_as_ply(gt_points, 'gt_points.ply')
            # save_points_as_ply(pred_points, 'pred_points.ply')
            # exit()

            nc_mesh = 0


            tree = KDTree(pred)
            dist, _ = tree.query(surface.cpu().numpy())
            d1 = dist
            gt_to_gen_chamfer = np.mean(dist)

            tree = KDTree(surface.cpu().numpy())
            dist, _ = tree.query(pred)
            d2 = dist
            gen_to_gt_chamfer = np.mean(dist)

            cd = gt_to_gen_chamfer + gen_to_gt_chamfer

            # calculate f-score
            th = 0.02
            if len(d1) and len(d2):
                recall = float(sum(d < th for d in d2)) / float(len(d2))
                precision = float(sum(d < th for d in d1)) / float(len(d1))

                if recall + precision > 0:
                    fscore = 2 * recall * precision / (recall + precision)
                else:
                    fscore = 0.0
            else:
                fscore = 0.0

            # emd
            emd = 0

        # normal consistency
        with torch.enable_grad():
            compare_points = surface.clone().requires_grad_(True).unsqueeze(0).to(config['device'])
            validate_tensors_require_grad(compare_points)

            compare_logits = model(point_cloud, compare_points)['logits']

            if not compare_logits.requires_grad:
                compare_logits = compare_logits.detach().clone().requires_grad_(True)

            compare_grad = torch.autograd.grad(
                compare_logits.sum(),
                compare_points,
                create_graph=False,
                retain_graph=False
            )[0]

            compare_normals = compare_grad / (compare_grad.norm(dim=-1, keepdim=True) + 1e-6)

            normal_consistency = torch.abs(torch.sum(compare_normals.cpu() * normals, dim=-1)).mean().item()

        del noise_logits, compare_logits, compare_grad, compare_normals
        torch.cuda.empty_cache()
        metric_logger.update_metrics(metrics={
            'iou': noise_iou,
            'chamfer': cd,
            'f1': fscore,
            'normal_consistency': normal_consistency,
            'normal_consistency_mesh': nc_mesh,
            'emd': emd,
        })

    result = metric_logger.average()

    print("run_name: ", config['run_name'])
    print(f"IoU: {result['iou']:.4f}")
    print(f"Chamfer: {result['chamfer']:.4f}")
    print(f"F1: {result['f1']:.4f}")
    print(f"Normal Consistency: {result['normal_consistency']:.4f}")
    print(f"Normal Consistency Mesh: {result['normal_consistency_mesh']:.4f}")
    print(f"EMD: {result['emd']:.4f}")
    return mesh_results



def evaluate_3DILG(config):
    # prepare the model
    model = create_model(
        "vqvae_512_1024_2048",
        pretrained=False,
    )

    # load the model
    model.load_state_dict(torch.load(config['model_path'], map_location='cuda')['model'])
    model.to(config['device'])
    model.eval()

    # prepare the dataset
    dataset = Noise2NoiseDatasetv2(
        **config['dataset_config'],
    )
    metric_logger = MetricLogger()

    test_length = len(dataset.test_model_list)

    mesh_results = []

    # 4 metric
    # 1. IoU
    # 2. Chamfer Distance
    # 3. F-Score
    # 4. Normal Consistency
    for i in range(test_length):
        if i%100 == 0:
            print(f"evaluating {i} / {test_length}")

        # NOTE: I used the near-surface points on abc, shapenet is volume
        gt_points, noise1_points, queries, surface, labels, normals, model_name = dataset.get_eval_item(i, use_train_data=False, use_one_shape_only=False, use_test_data=True, use_vol_points=True, get_name=True)

        queries = queries.unsqueeze(0).to(config['device'])
        labels = labels.to(config['device'])

        point_cloud = gt_points if config['use_gt_points'] else noise1_points
        point_cloud = point_cloud.unsqueeze(0).to(config['device'])

        N = 50000
        # following the 3dshape2vec evaluation logic
        density = 128
        gap = 2. / density
        x = np.linspace(-1, 1, density+1)
        y = np.linspace(-1, 1, density+1)
        z = np.linspace(-1, 1, density+1)
        xv, yv, zv = np.meshgrid(x, y, z)
        grid = torch.from_numpy(np.stack([xv, yv, zv]).astype(
            np.float32)).view(3, -1).transpose(0, 1)[None].cuda()

        with torch.no_grad():
            _, latents, centers_quantized, _, _, _ = model.encode(point_cloud)
            centers = centers_quantized.float() / 255.0 * 2 - 1
            noise_logits = torch.cat([model.decoder(latents, centers, queries[:, i*N:(i+1)*N])[0] for i in range(math.ceil(grid.shape[1]/N))], dim=1)


            noise_mask = (noise_logits > 0.0).float().squeeze(0)
            noise_intersection = (noise_mask * labels).sum()
            noise_union = (noise_mask + labels).gt(0).sum()
            noise_iou = noise_intersection * 1.0 / noise_union




            output = torch.cat([model.decoder(latents, centers, grid[:, i*N:(i+1)*N])[0] for i in range(math.ceil(grid.shape[1]/N))], dim=1)
            volume = output.view(density+1, density+1, density+1).permute(1, 0, 2).cpu().numpy()

            verts, faces = mcubes.marching_cubes(volume, 0)
            verts *= gap
            verts -= 1.
            mesh = trimesh.Trimesh(verts, faces)



            mesh_results.append({
                'index': i,
                'mesh': mesh,
                'model_name': model_name
            })


            # chamfer distance
            pred = mesh.sample(100000)


            tree = KDTree(pred)
            dist, _ = tree.query(surface.cpu().numpy())
            d1 = dist
            gt_to_gen_chamfer = np.mean(dist)

            tree = KDTree(surface.cpu().numpy())
            dist, _ = tree.query(pred)
            d2 = dist
            gen_to_gt_chamfer = np.mean(dist)

            cd = gt_to_gen_chamfer + gen_to_gt_chamfer

            # calculate f-score
            th = 0.02
            if len(d1) and len(d2):
                recall = float(sum(d < th for d in d2)) / float(len(d2))
                precision = float(sum(d < th for d in d1)) / float(len(d1))

                if recall + precision > 0:
                    fscore = 2 * recall * precision / (recall + precision)
                else:
                    fscore = 0.0
            else:
                fscore = 0.0

            # emd
            emd = 0

        # normal consistency
        with torch.enable_grad():
            compare_points = surface.clone().requires_grad_(True).unsqueeze(0).to(config['device'])
            validate_tensors_require_grad(compare_points)

            compare_logits = torch.cat([model.decoder(latents, centers, compare_points[:, i*N:(i+1)*N])[0] for i in range(math.ceil(compare_points.shape[1]/N))], dim=1)

            if not compare_logits.requires_grad:
                compare_logits = compare_logits.detach().clone().requires_grad_(True)

            compare_grad = torch.autograd.grad(
                compare_logits.sum(),
                compare_points,
                create_graph=False,
                retain_graph=False
            )[0]

            compare_normals = compare_grad / (compare_grad.norm(dim=-1, keepdim=True) + 1e-6)

            normal_consistency = torch.abs(torch.sum(compare_normals.cpu() * normals, dim=-1)).mean().item()

        del noise_logits, compare_logits, compare_grad, compare_normals
        torch.cuda.empty_cache()
        metric_logger.update_metrics(metrics={
            'iou': noise_iou,
            'chamfer': cd,
            'f1': fscore,
            'normal_consistency': normal_consistency,
            'emd': emd,
        })

    result = metric_logger.average()

    print("run_name: ", config['run_name'])
    print(f"IoU: {result['iou']:.4f}")
    print(f"Chamfer: {result['chamfer']:.4f}")
    print(f"F1: {result['f1']:.4f}")
    print(f"Normal Consistency: {result['normal_consistency']:.4f}")
    print(f"EMD: {result['emd']:.4f}")
    return mesh_results



if __name__ == "__main__":
    """
    python evaluate.py --config config/test_config/test_orginal_noise.json --model ae -c 02691156
    python evaluate.py --config config/test_config/test_orginal_noise_3DL.json --model 3DILG -c 02691156
    """


    parser = argparse.ArgumentParser()
    parser.add_argument('--config', type=str, default='config/test_config/test_orginal_noise.json')
    parser.add_argument('-c','--category', nargs='+', help='<Required> Set flag', required=True)
    parser.add_argument('--save_local', type=bool, help='Save meshes locally', default=False)
    parser.add_argument('--model', type=str, default='ae', choices=['ae', '3DILG'], help='<Required> Set flag', required=True)
    args = parser.parse_args()

    with open(args.config, 'r') as f:
        config = json.load(f)

    if len(args.category) > 0:
        config['dataset_config']['categories'] = args.category
        config['run_name'] = config['run_name'] +  '_'.join(args.category)


    wandb.init(project="sdf2sdf_evaluation_full", name=config['run_name'])

    if args.save_local:
        mesh_result_path = os.path.join(config['result_path'], config['run_name'])

    if args.model == 'ae':
        print("evaluate 3dshape2vec model")
        print(config)
        mesh_results = evaluate(config)
    elif args.model == '3DILG':
        print("evaluate 3DILG model")
        mesh_results = evaluate_3DILG(config)

    # use tmp file to save mesh results
    for i, meshObject in tqdm(enumerate(mesh_results), total=len(mesh_results), desc="save meshes"):
        with NamedTemporaryFile(suffix='.obj', delete=True) as tmpfile:
            meshObject['mesh'].export(tmpfile.name)
            wandb.log({
                f"mesh/{meshObject['model_name']}_{config['run_name']}": wandb.Object3D(tmpfile.name)
            })

            if args.save_local:
                os.makedirs(mesh_result_path, exist_ok=True)
                meshObject['mesh'].export(f"{mesh_result_path}/{meshObject['model_name']}.ply")
