import numpy as np

from flax import linen as nn
import jax.numpy as jnp
import jax
from jax import jacfwd, grad, vmap, jacrev
import pickle
from utils import div
import jax.scipy as jsp
from utils import sin_thres,tan_thres,periodic
from jax.scipy.special import lpmn_values,sph_harm
from jax import config
from sampling import PlanarSamplerAutoEncoder
config.update("jax_debug_nans", True)

def div(F):
    B = jacfwd(F)
    return lambda x: jnp.trace(B(x),axis1=-2,axis2=-1)



#analog of curl by taking norm of Df - Df^T
def curl(F,x):
    b = jacrev(F)
    B = b(x)
    C = jnp.array([B[2,1]-B[1,2],B[0,2]-B[2,0],B[1,0]-B[0,1]])
    return C

def u_f(u):
    B = lambda x: jnp.cross(u(x),x)

class NCLImplicit(object):
    def __init__(self,network):
        self.network = network
        self.n = jnp.array([0,0,1])
        self.n = self.n/jnp.linalg.norm(self.n)

        
    #return type of NCL is [rho,rho u, p] (note middle!)
    def __call__(self,x,params):
        mean_var = lambda x: self.network(x,params)[1:]
        u = lambda x: self.network(x,params)[:1]
        E = jacrev(u)
        div_func = lambda x: jnp.cross(E(x)[0,:3],self.n)

        curl_ = lambda x: curl(div_func,x)
        stream_func = lambda x: jnp.sum(self.n*curl_(x))
        v = div_func(x)
        w = stream_func(x)
        mean_var_v = mean_var(x)
        #print(mean_var_v.shape,w.shape,*v.shape)
        return jnp.array([*v,w,*mean_var_v])



from models import MLP,Siren,VAE
import numpy as np
import jax.random as random
import jax
import jax.numpy as jnp
from pde import PDEDivForm
from losses import Loss,Sphere_Loss
from jax_sphere_experiment_setup import runBallExperiment
import optax
import flax
import gzip
import struct
import array
import os
from jax.scipy.ndimage import map_coordinates

def mnist_raw():
    _DATA = './alphabet_flow/data/EMNIST/raw/gzip'
    """Download and parse the raw MNIST dataset."""
    # CVDF mirror of http://yann.lecun.com/exdb/mnist/

    def parse_labels(filename):
        with gzip.open(filename, "rb") as fh:
            _ = struct.unpack(">II", fh.read(8))
            return np.array(array.array("B", fh.read()), dtype=np.uint8)

    def parse_images(filename):
        with gzip.open(filename, "rb") as fh:
            _, num_data, rows, cols = struct.unpack(">IIII", fh.read(16))
            return np.array(array.array("B", fh.read()),
                            dtype=np.uint8).reshape(num_data, rows, cols)


    train_images = parse_images(os.path.join(_DATA, "emnist-letters-train-images-idx3-ubyte.gz"))
    train_labels = parse_labels(os.path.join(_DATA, "emnist-letters-train-labels-idx1-ubyte.gz"))

    return train_images, train_labels

def init_image(label,num):
    train_images, train_labels = mnist_raw()
    #I:9 C:3 M:13 L:12
    train_images_I = train_images[train_labels==label]
    one_hot_labels = jax.nn.one_hot(jnp.array([num]), num_classes=4)
    # print(one_hot_labels)
    one_hot_labels = jnp.tile(one_hot_labels,(train_images_I.shape[0],1))
    #print(one_hot_labels.shape)
    return train_images_I,one_hot_labels


# Runs the Ball experiment with a NCL model
flax.config.update('flax_return_frozendict', True)
#define hyperparams for u,rho,p
beta = 8
#act = lambda x: jax.nn.softplus(x*beta)/beta
act = lambda x: jax.nn.softplus(x*beta)**2/(beta*beta*2)
#
# act = jnp.sin
layers = 4
width = 128
advect_time_step = 300
time_step = 0.05
seed = np.random.randint(2**32)
# seed = 427443453
# seed = 2931118338
key =  random.PRNGKey(seed)
print("Random initial seed:", seed)
x = random.normal(key,shape=(3+4+28*28+4,))
#mlp = MLP(depth=layers,width=width,act=act,out_dim=1,std=0.01,bias=True) # 0.01
mlp = VAE(num_layers=layers,output_dim=1,w0=30,w0_first_layer=30,use_bias=True)
params = mlp.init(key,x)



scale = 8e-2
print(params['params'].keys())
params = params.unfreeze()['params']
params['encoder'] = jax.tree_map(lambda x: x*scale, params['encoder'])
# print(params['decoder'].keys())
# params['decoder']['dense_1'] = jax.tree_map(lambda x: x*scale, params['decoder']['dense_1'])
# params['decoder']['dense_2'] = jax.tree_map(lambda x: x*scale, params['decoder']['dense_2'])
#print(params,type(params))
# import copy
# params_1 = copy.deepcopy(params)

#func_mlp_ = lambda x,params: mlp.apply({'params':params}, x)

# n = jnp.array([0,0,1])
# n = n/jnp.linalg.norm(n)
# rotation_m_inv = jnp.linalg.inv(jnp.array([[n[1]/jnp.sqrt(n[0]**2+n[1]**2),-n[0]/jnp.sqrt(n[0]**2+n[1]**2),0],
#                         [n[0]*n[2]/jnp.sqrt(n[0]**2+n[1]**2),n[1]*n[2]/jnp.sqrt(n[0]**2+n[1]**2),-jnp.sqrt(n[0]**2+n[1]**2)],
#                         [n[0],n[1],n[2]]
#                         ]).transpose())

def func_mlp(x,params):
    # x = x.at[:3].set(rotation_m_inv@x[:3])
    func_mlp_ = lambda x,params: mlp.apply({'params':params}, x)
    return func_mlp_(x,params)

mlp_k = Siren(num_layers=layers,output_dim=1,w0=30,w0_first_layer=30,use_bias=True)
def func_mlp_generate(x,params):
    # x = x.at[:3].set(rotation_m_inv@x[:3])
    func_mlp_ = lambda x,params: mlp_k.apply({'params':params['decoder']}, x)
    return func_mlp_(x,params)

#ncl outputs [rho,rho u, p] u = (u_x,u_y,u_z)
ncl = NCLImplicit(func_mlp)
print("Sample NCL output:", ncl(x,params))

infer_ncl = NCLImplicit(func_mlp_generate)

#convenience for plotting, only ncl is passed to train/loss module
u = lambda x,params: ncl(x,params)
u_inference = lambda x,params: infer_ncl(x,params)

pde = PDEDivForm(spatial_m=True,time_step=time_step)
pde.setNormal(lambda y: y[1:])

loss = Sphere_Loss(ncl)
loss.addTermDom(pde.advect_loss,'advect')
loss.addTermInit(pde.init_w,'init')


gamma = {
    'advect':6e-1,
    'incp':1e-1,
    'cycle_l':3e-2,
    'cycle_r':3e-2,
    's_n':0e-2,
    's_s':0e-2,
    'vel_n':0e-2,
    'vel_s':0e-2,
    'init':5e1
}
loss.setGamma(gamma)

sched = optax.exponential_decay(init_value = 1e-4,transition_steps=400000,decay_rate=1e-1)

# sched = optax.piecewise_constant_schedule(init_value=1e-4,
#                         boundaries_and_scales={20000:1e-1,
#                                                40000:1e-2}
#                         )

#I:9 C:3 M:13 L:12
l_list = [9,3,13,12]
training_images = []
training_labels = []
for i in range(4):
    a_o,b_o = init_image(l_list[i],i)
    training_images.append(a_o)
    training_labels.append(b_o)
training_images = jnp.concatenate(training_images,axis=0)
training_labels = jnp.concatenate(training_labels,axis=0)


key_a = random.split(key, num=2)
indices = jax.random.permutation(key_a[0], training_images.shape[0])
training_images = training_images[indices].reshape(training_images.shape[0],-1)
training_labels = training_labels[indices]

smp = PlanarSamplerAutoEncoder(False,training_images,training_labels,T=0.5,N=1000)

print(key_a)
runBallExperiment(params=params, 
                  key=key_a[1],
                  pde=pde,
                  loss=loss,
                  pinn=u_inference,
                  advect_time_step=advect_time_step,
                  time_step=time_step,
                  sched=sched,
                  apx=str(seed)+"ncl_periods",
                  mlp = mlp,
                  smp = smp,
                    )
