import numpy as np

from flax import linen as nn
import jax.numpy as jnp
import jax
from jax import jacfwd, grad, vmap
import pickle
from utils import div
import jax.scipy as jsp
from utils import sin_thres,tan_thres,periodic
from jax.scipy.special import lpmn_values,sph_harm


from jax import config
config.update("jax_debug_nans", True)

# def sin_thres(x,thres=1e1):
#     x = jnp.minimum(1/(jnp.sin(x)+1e-10),thres)
#     x = jnp.maximum(x,-thres) 
#     return x

# def tan_thres(x,thres=1e1):
#     x = jnp.minimum(jnp.cos(x)/(jnp.sin(x)+1e-10),thres)
#     x = jnp.maximum(x,-thres)
#     return x

# def p_div(F,x):
#     val_ = F(x)
#     B = jacfwd(F)
#     H = B(x)
#     Z = jnp.zeros_like(H)[0,0]
#     Z = Z.at[0].set(-H[0,1,1])
#     Z = Z.at[1].set(H[0,1,0]*(jnp.sin(x[0])**2)- val_[0,1]*jnp.cos(x[0])*(jnp.sin(x[0])))
#     return Z

# def checkify(Z,a,b,c,d,e,x):
#     if Z>10:
#         print(a,b,c,d,e,x)

# def laplacian(F,x):
#     # val_: 2x2
#     val_ = F(x)
#     # B: 2x2x2
#     B = jacfwd(F)
#     b = B(x)
#     # D : 2x2x2x2
#     D = jacfwd(B)
#     d = D(x)
#     Z = d[0,1,0,0]*((jnp.sin(x[0])**2))-b[0,1,0]*jnp.cos(x[0])*(jnp.sin(x[0]))+ d[0,1,1,1]+ val_[0,1]
    
#     # a = d[0,1,0,0]
#     # b = b[0,1,0]
#     # c = d[0,1,1,1]
#     # d = val_[0,1]
#     # e = (sin_thres(x[0])**3)
#     # #jax.debug.print("{},{},{},{}",d[0,1,0,0]*(sin_thres(x[0])),b[0,1,0]*jnp.cos(x[0])*(sin_thres(x[0])**2),val_[0,1]*(sin_thres(x[0])**3),d[0,1,1,1] * (sin_thres(x[0])**3))
#     # jax.debug.callback(checkify,Z,a,b,c,d,e,x)
#     return Z

def div(F):
    B = jacfwd(F)
    return lambda x: jnp.trace(B(x),axis1=-2,axis2=-1)

#analog of curl by taking norm of Df - Df^T
# def curl(F):
#     B = jacfwd(F)
#     C = lambda x: jnp.stack([(B(x)-B(x).T)[2,1],(B(x)-B(x).T)[2,0],(B(x)-B(x).T)[1,0]],axis=-1)
#     return C

#analog of curl by taking norm of Df - Df^T
def curl(F,x):
    b = jacfwd(F)
    B = b(x)
    C = jnp.array([B[2,1]-B[1,2],B[0,2]-B[2,0],B[1,0]-B[0,1]])
    return C


#analog of curl by taking norm of Df - Df^T
# def curl(F,x):
#     b = jacfwd(F)
#     B = b(x)
#     C = jnp.array([B[2,1]-B[1,2],B[0,2]-B[2,0],B[1,0]-B[0,1]])
#     return C


# def P_div(u):
#     n = lambda x: x/jnp.linalg.norm(x)
#     B = lambda x: jnp.cross(u(x),n(x))
#     curl_B = lambda x: curl(B,x)
#     D = lambda x: jnp.sum(n(x)*curl_B(x))
#     E = jacfwd(D)
#     return lambda x: jnp.cross(E(x),n(x))

# def laplacian(div):
#     n = lambda x: x/jnp.linalg.norm(x)
#     curl_ = lambda x:curl(div,x)
#     return lambda x: jnp.sum(n(x),curl_(x))

def u_f(u):
    B = lambda x: jnp.cross(u(x),x)

class NCLImplicit(object):
    def __init__(self,network):
        self.network = network

        
    #return type of NCL is [rho,rho u, p] (note middle!)
    def __call__(self,x,params):
        
        # given normal
        n = lambda x: x
        
        
        u = lambda x: self.network(x,params)
        
        # curl_B = lambda x: curl(u,x)
        # D = lambda x: jnp.sum(n(x)*curl_B(x))

        E = jacfwd(u)
        div_func = lambda x: jnp.cross(E(x)[0],n(x))

        curl_ = lambda x: curl(div_func,x)
        stream_func = lambda x: jnp.sum(n(x)*curl_(x))
        v = div_func(x)
        w = stream_func(x)
        
        # jax.debug.print("{},{}",v,w)
        #print(v.shape,w.shape,E(x).shape,n(x).shape)
        return jnp.array([*v,w])

# def embd(x,freqs=[1/4,1/6,1/8]):
#     return jnp.concatenate([x] + [periodic(x,0)] + [periodic(x,k) for k in freqs])

def ll_specify(x,freq):
    return [lpmn_values(freq,freq,x,is_normalized=True)[:,-1,0]]

def embd(x,freq=3):

    # #a = [x] + [lpmn_values(0,0,jnp.cos(x[:1]),is_normalized=True)[0,0]]
    # a = []
    # for i in range(1,freq):
    #     a += [jnp.sin(x[1:2])*i for i in ll_specify(jnp.cos(x[:1])*0.9999,i)]
    #     a += [jnp.cos(x[1:2])*i for i in ll_specify(jnp.cos(x[:1])*0.9999,i)]
    # #jax.debug.print("a {}",a)

    a = [x] + [lpmn_values(0,0,jnp.cos(x[:1]),is_normalized=True)[0,0]]
    for i in range(1,freq):
        a += ll_specify(jnp.cos(x[:1])*0.99999,i)
    a = [i * jnp.sin(x[0])**6 for i in a]
    b = jnp.concatenate(a)
    b = jnp.nan_to_num(b)
    
    return b

# def ll_specify(x,freq):
#     return [sph_harm(jnp.array([i]),jnp.array([freq]),0,x[:1],n_max=3) for i in range(-freq,freq)]

# def embd(x,freq=3):

#     a = [x] + [sph_harm(jnp.array([0]),jnp.array([0]),0,x[:1],n_max=freq)]
#     for i in range(1,freq):
#         a += ll_specify(x,i)
#     #jax.debug.print("a {}",a)
#     b = jnp.concatenate(a)
#     b = jnp.real(b)
#     return b

from models import MLP,Siren
import numpy as np
import jax.random as random
import jax
import jax.numpy as jnp
from pde import PDEDivForm
from losses import Loss,Sphere_Loss
from jax_sphere_experiment_setup import runBallExperiment
import optax
import flax

freqs = []




def syn(x):
    x = x.at[2].set(jnp.abs(x[2]))
    return x

def deal_sign(v,x):
    sign = jnp.sign(x[2])
    #v = v.at[0].set(v[0]*sign)
    #v = v.at[1].set(v[1]*sign)
    v = v.at[2].set(v[2]*sign)
    v = v.at[3].set(v[3]*sign)
    return v


# Runs the Ball experiment with a NCL model
flax.config.update('flax_return_frozendict', True)
#define hyperparams for u,rho,p
beta = 8
#act = lambda x: jax.nn.softplus(x*beta)/beta
act = lambda x: jax.nn.softplus(x*beta)**2/(beta*beta*2)
#
# act = jnp.sin
layers = 4
width = 128
advect_time_step = 300
time_step = 0.05
seed = np.random.randint(2**32)
# seed = 427443453
# seed = 2931118338
key =  random.PRNGKey(seed)
print("Random initial seed:", seed)
x = random.normal(key,shape=(3,))
#mlp = MLP(depth=layers,width=width,act=act,out_dim=1,std=0.01,bias=True) # 0.01
mlp = Siren(num_layers=layers,output_dim=1,w0=30,w0_first_layer=100,use_bias=True)
params = mlp.init(key,x)

# scale = 8e-2
# params = jax.tree_map(lambda x: x*scale, params)
params = params.unfreeze()['params']

#print(params,type(params))
# import copy
# params_1 = copy.deepcopy(params)

func_mlp_ = lambda x,params: mlp.apply({'params':params}, x)


#ncl outputs [rho,rho u, p] u = (u_x,u_y,u_z)
ncl = NCLImplicit(func_mlp_)

print("Sample NCL output:", ncl(x,params))

#convenience for plotting, only ncl is passed to train/loss module
u = lambda x,params: ncl(x,params)


pde = PDEDivForm(spatial_m=True,time_step=time_step)
pde.setNormal(lambda y: y[1:])

loss = Sphere_Loss(ncl)
loss.addTermDom(pde.advect_loss,'advect')
loss.addTermInit(pde.init_w,'init')


gamma = {
    'advect':6e-1,
    'incp':1e-1,
    'cycle_l':3e-2,
    'cycle_r':3e-2,
    's_n':0e-2,
    's_s':0e-2,
    'vel_n':0e-2,
    'vel_s':0e-2,
    'init':5e1
}
loss.setGamma(gamma)

sched = optax.exponential_decay(init_value = 1e-4,transition_steps=200000,decay_rate=1e-1)

# sched = optax.piecewise_constant_schedule(init_value=1e-4,
#                         boundaries_and_scales={20000:1e-1,
#                                                40000:1e-2}
#                         )



runBallExperiment(params=params, 
                  key=key,
                  pde=pde,
                  loss=loss,
                  pinn=u,
                  advect_time_step=advect_time_step,
                  time_step=time_step,
                  sched=sched,
                  apx=str(seed)+"ncl_periods",
                  mlp = mlp,
                    )
