'''
Functions for computing the metrics shown in the paper -> psnr, lpips, ssim, fid
'''
import numpy as np
import subprocess
import lpips
from torchmetrics.functional import structural_similarity_index_measure as ssim
from pytorch3d.loss import chamfer_distance
import open3d as o3d
import torch

'''
Function for computing the psnr between the pred and gt image
images are normalized in range [0-1]
'''
def psnr(pred, truth):
    psnr = -10 * np.log10(np.mean((pred - truth) ** 2))

    return psnr

'''
Function for computing the ssim between the pred and gt image
'''
def structure_sim(norm_img, norm_pred_img):
    return ssim(norm_img, norm_pred_img, size_average=False)


'''
Function for computing the lpips metric between pred and gt image
'''
def clpips(fn, truth, pred, device='cpu'):
    fn = lpips.LPIPS(net='alex').eval().to(device)
    
    return fn(truth.to(device), pred.to(device), normalize=True).item()


'''
Function for computing the fid between pred and gt dirs
'''
def fid(truth_dir, pred_dir):
    os_command = 'python -m pytorch_fid {} {} --device cuda:{}'.format(pred_dir, truth_dir, 0)

    p = subprocess.Popen(os_command, stdout=subprocess.PIPE, shell=True)
    out, _ = p.communicate()

    fid = out.decode().split(' ')[2].strip()

    print(fid)


'''
Function for computing the CD between pred and gt meshes
'''
def cd(pred_mesh_path, gt_mesh_path):
    
    gt_mesh = o3d.io.read_triangle_mesh(gt_mesh_path)
    gt_mesh.compute_vertex_normals()

    pred_mesh  = o3d.io.read_triangle_mesh(pred_mesh_path)
    pred_mesh.compute_vertex_normals()

    no_points = 4096

    instant_pcd =gt_mesh.sample_points_poisson_disk(number_of_points = no_points)
    instant_points = torch.from_numpy(np.asarray(instant_pcd.points)).to("cuda:2").reshape(1,no_points,3)
    instant_points = instant_points-torch.mean(instant_points,dim=-2).unsqueeze(-2)

    meta_pcd    = pred_mesh.sample_points_poisson_disk(number_of_points = no_points)
    meta_points = torch.from_numpy(np.asarray(meta_pcd.points)).to("cuda:2").reshape(1,no_points,3)
    meta_points = meta_points-torch.mean(meta_points,dim=-2).unsqueeze(-2)
    chamfer_loss = chamfer_distance(instant_points,meta_points)[0]

    # print(chamfer_loss.item())

    return chamfer_loss.item()