"""
Copyright (c) Meta Platforms, Inc. and affiliates.
All rights reserved.

This source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
"""

import jax.numpy as jnp 
import jax.random as random 
import healpy as hp
import numpy as np
import jax
import igl
import trimesh
from jax import vmap,jacfwd
# creates a sampler class to sample points from 
# 1. The interior of the domain 
# 2. The boundary 
# 3. The initial condition interval 
# 3 is here because we want to fuzz the initial condition -- sample at times slightly non-zero 
# This really seems to improve stability in practice 


#base class for samplers (to be extended by subclasses for specific domains)
class Sampler(object):

    #delta is the fuzzing timestep for the initial condition
    def __init__(self,delta=1e-2):
        self.delta = delta
        self.T = 1
        self.bsize = 100

    #N is always the batch size
    #samples from the interior of the domain 
    def smpDom(self,key,t,N=None):
        raise NotImplementedError

    #samples from the boundary
    def smpBd(self,key,N=None):
        raise NotImplementedError

    #samples time
    #defaults to uniform but can be overriden
    def smpTime(self,key,N=None):
        if N is None:
            N = self.N
        return random.uniform(key,shape=(N,1))*self.T

    def smpfixTime(self,t,N=None):
        if N is None:
            N = self.N
        return jnp.ones((N,1))*t

    #fuzzes the initial condition by sampling at different time 
    def smpInit(self,key,N=None):
        if N is None:
            N = self.N
        keys = random.split(key,3)
        t = random.uniform(keys[0],shape=(N,1))*self.delta
        return self.smpDom(keys[1],t)


#sampler class for the periodic square (flat 2-Tori)
class ToriSampler(Sampler):

    #T is the end time for the domain
    #N is the default batch size
    def __init__(self,T=1,N=100):
        super().__init__()
        self.T = T
        self.N = N
    
    #samples from the interior of the periodic square
    def smpDom(self,key,t,N=None):
        if N is None:
            N = self.N
        keys = random.split(key,2)
        
        pts = random.uniform(keys[0],shape=(N,2))

        return jnp.concatenate([t,pts],axis=1)

    #no boundary conditions since we are on a periodic domain 

    def smpBd(self,key,t,N=None):
        if N is None:
            N = self.N
        return jnp.zeros(shape=(N,2))
    
    
#Sampler for the 3d unit ball problem 
class BallSampler(Sampler):
    
    #T is the end time for the domain
    #N is the default batch size
    def __init__(self,T=1,N=100):
        super().__init__()
        self.T = T
        self.N = N
    
    #samples from the interior of the periodic square
    def smpDom(self,key,t,N=None):
        if N is None:
            N = self.N
        keys = random.split(key,2)
        
        pts = self.smpBd(keys[0],t)
        scale = random.uniform(keys[1],shape=(N,1))
        scale = jnp.sqrt(scale)
        pts = pts*scale

        return jnp.concatenate([t,pts[:,1:]],axis=1)

    #no boundary conditions since we are on a periodic domain 

    def smpBd(self,key,t,N=None):
        if N is None:
            N = self.N
            
        pts = random.uniform(key,shape=(N,3))*2 - 1
        pts = pts / jnp.linalg.norm(pts,axis=1).reshape(-1,1)
        return jnp.concatenate([t,pts],axis=1)
    
    

class CubeSampler(Sampler):

    #T is the end time for the domain
    #N is the default batch size
    def __init__(self,dim=3,width=1,N=500):
        super().__init__()
        self.dim = dim
        self.width = 2 #width of cube 
        self.N = N
    
    #samples from the interior of the periodic square
    def smpDom(self,key,t,N=None):
        if N is None:
            N = self.N
        pts = (random.uniform(key,shape=(N,self.dim)) - 0.5)*self.width

        return jnp.concatenate([t,pts],axis=1)
    
    def smpBd(self,key,N=None):
        return 0.
    


class SphereSampler(Sampler):
    
    #T is the end time for the domain
    #N is the default batch size
    def __init__(self,heal,T=1,N=100,nside=4000):
        super().__init__()
        self.T = T
        self.N = N
        self.heal = heal
        self.nside = nside
        self.npix = hp.nside2npix(nside)
        self.ag_array = jnp.array(np.stack(hp.pix2ang(self.nside,np.arange(self.npix)),axis=-1))
        print("max",jnp.max(self.ag_array[:,0]),jnp.max(self.ag_array[:,1]))
        print("min",jnp.min(self.ag_array[:,0]),jnp.min(self.ag_array[:,1]))

    #samples from the interior of the periodic square
    def smpDom(self,key,t,N=None):
        if N is None:
            N = self.N
        if not self.heal:
            keys = random.split(key,2)
            pts = random.uniform(keys[0],shape=(N,2))
            pts = pts.at[:,1].set(pts[:,1]*2-1)
            #pts = pts.at[:,0].set((pts[:,0]-0.5) * 0.8 + 0.5)
            pts = pts * jnp.pi
        else:
            keys = random.split(key,2)
            pts = random.choice(keys[0],self.ag_array,shape=(N,),axis=0)
            #jax.debug.print("Z: {} {} {}", pts[10,0],pts[10,1],pts[10,2])
            pts = pts.at[:,1].set(pts[:,1] - jnp.pi)
        return jnp.concatenate([t,pts],axis=1)

    #no boundary conditions since we are on a periodic domain 

    def smpBd(self,key,t,N=None):
        if N is None:
            N = self.N
        keys = random.split(key,4)
        pts_l = (random.uniform(keys[0],shape=(N,2))*2 - 1)* jnp.pi
        pts_r = (random.uniform(keys[1],shape=(N,2))*2 - 1)* jnp.pi
        pts_s = (random.uniform(keys[2],shape=(N,2))*2 - 1)* jnp.pi
        pts_n = (random.uniform(keys[3],shape=(N,2))*2 - 1)* jnp.pi

        bd = {}
        pts_l = pts_l.at[:,1].set(-jnp.pi)
        pts_r = pts_r.at[:,1].set(jnp.pi)
        pts_r = pts_r.at[:,0].set((pts_r[:,0]+jnp.pi)/2)
        pts_l = pts_l.at[:,0].set((pts_l[:,0]+jnp.pi)/2)

        pts_n = pts_n.at[:,0].set(0)
        pts_s = pts_s.at[:,0].set(jnp.pi-(1e-6))

        bd['l'] = jnp.concatenate([t,pts_l],axis=1)
        bd['r'] = jnp.concatenate([t,pts_r],axis=1)
        bd['n'] = jnp.concatenate([t,pts_n],axis=1)
        bd['s'] = jnp.concatenate([t,pts_s],axis=1)
        return bd

    def smpInit(self,key,N=None):
        if N is None:
            N = self.N
        keys = random.split(key,3)
        t = random.uniform(keys[0],shape=(N,1))*self.delta
        return self.smpDom(keys[1],t)
    


class SphereSampler3d(Sampler):
    
    #T is the end time for the domain
    #N is the default batch size
    def __init__(self,heal,T=1,N=100,nside=4000):
        super().__init__()
        self.T = T
        self.N = N
        self.heal = heal
        self.nside = nside
        self.npix = hp.nside2npix(nside)
        self.ag_array = jnp.array(np.stack(hp.pix2ang(self.nside,np.arange(self.npix)),axis=-1))
        print("max",jnp.max(self.ag_array[:,0]),jnp.max(self.ag_array[:,1]))
        print("min",jnp.min(self.ag_array[:,0]),jnp.min(self.ag_array[:,1]))

    #samples from the interior of the periodic square
    def smpDom(self,key,t,delta=0.00001,N=None):
        if N is None:
            N = self.N
        if not self.heal:
            keys = random.split(key,2)
            pts = random.uniform(keys[0],shape=(N,2))
            # r = random.uniform(keys[1],shape=(N,1))
            # r = (r-0.5) * delta + 1
            pts = pts.at[:,1].set(pts[:,1]*2)
            #pts = pts.at[:,0].set((pts[:,0]-0.5) * 0.8 + 0.5)
            pts = pts * jnp.pi

            u_0 = jnp.sin(pts[:,0])*jnp.sin(pts[:,1])
            u_1 = jnp.sin(pts[:,0])*jnp.cos(pts[:,1])
            u_2 = jnp.cos(pts[:,0])
            pts = jnp.stack([u_0,u_1,u_2],axis=1)
        else:
            keys = random.split(key,2)
            pts = random.choice(keys[0],self.ag_array,shape=(N,),axis=0)
            #jax.debug.print("Z: {} {} {}", pts[10,0],pts[10,1],pts[10,2])
            pts = pts.at[:,1].set(pts[:,1] - jnp.pi)
            r = random.uniform(keys[1],shape=(N,1))
            r = (r-0.5) * delta + 1
            u_0 = jnp.sin(pts[:,0])*jnp.sin(pts[:,1])
            u_1 = jnp.sin(pts[:,0])*jnp.cos(pts[:,1])
            u_2 = jnp.cos(pts[:,0])
            pts = jnp.stack([u_0,u_1,u_2],axis=1) * r
        return jnp.concatenate([t,pts],axis=1)

    #no boundary conditions since we are on a periodic domain 

    def smpBd(self,key,t,N=None):
        if N is None:
            N = self.N
        keys = random.split(key,4)
        pts_l = (random.uniform(keys[0],shape=(N,2))*2 - 1)* jnp.pi
        pts_r = (random.uniform(keys[1],shape=(N,2))*2 - 1)* jnp.pi
        pts_s = (random.uniform(keys[2],shape=(N,2))*2 - 1)* jnp.pi
        pts_n = (random.uniform(keys[3],shape=(N,2))*2 - 1)* jnp.pi

        bd = {}
        pts_l = pts_l.at[:,1].set(-jnp.pi)
        pts_r = pts_r.at[:,1].set(jnp.pi)
        pts_r = pts_r.at[:,0].set((pts_r[:,0]+jnp.pi)/2)
        pts_l = pts_l.at[:,0].set((pts_l[:,0]+jnp.pi)/2)

        pts_n = pts_n.at[:,0].set(0)
        pts_s = pts_s.at[:,0].set(jnp.pi-(1e-6))

        u_0 = jnp.sin(pts_l[:,0])*jnp.sin(pts_l[:,1])
        u_1 = jnp.sin(pts_l[:,0])*jnp.cos(pts_l[:,1])
        u_2 = jnp.cos(pts_l[:,0])
        pts_l = jnp.stack([u_0,u_1,u_2],axis=1)

        u_0 = jnp.sin(pts_r[:,0])*jnp.sin(pts_r[:,1])
        u_1 = jnp.sin(pts_r[:,0])*jnp.cos(pts_r[:,1])
        u_2 = jnp.cos(pts_r[:,0])
        pts_r = jnp.stack([u_0,u_1,u_2],axis=1)

        u_0 = jnp.sin(pts_n[:,0])*jnp.sin(pts_s[:,1])
        u_1 = jnp.sin(pts_n[:,0])*jnp.cos(pts_s[:,1])
        u_2 = jnp.cos(pts_n[:,0])
        pts_n = jnp.stack([u_0,u_1,u_2],axis=1)

        u_0 = jnp.sin(pts_s[:,0])*jnp.sin(pts_s[:,1])
        u_1 = jnp.sin(pts_s[:,0])*jnp.cos(pts_s[:,1])
        u_2 = jnp.cos(pts_s[:,0])
        pts_s = jnp.stack([u_0,u_1,u_2],axis=1)

        bd['l'] = jnp.concatenate([t,pts_l],axis=1)
        bd['r'] = jnp.concatenate([t,pts_r],axis=1)
        bd['n'] = jnp.concatenate([t,pts_n],axis=1)
        bd['s'] = jnp.concatenate([t,pts_s],axis=1)
        return bd

    def smpInit(self,key,N=None):
        if N is None:
            N = self.N
        keys = random.split(key,3)
        t = random.uniform(keys[0],shape=(N,1))*self.delta
        return self.smpDom(keys[1],t)
    




class PlanarSampler(Sampler):
    
    #T is the end time for the domain
    #N is the default batch size
    def __init__(self,heal,T=1,N=100,nside=4000):
        super().__init__()
        self.T = T
        self.N = N
        self.heal = heal
        self.nside = nside
        self.npix = hp.nside2npix(nside)
        self.ag_array = jnp.array(np.stack(hp.pix2ang(self.nside,np.arange(self.npix)),axis=-1))
        print("max",jnp.max(self.ag_array[:,0]),jnp.max(self.ag_array[:,1]))
        print("min",jnp.min(self.ag_array[:,0]),jnp.min(self.ag_array[:,1]))

    #samples from the interior of the periodic square
    def smpDom(self,key,t,delta=0.00001,N=None):
        if N is None:
            N = self.N

        keys = random.split(key,2)
        pts = random.uniform(keys[0],shape=(N,2))
        pts = (pts-0.5)*2*jnp.pi
        z = jnp.zeros_like(pts[:,0]).reshape([-1,1])
        pts = jnp.concatenate([pts,z],axis=1)
        n = jnp.array([0.3,-0.5,0.8])
        n = n/jnp.linalg.norm(n)
        rotation_m = jnp.array([[n[1]/jnp.sqrt(n[0]**2+n[1]**2),-n[0]/jnp.sqrt(n[0]**2+n[1]**2),0],
                                [n[0]*n[2]/jnp.sqrt(n[0]**2+n[1]**2),n[1]*n[2]/jnp.sqrt(n[0]**2+n[1]**2),-jnp.sqrt(n[0]**2+n[1]**2)],
                                [n[0],n[1],n[2]]
                                ])
        pts = (rotation_m@pts.transpose()).transpose()
        return jnp.concatenate([t,pts],axis=1)

    #no boundary conditions since we are on a periodic domain 

    def smpBd(self,key,t,N=None):
        if N is None:
            N = self.N
        keys = random.split(key,4)
        pts_l = (random.uniform(keys[0],shape=(N,2))*2 - 1)* jnp.pi
        pts_r = (random.uniform(keys[1],shape=(N,2))*2 - 1)* jnp.pi
        pts_s = (random.uniform(keys[2],shape=(N,2))*2 - 1)* jnp.pi
        pts_n = (random.uniform(keys[3],shape=(N,2))*2 - 1)* jnp.pi

        bd = {}
        pts_l = pts_l.at[:,1].set(-jnp.pi)
        pts_r = pts_r.at[:,1].set(jnp.pi)
        pts_r = pts_r.at[:,0].set((pts_r[:,0]+jnp.pi)/2)
        pts_l = pts_l.at[:,0].set((pts_l[:,0]+jnp.pi)/2)

        pts_n = pts_n.at[:,0].set(0)
        pts_s = pts_s.at[:,0].set(jnp.pi-(1e-6))

        u_0 = jnp.sin(pts_l[:,0])*jnp.sin(pts_l[:,1])
        u_1 = jnp.sin(pts_l[:,0])*jnp.cos(pts_l[:,1])
        u_2 = jnp.cos(pts_l[:,0])
        pts_l = jnp.stack([u_0,u_1,u_2],axis=1)

        u_0 = jnp.sin(pts_r[:,0])*jnp.sin(pts_r[:,1])
        u_1 = jnp.sin(pts_r[:,0])*jnp.cos(pts_r[:,1])
        u_2 = jnp.cos(pts_r[:,0])
        pts_r = jnp.stack([u_0,u_1,u_2],axis=1)

        u_0 = jnp.sin(pts_n[:,0])*jnp.sin(pts_s[:,1])
        u_1 = jnp.sin(pts_n[:,0])*jnp.cos(pts_s[:,1])
        u_2 = jnp.cos(pts_n[:,0])
        pts_n = jnp.stack([u_0,u_1,u_2],axis=1)

        u_0 = jnp.sin(pts_s[:,0])*jnp.sin(pts_s[:,1])
        u_1 = jnp.sin(pts_s[:,0])*jnp.cos(pts_s[:,1])
        u_2 = jnp.cos(pts_s[:,0])
        pts_s = jnp.stack([u_0,u_1,u_2],axis=1)

        bd['l'] = jnp.concatenate([t,pts_l],axis=1)
        bd['r'] = jnp.concatenate([t,pts_r],axis=1)
        bd['n'] = jnp.concatenate([t,pts_n],axis=1)
        bd['s'] = jnp.concatenate([t,pts_s],axis=1)
        return bd

    def smpInit(self,key,N=None):
        if N is None:
            N = self.N
        keys = random.split(key,3)
        t = random.uniform(keys[0],shape=(N,1))*self.delta
        return self.smpDom(keys[1],t)
    


class ExplicitMeshSampler(Sampler):
    
    #T is the end time for the domain
    #N is the default batch size
    def __init__(self,v_u,f_u,normal,T=1,N=100,nside=4000):
        super().__init__()
        self.T = T
        self.N = N
        self.V = jnp.array(v_u)
        self.F = jnp.array(f_u)
        self.normal = jnp.array(normal)
        # self.heal = heal
        # self.nside = nside
        # self.npix = hp.nside2npix(nside)
        # self.ag_array = jnp.array(np.stack(hp.pix2ang(self.nside,np.arange(self.npix)),axis=-1))
        # print("max",jnp.max(self.ag_array[:,0]),jnp.max(self.ag_array[:,1]))
        # print("min",jnp.min(self.ag_array[:,0]),jnp.min(self.ag_array[:,1]))


    # Random seed managed by numpy
    def sample_surface_fixed_seed(self, keys, sample_size):
        dbl_area = igl.doublearea(np.array(self.V), np.array(self.F))

        prob = jnp.array(dbl_area / dbl_area.sum())
        fid = jnp.arange(len(self.F))
        key = jax.random.split(keys,2)
        fid_pick = jax.random.choice(key[0],fid, (sample_size,), p=prob)

        # https://mathworld.wolfram.com/TrianglePointPicking.html
        sample_bary = jax.random.uniform(key[1], (sample_size, 2))
        # https://mathworld.wolfram.com/TriangleInterior.html
        sample_outside_mask = jnp.tile((jnp.sum(sample_bary,axis=-1) > 1).reshape(-1,1),(1,2))
        # sample_bary[sample_outside_mask] -= 1
        sample_bary = jnp.where(sample_outside_mask,sample_bary-1,sample_bary)
        sample_bary = jnp.abs(sample_bary)
        sample_bary = jnp.concatenate([sample_bary,1-jnp.sum(sample_bary,axis=-1).reshape(-1,1)],axis=-1)
        return sample_bary, fid_pick
    
    #samples from the interior of the periodic square
    def smpDom(self,key,t,delta=0.00001,N=None):
        if N is None:
            N = self.N

        # keys = random.split(key,2)
        # pts = random.uniform(keys[0],shape=(N,2))
        # pts = (pts-0.5)*2*jnp.pi
        # z = jnp.zeros_like(pts[:,0]).reshape([-1,1])
        # pts = jnp.concatenate([pts,z],axis=1)
        # n = jnp.array([0.3,-0.5,0.8])
        # n = n/jnp.linalg.norm(n)
        # rotation_m = jnp.array([[n[1]/jnp.sqrt(n[0]**2+n[1]**2),-n[0]/jnp.sqrt(n[0]**2+n[1]**2),0],
        #                         [n[0]*n[2]/jnp.sqrt(n[0]**2+n[1]**2),n[1]*n[2]/jnp.sqrt(n[0]**2+n[1]**2),-jnp.sqrt(n[0]**2+n[1]**2)],
        #                         [n[0],n[1],n[2]]
        #                         ])
        # pts = (rotation_m@pts.transpose()).transpose()
        
        pts_b,f = self.sample_surface_fixed_seed(key,N)
        normal = self.normal[f]
        pts = jnp.array(jnp.sum(jnp.stack([self.V[self.F[f][:,0]],self.V[self.F[f][:,1]],self.V[self.F[f][:,2]]],axis=-1) * pts_b[:,None,:],axis=-1))
        return jnp.concatenate([t,pts,normal],axis=1)

    #no boundary conditions since we are on a periodic domain 

    def smpBd(self,key,t,N=None):
        if N is None:
            N = self.N
        bd = {}
        return bd

    def smpInit(self,key,t,N=None):
        if N is None:
            N = self.N
        #print(igl.random_points_on_mesh(self.N,self.V,self.F))
        pts_b,f = self.sample_surface_fixed_seed(key,N)
        #print(pts_b.shape,f.shape)
        normal = self.normal[f]
        pts = jnp.array(jnp.sum(jnp.stack([self.V[self.F[f][:,0]],self.V[self.F[f][:,1]],self.V[self.F[f][:,2]]],axis=-1) * pts_b[:,None,:],axis=-1))
        
        return jnp.concatenate([t,pts,normal,pts_b,jnp.array(self.F[f][:,0]).reshape(-1,1),
                                jnp.array(self.F[f][:,1]).reshape(-1,1),
                                jnp.array(self.F[f][:,2]).reshape(-1,1)],axis=1)
    

from models import Siren
import flax
from utils import loadState

class ImplicitMeshSampler(Sampler):
    
    #T is the end time for the domain
    #N is the default batch size
    def __init__(self,mlp_path,T=1,N=100,nside=4000):
        super().__init__()
        self.T = T
        self.N = N
        seed = np.random.randint(2**32)
        layers = 5
        flax.config.update('flax_return_frozendict', True)
        key =  random.PRNGKey(seed)
        x = random.normal(key,shape=(3,))
        mlp = Siren(num_layers=layers,output_dim=1,w0=30,w0_first_layer=30,use_bias=True)
        params = mlp.init(key,x)
        params = params.unfreeze()['params']
        func_mlp_ = lambda params,x: mlp.apply({'params':params}, x)
        path = mlp_path
        params = loadState(path)

        self.implcit_mlp = lambda x: func_mlp_(params,x)

        max_samples = np.max(verts,axis=0)
        min_samples = np.min(verts,axis=0)

        res = 512
        line = np.linspace(-1, 1, res)
        samples = jnp.array(np.stack(np.meshgrid(line, line, line), -1).reshape(-1, 3))
        sdf = []
        step = res**3//128
        for i in range(128):
            sdf.append(vmap(self.implcit_mlp)(samples[i*step:(i+1)*step]))

        sdf = jnp.concatenate(sdf,axis=0).reshape(res,res,res)
        from skimage import measure
        verts, faces, normals, values = measure.marching_cubes(np.array(sdf), 0)

        center = (max_samples+min_samples)/2
        scale = np.max(max_samples-center)
        verts = (verts-center)/scale
        R = np.array([[0,1,0],[1,0,0],[0,0,1]])
        verts = (R@verts.transpose()).transpose()

        self.v = jnp.array(verts)
        self.f = jnp.array(faces)
        self.normals = jnp.array(normals)

    #samples from the interior of the periodic square
    def smpDom(self,key,t,delta=0.00001,batch_size=4000,N=None):
        if N is None:
            N = self.N
        keys = random.split(key,2)
        pts = random.uniform(keys[0],shape=(N,3))
        pts = (pts-0.5)*2*jnp.pi
        sdf = vmap(self.implcit_mlp)(pts)
        thr = 0.05
        out = []
        cnt = 0
        while cnt < N:
            x =  random.uniform(keys[0],shape=(batch_size,3))
            y = jnp.abs(sdf(x))
            m = (y < thr).astype(jnp.int64)
            m_cnt = jnp.sum(m)
            if m_cnt < 1:
                continue
            x_eq = x[m].reshape(m_cnt, 3)
            out.append(x_eq)
            cnt += m_cnt
        x = jnp.concatenate(out, axis=0)[:N, :]
        y = sdf(x)
        sdf_n = vmap(jacfwd(self.implcit_mlp))
        g = sdf_n(x)
        g = g / g.norm(dim=-1, keepdim=True)
        x = x - g * y
        normal = sdf_n(x)
        return jnp.concatenate([t,pts,normal],axis=1)

    #no boundary conditions since we are on a periodic domain 

    def smpBd(self,key,t,N=None):
        if N is None:
            N = self.N
        bd = {}
        return bd

    def smpInit(self,key,N=None):
        if N is None:
            N = self.N
        keys = random.split(key,3)
        t = random.uniform(keys[0],shape=(N,1))*self.delta
        return self.smpDom(keys[1],t)