import jax.numpy as jnp
from jax import vmap, jit, value_and_grad
import numpy as np
import igl
import os
from joblib import Parallel, delayed
import multiprocessing
import polyscope as ps


@jit
def normalize(x):
    return x / (jnp.linalg.norm(x, axis=-1, keepdims=True) + 1e-6)


class ShapeSDFSampler:

    def __init__(self, surface_ratio=0.6, close_ratio=0.3, sigma=1e-4): # 5e-2
        
        # assert idx < len(chair_ids)
        # chair_id = chair_ids[idx]

        model_path = os.path.join('./lucy.ply')
        V, F = igl.read_triangle_mesh(model_path)
        R = np.array([[1,0,0],[0,0,1],[0,1,0]])
        V = (R@V.transpose()).transpose()
        self.V = V
        max_samples = np.max(self.V,axis=0)
        min_samples = np.min(self.V,axis=0)

        center = (max_samples+min_samples)/2
        scale = np.max(max_samples-center)
        self.V = (self.V-center)/scale

        self.F = F
        self.surface_ratio = surface_ratio
        self.close_ratio = close_ratio
        self.sigma = sigma

    def sample_sdf_igl(self, x):
        return igl.signed_distance(x, self.V, self.F)[0]

    def sample_uniform(self, sample_size):
        
        sample_size_full = int(sample_size)
        n_surface = int(sample_size_full * self.surface_ratio)
        n_free = sample_size_full - n_surface

        bary, f_id,_ = igl.random_points_on_mesh(n_surface, self.V, self.F)
        surface_samples = np.sum(bary[..., None] * self.V[self.F[f_id]], 1)

        degen_n = normalize(np.array([1., 1., 1.]))[None, ...]
        FN = igl.per_face_normals(self.V, self.F, np.float64(degen_n))
        normals_ = FN[f_id]
        free_samples = np.random.uniform(low=-1, high=1, size=(n_free, 3))

        samples_full = np.vstack([surface_samples, free_samples])
        normals_full = np.zeros_like(samples_full)
        normals_full[:surface_samples.shape[0],:] = normals_

        sample_indices = np.random.permutation(sample_size_full)
        samples = samples_full[sample_indices]
        normals = normals_full[sample_indices]

        sdf_vals, _, _ = igl.signed_distance(np.array(samples), self.V, self.F)

        return samples, np.array(sdf_vals), normals, samples_full

    def sample_importance(self, sample_size, multiplier=10., beta=1.5):

        sample_size_full = int(sample_size * multiplier)
        n_surface = int(sample_size_full * self.surface_ratio)
        n_close = int(sample_size_full * self.close_ratio)
        n_free = sample_size_full - (n_surface + n_close)

        bary, f_id,_ = igl.random_points_on_mesh(n_surface, self.V, self.F)
        surface_samples = np.sum(bary[..., None] * self.V[self.F[f_id]], 1)

        degen_n = normalize(np.array([1., 1., 1.]))[None, ...]
        FN = igl.per_face_normals(self.V, self.F, np.float64(degen_n))

        normals_ = FN[f_id]

        surface_samples += self.sigma * np.random.normal(size=(n_surface,
                                                               1)) * FN[f_id]

        bary, f_id,_ = igl.random_points_on_mesh(n_close, self.V, self.F)

        close_samples = np.sum(
            bary[..., None] * self.V[self.F[f_id]],
            1) + 2. * self.sigma * np.random.normal(size=(n_close, 3))

        free_samples = np.random.uniform(low=-0.5, high=0.5, size=(n_free, 3))

        # Reference: https://github.com/nmwsharp/neural-implicit-queries/blob/c17e4b54f216cefb02d00ddba25c4f15b9873278/src/geometry.py#LL43C1-L43C1
        samples_full = np.vstack([surface_samples, close_samples, free_samples])
        normals_full = np.zeros_like(samples_full)
        normals_full[:surface_samples.shape[0]] = normals_
        mask = np.zeros_like(samples_full[:,0])
        mask[:surface_samples.shape[0]] = 1
        dist_sq, _, _ = igl.point_mesh_squared_distance(samples_full, self.V,
                                                        self.F)
        weight = np.exp(-beta * np.sqrt(dist_sq))
        weight = weight / np.sum(weight)

        sample_indices = np.random.choice(np.arange(sample_size_full),
                                          size=sample_size,
                                          p=weight,
                                          replace=False)
        samples = samples_full[sample_indices]
        normals = normals_full[sample_indices]
        mask = mask[sample_indices]

        sdf_vals, _, _ = igl.signed_distance(np.array(samples), self.V, self.F)

        return samples, np.array(sdf_vals), normals, mask

    def sample_surface(self, sample_size):
        bary, f_id,_ = igl.random_points_on_mesh(sample_size, self.V, self.F)
        surface_samples = np.sum(bary[..., None] * self.V[self.F[f_id]], 1)

        z = normalize(np.array([1, 1, 1]))
        FN = igl.per_face_normals(self.V, self.F, np.float64(z[None, :]))

        return surface_samples, surface_samples - 1e-3 * FN[f_id]

    def sample_dense(self, res=512):
        line = np.linspace(-0.5, 0.5, res)
        samples = np.stack(np.meshgrid(line, line, line), -1).reshape(-1, 3)

        splits = len(samples) // 100000
        sdf_vals = Parallel(
            n_jobs=multiprocessing.cpu_count() - 2, backend='multiprocessing')(
                delayed(self.sample_sdf_igl)(sample_split)
                for sample_split in np.array_split(samples, splits, axis=0))

        sdf_vals = np.concatenate(sdf_vals)


        return samples, np.array(sdf_vals)

if __name__ == '__main__':
    from tqdm import tqdm

    # for i in tqdm(range(len(chair_ids))):
    sampler = ShapeSDFSampler()
    #samples, sample_sdfs,normals, mask = sampler.sample_importance(10000000)
    samples, sample_sdfs,normals,samples_full = sampler.sample_uniform(10000000)
    np.save(f'./samplesd_lucy_.npy', samples)
    np.save(f'./samplesd_lucy_sdfs_.npy', sample_sdfs)
    np.save(f'./samplesd_lucy_normals.npy', normals)
    
    # ps.init()
    # ps.register_point_cloud("pg",samples_full[:6000000])
    # ps.register_surface_mesh("mesh",sampler.V,sampler.F)
    # ps.show()
    #np.save(f'./sample_masks_.npy', mask)