import numpy as np

from flax import linen as nn
import jax.numpy as jnp
import jax
from jax import jacfwd, grad, vmap
import pickle
from utils import div
import jax.scipy as jsp
from utils import sin_thres,tan_thres,periodic
from jax.scipy.special import lpmn_values,sph_harm


from jax import config
config.update("jax_debug_nans", True)

# def sin_thres(x,thres=1e1):
#     x = jnp.minimum(1/(jnp.sin(x)+1e-10),thres)
#     x = jnp.maximum(x,-thres) 
#     return x

# def tan_thres(x,thres=1e1):
#     x = jnp.minimum(jnp.cos(x)/(jnp.sin(x)+1e-10),thres)
#     x = jnp.maximum(x,-thres)
#     return x

def p_div(F,x):
    val_ = F(x)
    B = jacfwd(F)
    H = B(x)
    Z = jnp.zeros_like(H)[0,0]
    Z = Z.at[0].set(-H[0,1,1]*(sin_thres(x[0])**2))
    Z = Z.at[1].set(H[0,1,0]- val_[0,1]*tan_thres(x[0]))
    return Z

def checkify(Z,a,b,c,d,e,x):
    if Z>10:
        print(a,b,c,d,e,x)

def laplacian(F,x):
    # val_: 2x2
    val_ = F(x)
    # B: 2x2x2
    B = jacfwd(F)
    b = B(x)
    # D : 2x2x2x2
    D = jacfwd(B)
    d = D(x)
    Z = d[0,1,0,0]*(sin_thres(x[0]))-b[0,1,0]*jnp.cos(x[0])*(sin_thres(x[0])**2)+ d[0,1,1,1] * (sin_thres(x[0])**3) + val_[0,1]*(sin_thres(x[0])**3)
    a = d[0,1,0,0]
    b = b[0,1,0]
    c = d[0,1,1,1]
    d = val_[0,1]
    e = (sin_thres(x[0])**3)
    #jax.debug.print("{},{},{},{}",d[0,1,0,0]*(sin_thres(x[0])),b[0,1,0]*jnp.cos(x[0])*(sin_thres(x[0])**2),val_[0,1]*(sin_thres(x[0])**3),d[0,1,1,1] * (sin_thres(x[0])**3))
    jax.debug.callback(checkify,Z,a,b,c,d,e,x)
    return Z



class NCLImplicit(object):
    def __init__(self,network):
        self.network = network
        
    #return type of NCL is [rho,rho u, p] (note middle!)
    def __call__(self,x,params):
        

        
        def A(x):
            u = lambda x: self.network(x,params)[:-1]
            
            A = jacfwd(u)(x)
            #jax.debug.print("a {} {}",u(x),A)
            return A - A.T
        
        v = p_div(A,x)
        w = laplacian(A,x)
        return jnp.array([*v,w])

# def embd(x,freqs=[1/4,1/6,1/8]):
#     return jnp.concatenate([x] + [periodic(x,0)] + [periodic(x,k) for k in freqs])

def ll_specify(x,freq):
    return [lpmn_values(freq,freq,x,is_normalized=True)[:,-1,0]]

def embd(x,freq=3):

    # #a = [x] + [lpmn_values(0,0,jnp.cos(x[:1]),is_normalized=True)[0,0]]
    # a = []
    # for i in range(1,freq):
    #     a += [jnp.sin(x[1:2])*i for i in ll_specify(jnp.cos(x[:1])*0.9999,i)]
    #     a += [jnp.cos(x[1:2])*i for i in ll_specify(jnp.cos(x[:1])*0.9999,i)]
    # #jax.debug.print("a {}",a)

    a = [x] + [lpmn_values(0,0,jnp.cos(x[:1]),is_normalized=True)[0,0]]
    for i in range(1,freq):
        a += ll_specify(jnp.cos(x[:1])*0.99999,i)
    a = [i * jnp.sin(x[0])**6 for i in a]
    b = jnp.concatenate(a)
    b = jnp.nan_to_num(b)
    
    return b

# def ll_specify(x,freq):
#     return [sph_harm(jnp.array([i]),jnp.array([freq]),0,x[:1],n_max=3) for i in range(-freq,freq)]

# def embd(x,freq=3):

#     a = [x] + [sph_harm(jnp.array([0]),jnp.array([0]),0,x[:1],n_max=freq)]
#     for i in range(1,freq):
#         a += ll_specify(x,i)
#     #jax.debug.print("a {}",a)
#     b = jnp.concatenate(a)
#     b = jnp.real(b)
#     return b

from models import MLP
import numpy as np
import jax.random as random
import jax
import jax.numpy as jnp
from pde import PDEDivForm
from losses import Loss,Sphere_Loss
from jax_sphere_experiment_setup import runBallExperiment
import optax
import flax

freqs = []

# Runs the Ball experiment with a NCL model
flax.config.update('flax_return_frozendict', True)
#define hyperparams for u,rho,p
beta = 25
act = lambda x: jax.nn.softplus(x*beta)/beta
#act = lambda x: jax.nn.softplus(x*beta)**2/(beta*beta*2)
#act = lambda x: jax.nn.relu(x)
layers = 4
width = 128
advect_time_step = 100
time_step = 0.1
seed = np.random.randint(2**32)
# seed = 427443453
key =  random.PRNGKey(seed)
print("Random initial seed:", seed)
x = random.normal(key,shape=(2,))
mlp = MLP(depth=layers,width=width,act=act,out_dim=3,std=0.01,bias=0.0001)
params = mlp.init(key,embd(x))

scale = 8e-2
params = jax.tree_map(lambda x: x*scale, params)

params = params.unfreeze()['params']

#print(params,type(params))
# import copy
# params_1 = copy.deepcopy(params)
func_mlp = lambda x,params: mlp.apply({'params':params}, embd(x))

#ncl outputs [rho,rho u, p] u = (u_x,u_y,u_z)
ncl = NCLImplicit(func_mlp)

print("Sample NCL output:", ncl(x,params))

#convenience for plotting, only ncl is passed to train/loss module
u = lambda x,params: ncl(x,params)


pde = PDEDivForm(spatial_m=True,time_step=time_step)
pde.setNormal(lambda y: y[1:])

loss = Sphere_Loss(ncl)
loss.addTermDom(pde.advect_loss,'advect')
#loss.addTermDom(pde.rho_pos,'rho')
# loss.addTermDom(pde.incp,'incp')
#loss.addTermInit(pde.init,'init')
loss.addTermInit(pde.init_w,'init')
loss.addTermBd_r(pde.bdry_cycle_r, 'cycle_r')
loss.addTermBd_l(pde.bdry_cycle_l, 'cycle_l')
loss.addTermBd_n(pde.bdry_scalar_singular_n, 's_n')
loss.addTermBd_s(pde.bdry_scalar_singular_s, 's_s')
loss.addTermBd_n(pde.bdry_vel_singular_n, 'vel_n')
loss.addTermBd_s(pde.bdry_vel_singular_s, 'vel_s')

gamma = {
    'advect':6e3,
    'incp':1e-1,
    'cycle_l':3e-2,
    'cycle_r':3e-2,
    's_n':0e-2,
    's_s':0e-2,
    'vel_n':0e-2,
    'vel_s':0e-2,
    'init':5e1
}
loss.setGamma(gamma)

sched = optax.piecewise_constant_schedule(init_value=1e-3,
                                    boundaries_and_scales={50000:1e-3,
                                                           80000:1e-4}
                                   )

runBallExperiment(params=params, 
                  key=key,
                  pde=pde,
                  loss=loss,
                  pinn=u,
                  advect_time_step=advect_time_step,
                  time_step=time_step,
                  sched=sched,
                  apx=str(seed)+"ncl_periods",
                    )
