import jax

import jax.numpy as jnp
import pickle
import numpy as np
from models import Siren
import jax.random as random

def loadState(path):
    with open(path + "_model",'rb') as f:
        params = pickle.load(f)
    return params


import flax
import optax
from jax import vmap

seed = np.random.randint(2**32)

layers = 5
flax.config.update('flax_return_frozendict', True)
key =  random.PRNGKey(seed)
print("Random initial seed:", seed)
x = random.normal(key,shape=(3,))
mlp = Siren(num_layers=layers,output_dim=1,w0=30,w0_first_layer=30,use_bias=True)
params = mlp.init(key,x)
params = params.unfreeze()['params']
func_mlp_ = lambda params,x: mlp.apply({'params':params}, x)
func_mlp = lambda x: func_mlp_(params,x)
path = './savings'
params = loadState(path)

res = 512

line = np.linspace(-1, 1, res)
samples = jnp.array(np.stack(np.meshgrid(line, line, line), -1).reshape(-1, 3))

sdf = []
step = res**3//128
for i in range(128):
    sdf.append(vmap(func_mlp)(samples[i*step:(i+1)*step]))

sdf = jnp.concatenate(sdf,axis=0).reshape(res,res,res)


from skimage import measure


verts, faces, normals, values = measure.marching_cubes(np.array(sdf), 0)



import polyscope as ps

import os
import igl
model_path = os.path.join('./Armadillo.ply')
V, F = igl.read_triangle_mesh(model_path)

max_samples = np.max(V,axis=0)
min_samples = np.min(V,axis=0)

center = (max_samples+min_samples)/2
scale = np.max(max_samples-center)
V = (V-center)/scale


max_samples = np.max(verts,axis=0)
min_samples = np.min(verts,axis=0)

center = (max_samples+min_samples)/2
scale = np.max(max_samples-center)
verts = (verts-center)/scale
R = np.array([[0,1,0],[1,0,0],[0,0,1]])
verts = (R@verts.transpose()).transpose()

# import open3d as o3d
# samples = np.load('./samplesd_.npy')[:6000000]
# samples_normals = np.load('./samplesd_normals.npy')[:6000000]
# pcd = o3d.geometry.PointCloud()
# pcd.points = o3d.utility.Vector3dVector(samples)
# pcd.normals = o3d.utility.Vector3dVector(samples_normals)




# with o3d.utility.VerbosityContextManager(
#         o3d.utility.VerbosityLevel.Debug) as cm:
#     mesh, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(
#         pcd, depth=9)

# v_p = np.asarray(mesh.vertices)
# f_p = np.asarray(mesh.triangles)

a = igl.adjacency_matrix(faces)
n, C, K = igl.connected_components(a)

f_pr = faces.reshape(-1)
k = C[f_pr]
mask_ = (k!=2).reshape(-1,3)

mask = mask_[:,0]&mask_[:,1]&mask_[:,2]
faces = faces[~mask]
#verts,_,_,_ = igl.remove_unreferenced(verts,faces)

n_ = (verts[C==2]).shape[0]
k = np.zeros_like(verts[:,0]).astype(np.int64)
k[C==2] = np.arange(n_).astype(np.int64)
# print(k[C==2])
faces_p2 = faces.reshape(-1,3)
faces_p3 = k[faces_p2]
faces_p3 = faces_p3.reshape(-1,3)

np.save('ad_verts.npy',verts[C==2])

np.save('ad_faces.npy',faces_p3)
np.save('ad_normals.npy',normals[C==2])

# print(faces_p3)
ps.init()
ps.register_surface_mesh("inference",verts[C==2],faces_p3)
#ps.register_point_cloud("inference",verts[C==2])
# ps.register_surface_mesh("gt_2",V,F)
# ps.register_surface_mesh("gt_p",v_p,f_p)
ps.show()
