"""
Copyright (c) Meta Platforms, Inc. and affiliates.
All rights reserved.

This source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
"""

import jax.numpy as jnp 
import jax.random as random 
import healpy as hp
import numpy as np
import jax

# creates a sampler class to sample points from 
# 1. The interior of the domain 
# 2. The boundary 
# 3. The initial condition interval 
# 3 is here because we want to fuzz the initial condition -- sample at times slightly non-zero 
# This really seems to improve stability in practice 


#base class for samplers (to be extended by subclasses for specific domains)
class Sampler(object):

    #delta is the fuzzing timestep for the initial condition
    def __init__(self,delta=1e-2):
        self.delta = delta
        self.T = 1
        self.bsize = 100

    #N is always the batch size
    #samples from the interior of the domain 
    def smpDom(self,key,t,N=None):
        raise NotImplementedError

    #samples from the boundary
    def smpBd(self,key,N=None):
        raise NotImplementedError

    #samples time
    #defaults to uniform but can be overriden
    def smpTime(self,key,N=None):
        if N is None:
            N = self.N
        return random.uniform(key,shape=(N,1))*self.T

    def smpfixTime(self,t,N=None):
        if N is None:
            N = self.N
        return jnp.ones((N,1))*t

    #fuzzes the initial condition by sampling at different time 
    def smpInit(self,key,N=None):
        if N is None:
            N = self.N
        keys = random.split(key,3)
        t = random.uniform(keys[0],shape=(N,1))*self.delta
        return self.smpDom(keys[1],t)


#sampler class for the periodic square (flat 2-Tori)
class ToriSampler(Sampler):

    #T is the end time for the domain
    #N is the default batch size
    def __init__(self,T=1,N=100):
        super().__init__()
        self.T = T
        self.N = N
    
    #samples from the interior of the periodic square
    def smpDom(self,key,t,N=None):
        if N is None:
            N = self.N
        keys = random.split(key,2)
        
        pts = random.uniform(keys[0],shape=(N,2))

        return jnp.concatenate([t,pts],axis=1)

    #no boundary conditions since we are on a periodic domain 

    def smpBd(self,key,t,N=None):
        if N is None:
            N = self.N
        return jnp.zeros(shape=(N,2))
    
    
#Sampler for the 3d unit ball problem 
class BallSampler(Sampler):
    
    #T is the end time for the domain
    #N is the default batch size
    def __init__(self,T=1,N=100):
        super().__init__()
        self.T = T
        self.N = N
    
    #samples from the interior of the periodic square
    def smpDom(self,key,t,N=None):
        if N is None:
            N = self.N
        keys = random.split(key,2)
        
        pts = self.smpBd(keys[0],t)
        scale = random.uniform(keys[1],shape=(N,1))
        scale = jnp.sqrt(scale)
        pts = pts*scale

        return jnp.concatenate([t,pts[:,1:]],axis=1)

    #no boundary conditions since we are on a periodic domain 

    def smpBd(self,key,t,N=None):
        if N is None:
            N = self.N
            
        pts = random.uniform(key,shape=(N,3))*2 - 1
        pts = pts / jnp.linalg.norm(pts,axis=1).reshape(-1,1)
        return jnp.concatenate([t,pts],axis=1)
    
    

class CubeSampler(Sampler):

    #T is the end time for the domain
    #N is the default batch size
    def __init__(self,dim=3,width=1,N=500):
        super().__init__()
        self.dim = dim
        self.width = 2 #width of cube 
        self.N = N
    
    #samples from the interior of the periodic square
    def smpDom(self,key,t,N=None):
        if N is None:
            N = self.N
        pts = (random.uniform(key,shape=(N,self.dim)) - 0.5)*self.width

        return jnp.concatenate([t,pts],axis=1)
    
    def smpBd(self,key,N=None):
        return 0.
    


class SphereSampler(Sampler):
    
    #T is the end time for the domain
    #N is the default batch size
    def __init__(self,heal,T=1,N=100,nside=4000):
        super().__init__()
        self.T = T
        self.N = N
        self.heal = heal
        self.nside = nside
        self.npix = hp.nside2npix(nside)
        self.ag_array = jnp.array(np.stack(hp.pix2ang(self.nside,np.arange(self.npix)),axis=-1))
        print("max",jnp.max(self.ag_array[:,0]),jnp.max(self.ag_array[:,1]))
        print("min",jnp.min(self.ag_array[:,0]),jnp.min(self.ag_array[:,1]))

    #samples from the interior of the periodic square
    def smpDom(self,key,t,N=None):
        if N is None:
            N = self.N
        if not self.heal:
            keys = random.split(key,2)
            pts = random.uniform(keys[0],shape=(N,2))
            pts = pts.at[:,1].set(pts[:,1]*2-1)
            #pts = pts.at[:,0].set((pts[:,0]-0.5) * 0.8 + 0.5)
            pts = pts * jnp.pi
        else:
            keys = random.split(key,2)
            pts = random.choice(keys[0],self.ag_array,shape=(N,),axis=0)
            #jax.debug.print("Z: {} {} {}", pts[10,0],pts[10,1],pts[10,2])
            pts = pts.at[:,1].set(pts[:,1] - jnp.pi)
        return jnp.concatenate([t,pts],axis=1)

    #no boundary conditions since we are on a periodic domain 

    def smpBd(self,key,t,N=None):
        if N is None:
            N = self.N
        keys = random.split(key,4)
        pts_l = (random.uniform(keys[0],shape=(N,2))*2 - 1)* jnp.pi
        pts_r = (random.uniform(keys[1],shape=(N,2))*2 - 1)* jnp.pi
        pts_s = (random.uniform(keys[2],shape=(N,2))*2 - 1)* jnp.pi
        pts_n = (random.uniform(keys[3],shape=(N,2))*2 - 1)* jnp.pi

        bd = {}
        pts_l = pts_l.at[:,1].set(-jnp.pi)
        pts_r = pts_r.at[:,1].set(jnp.pi)
        pts_r = pts_r.at[:,0].set((pts_r[:,0]+jnp.pi)/2)
        pts_l = pts_l.at[:,0].set((pts_l[:,0]+jnp.pi)/2)

        pts_n = pts_n.at[:,0].set(0)
        pts_s = pts_s.at[:,0].set(jnp.pi-(1e-6))

        bd['l'] = jnp.concatenate([t,pts_l],axis=1)
        bd['r'] = jnp.concatenate([t,pts_r],axis=1)
        bd['n'] = jnp.concatenate([t,pts_n],axis=1)
        bd['s'] = jnp.concatenate([t,pts_s],axis=1)
        return bd

    def smpInit(self,key,N=None):
        if N is None:
            N = self.N
        keys = random.split(key,3)
        t = random.uniform(keys[0],shape=(N,1))*self.delta
        return self.smpDom(keys[1],t)
    


class SphereSampler3d(Sampler):
    
    #T is the end time for the domain
    #N is the default batch size
    def __init__(self,heal,T=1,N=100,nside=4000):
        super().__init__()
        self.T = T
        self.N = N
        self.heal = heal
        self.nside = nside
        self.npix = hp.nside2npix(nside)
        self.ag_array = jnp.array(np.stack(hp.pix2ang(self.nside,np.arange(self.npix)),axis=-1))
        print("max",jnp.max(self.ag_array[:,0]),jnp.max(self.ag_array[:,1]))
        print("min",jnp.min(self.ag_array[:,0]),jnp.min(self.ag_array[:,1]))

    #samples from the interior of the periodic square
    def smpDom(self,key,t,delta=0.00001,N=None):
        if N is None:
            N = self.N
        if not self.heal:
            keys = random.split(key,2)
            pts = random.uniform(keys[0],shape=(N,2))
            # r = random.uniform(keys[1],shape=(N,1))
            # r = (r-0.5) * delta + 1
            pts = pts.at[:,1].set(pts[:,1]*2)
            #pts = pts.at[:,0].set((pts[:,0]-0.5) * 0.8 + 0.5)
            pts = pts * jnp.pi

            u_0 = jnp.sin(pts[:,0])*jnp.sin(pts[:,1])
            u_1 = jnp.sin(pts[:,0])*jnp.cos(pts[:,1])
            u_2 = jnp.cos(pts[:,0])
            pts = jnp.stack([u_0,u_1,u_2],axis=1)
        else:
            keys = random.split(key,2)
            pts = random.choice(keys[0],self.ag_array,shape=(N,),axis=0)
            #jax.debug.print("Z: {} {} {}", pts[10,0],pts[10,1],pts[10,2])
            pts = pts.at[:,1].set(pts[:,1] - jnp.pi)
            r = random.uniform(keys[1],shape=(N,1))
            r = (r-0.5) * delta + 1
            u_0 = jnp.sin(pts[:,0])*jnp.sin(pts[:,1])
            u_1 = jnp.sin(pts[:,0])*jnp.cos(pts[:,1])
            u_2 = jnp.cos(pts[:,0])
            pts = jnp.stack([u_0,u_1,u_2],axis=1) * r
        return jnp.concatenate([t,pts],axis=1)

    #no boundary conditions since we are on a periodic domain 

    def smpBd(self,key,t,N=None):
        if N is None:
            N = self.N
        keys = random.split(key,4)
        pts_l = (random.uniform(keys[0],shape=(N,2))*2 - 1)* jnp.pi
        pts_r = (random.uniform(keys[1],shape=(N,2))*2 - 1)* jnp.pi
        pts_s = (random.uniform(keys[2],shape=(N,2))*2 - 1)* jnp.pi
        pts_n = (random.uniform(keys[3],shape=(N,2))*2 - 1)* jnp.pi

        bd = {}
        pts_l = pts_l.at[:,1].set(-jnp.pi)
        pts_r = pts_r.at[:,1].set(jnp.pi)
        pts_r = pts_r.at[:,0].set((pts_r[:,0]+jnp.pi)/2)
        pts_l = pts_l.at[:,0].set((pts_l[:,0]+jnp.pi)/2)

        pts_n = pts_n.at[:,0].set(0)
        pts_s = pts_s.at[:,0].set(jnp.pi-(1e-6))

        u_0 = jnp.sin(pts_l[:,0])*jnp.sin(pts_l[:,1])
        u_1 = jnp.sin(pts_l[:,0])*jnp.cos(pts_l[:,1])
        u_2 = jnp.cos(pts_l[:,0])
        pts_l = jnp.stack([u_0,u_1,u_2],axis=1)

        u_0 = jnp.sin(pts_r[:,0])*jnp.sin(pts_r[:,1])
        u_1 = jnp.sin(pts_r[:,0])*jnp.cos(pts_r[:,1])
        u_2 = jnp.cos(pts_r[:,0])
        pts_r = jnp.stack([u_0,u_1,u_2],axis=1)

        u_0 = jnp.sin(pts_n[:,0])*jnp.sin(pts_s[:,1])
        u_1 = jnp.sin(pts_n[:,0])*jnp.cos(pts_s[:,1])
        u_2 = jnp.cos(pts_n[:,0])
        pts_n = jnp.stack([u_0,u_1,u_2],axis=1)

        u_0 = jnp.sin(pts_s[:,0])*jnp.sin(pts_s[:,1])
        u_1 = jnp.sin(pts_s[:,0])*jnp.cos(pts_s[:,1])
        u_2 = jnp.cos(pts_s[:,0])
        pts_s = jnp.stack([u_0,u_1,u_2],axis=1)

        bd['l'] = jnp.concatenate([t,pts_l],axis=1)
        bd['r'] = jnp.concatenate([t,pts_r],axis=1)
        bd['n'] = jnp.concatenate([t,pts_n],axis=1)
        bd['s'] = jnp.concatenate([t,pts_s],axis=1)
        return bd

    def smpInit(self,key,N=None):
        if N is None:
            N = self.N
        keys = random.split(key,3)
        t = random.uniform(keys[0],shape=(N,1))*self.delta
        return self.smpDom(keys[1],t)
    




class PlanarSamplerAutoEncoder(Sampler):
    
    #T is the end time for the domain
    #N is the default batch size
    def __init__(self,heal,training_images,training_labels,T=1,N=100,nside=4000):
        super().__init__()
        self.T = T
        self.N = N
        self.heal = heal
        self.nside = nside
        self.npix = hp.nside2npix(nside)
        self.ag_array = jnp.array(np.stack(hp.pix2ang(self.nside,np.arange(self.npix)),axis=-1))
        print("max",jnp.max(self.ag_array[:,0]),jnp.max(self.ag_array[:,1]))
        print("min",jnp.min(self.ag_array[:,0]),jnp.min(self.ag_array[:,1]))
        self.training_images = training_images
        self.training_labels = training_labels
        self.total_shape = training_labels.shape[0]

    #samples from the interior of the periodic square
    def smpDom(self,key,t,delta=0.00001,N=None):
        if N is None:
            N = self.N

        keys = random.split(key,3)
        pts = random.uniform(keys[0],shape=(N,2))
        ind = random.choice(keys[1], self.total_shape, shape=(N,))
        pts = (pts-0.5)*2*jnp.pi
        z = jnp.zeros_like(pts[:,0]).reshape([-1,1])
        pts = jnp.concatenate([pts,z],axis=1)
        zk = random.normal(keys[2],shape=(N,4))
        # n = jnp.array([0,0,1])
        # n = n/jnp.linalg.norm(n)
        # rotation_m = jnp.array([[n[1]/jnp.sqrt(n[0]**2+n[1]**2),-n[0]/jnp.sqrt(n[0]**2+n[1]**2),0],
        #                         [n[0]*n[2]/jnp.sqrt(n[0]**2+n[1]**2),n[1]*n[2]/jnp.sqrt(n[0]**2+n[1]**2),-jnp.sqrt(n[0]**2+n[1]**2)],
        #                         [n[0],n[1],n[2]]
        #                         ]).transpose()
        
        # pts = (rotation_m@pts.transpose()).transpose()
        
        a = jnp.concatenate([t,pts,self.training_labels[ind],self.training_images[ind],zk],axis=1)
        return a

    #no boundary conditions since we are on a periodic domain 

    def smpBd(self,key,t,N=None):
        if N is None:
            N = self.N
        keys = random.split(key,4)
        pts_l = (random.uniform(keys[0],shape=(N,2))*2 - 1)* jnp.pi
        pts_r = (random.uniform(keys[1],shape=(N,2))*2 - 1)* jnp.pi
        pts_s = (random.uniform(keys[2],shape=(N,2))*2 - 1)* jnp.pi
        pts_n = (random.uniform(keys[3],shape=(N,2))*2 - 1)* jnp.pi

        bd = {}
        pts_l = pts_l.at[:,1].set(-jnp.pi)
        pts_r = pts_r.at[:,1].set(jnp.pi)
        pts_r = pts_r.at[:,0].set((pts_r[:,0]+jnp.pi)/2)
        pts_l = pts_l.at[:,0].set((pts_l[:,0]+jnp.pi)/2)

        pts_n = pts_n.at[:,0].set(0)
        pts_s = pts_s.at[:,0].set(jnp.pi-(1e-6))

        u_0 = jnp.sin(pts_l[:,0])*jnp.sin(pts_l[:,1])
        u_1 = jnp.sin(pts_l[:,0])*jnp.cos(pts_l[:,1])
        u_2 = jnp.cos(pts_l[:,0])
        pts_l = jnp.stack([u_0,u_1,u_2],axis=1)

        u_0 = jnp.sin(pts_r[:,0])*jnp.sin(pts_r[:,1])
        u_1 = jnp.sin(pts_r[:,0])*jnp.cos(pts_r[:,1])
        u_2 = jnp.cos(pts_r[:,0])
        pts_r = jnp.stack([u_0,u_1,u_2],axis=1)

        u_0 = jnp.sin(pts_n[:,0])*jnp.sin(pts_s[:,1])
        u_1 = jnp.sin(pts_n[:,0])*jnp.cos(pts_s[:,1])
        u_2 = jnp.cos(pts_n[:,0])
        pts_n = jnp.stack([u_0,u_1,u_2],axis=1)

        u_0 = jnp.sin(pts_s[:,0])*jnp.sin(pts_s[:,1])
        u_1 = jnp.sin(pts_s[:,0])*jnp.cos(pts_s[:,1])
        u_2 = jnp.cos(pts_s[:,0])
        pts_s = jnp.stack([u_0,u_1,u_2],axis=1)

        bd['l'] = jnp.concatenate([t,pts_l],axis=1)
        bd['r'] = jnp.concatenate([t,pts_r],axis=1)
        bd['n'] = jnp.concatenate([t,pts_n],axis=1)
        bd['s'] = jnp.concatenate([t,pts_s],axis=1)
        return bd

    def smpInit(self,key,N=None):
        if N is None:
            N = self.N
        keys = random.split(key,3)
        t = random.uniform(keys[0],shape=(N,1))*self.delta
        return self.smpDom(keys[1],t)