"""
Copyright (c) Meta Platforms, Inc. and affiliates.
All rights reserved.

This source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
"""

from flax import linen as nn
import jax.numpy as jnp
import jax
from jax import jacfwd, grad, vmap
import pickle
from utils import div
import jax.scipy as jsp

class MLP(nn.Module):
    depth: int
    width: int
    out_dim: int
    std: float #scaling factor for initialization
    act: callable
    bias: bool

    def setup(self):
        self.layers = [nn.Dense(self.width,use_bias=self.bias) for _ in range(self.depth)] + [nn.Dense(self.out_dim,use_bias=self.bias)]
        # self.layers_1 = [nn.Dense(self.width,use_bias=self.bias) for _ in range(self.depth)] + [nn.Dense(1,use_bias=self.bias)]
        # self.T = 20
        # self.k = 20
        # from ott.solvers.nn import models
        # self.icnn = models.ICNN(dim_data=2, dim_hidden=[8,8,1])

    def __call__(self, inputs):
        x = inputs
        #x = x.at[1:].set(jnp.sin(x[1:]))
        # t = inputs[:1]
        for i, lyr in enumerate(self.layers):
            x = lyr(x)
            # t = self.layers_1[i](t)
            # t = self.act(t)
            if i != len(self.layers) - 1:
                x = self.act(x)
        #     # else:
        #     #     x = x * jnp.minimum(1/(jnp.sin(2*inputs[1])+1e-8),10)
        # #jax.debug.print("rho: {}",x[0])
        # x = x.at[0].set(self.icnn(inputs[1:])*t[0])
        return x
    
class MLP_Skip(nn.Module):
    depth: int
    width: int
    out_dim: int
    std: float #scaling factor for initialization
    act: callable
    bias: bool

    def setup(self):
        self.layers = [nn.Dense(self.width,use_bias=self.bias) for _ in range(self.depth)] + [nn.Dense(self.out_dim,use_bias=self.bias)]
        self.skip = [nn.Dense(self.width,use_bias=self.bias) for _ in range(self.depth)] + [nn.Dense(self.out_dim,use_bias=self.bias)]
            
    def __call__(self, inputs):
        x = inputs
        for i, lyr in enumerate(zip(self.layers,self.skip)):
            x = lyr[0](x) + lyr[1](inputs)
            if i != len(self.layers) - 1:
                x = self.act(x)
        return x
    

#class that parameterizes the NCL using the matrix form
class NCL(object):
    def __init__(self,network,mass_constant=2):
        self.network = network
        self.mc = mass_constant
        
    #return type of NCL is [rho,rho u, p] (note middle!)
    def __call__(self,x,params):
        
        def A(x):
            u_v = self.network(x,params[0])[:-1]
            N = len(x)
            A = jnp.zeros((N,N))
            idx = jnp.triu_indices(N,1)
            A = A.at[idx].set(u_v)

            return A - A.T
        u_v = self.network(x,params[0])
        return jnp.array([*div(A)(x),u_v[3]]) + jnp.array([self.mc, *params[1],0]) 

#class that parameterizes the NCL using the matrix form
class NCL_sparse(object):
    def __init__(self,network,mass_constant=2):
        self.network = network
        self.mc = mass_constant
        
    #return type of NCL is [rho,rho u, p] (note middle!)
    def __call__(self,x,params):
        
        def A(x):
            u_v = self.network(x,params[0])[:-1]
            A = jnp.diag((u_v*jnp.roll(x,-1))[:-1],k=1)
            A = A.at[0,-1].set(x[0]*u_v[-1]) 

            return A - A.T
        u_v = self.network(x,params[0])
        return jnp.array([*div(A)(x),u_v[3]]) + jnp.array([self.mc, *params[1],0]) 
        
#class that parameterizes the NCL using the vector form
class NCLImplicit(object):
    def __init__(self,network):
        self.network = network
        
    #return type of NCL is [rho,rho u, p] (note middle!)
    def __call__(self,x,params):
        
        def A(x):
            u = lambda x: self.network(x,params)[:-1]
            A = jacfwd(u)(x)
            return A - A.T
        u_x = self.network(x,params)
        return jnp.array([*div(A)(x),u_x[-1]])
    
#utility bump function
def bump(x):
    dim = x.shape[0]
    const = dim*5
    bump_inner = lambda x: jnp.exp(-1/(1-jnp.dot(x,x)/const))
    return jax.lax.cond(jnp.dot(x,x) < const - 1e-3, bump_inner,lambda x: 0., x)
    
#class that parameterizes arbitrary div_free net
class DivFree(object):
    def __init__(self,network):
        self.network = network
        
    def __call__(self,x,params):
        
        def A(x):
            u_v = self.network(x,params)
            N = len(x)
            
            A = jnp.zeros((N,N))
            idx = jnp.triu_indices(N,1)
            A = A.at[idx].set(u_v) 
            #A = A*jnp.roll(x,1).reshape(-1,1)

            return A - A.T
        return div(A)(x)
#class that parameterizes arbitrary div_free net
class DivFreeSparse(object):
    def __init__(self,network):
        self.network = network
        
    def __call__(self,x,params):
        
        def A(x):
            u_v = self.network(x,params)
            A = jnp.diag((u_v*jnp.roll(x,-1))[:-1],k=1)
            A = A.at[0,-1].set(x[0]*u_v[-1]) 

            return A - A.T
        return div(A)(x)
        
        
class DivFreeImplicit(object):
    def __init__(self,network):
        self.network = network
        
    def __call__(self,x,params):
        
        def A(x):
            u = lambda x: self.network(x,params)
            A = jacfwd(u)(x)
            return A - A.T
        return div(A)(x)
    
#class for the mixture of gaussians score
class MixtureScore:
    def __init__(self,mu0,mu1):
        self.mu0 = mu0
        self.mu1 = mu1
        self.sig0 = 1
        self.sig1 = 1e-2

    def sigma(self,t):
        
        return self.sig0*(1-t)**2 + self.sig1*t**2

    def __call__(self,x):
        #log_pdf without constants
        t = x[0]
        x = x[1:]
        mu_total = self.mu0*(1-t) + self.mu1*t
        

        log_pdf = lambda x: jnp.sum(vmap(lambda mu: jsp.stats.multivariate_normal.logpdf(x,mean=mu,cov=jnp.sqrt(self.sigma(t))))(mu_total))
        return grad(log_pdf)(x)


from flax import linen as nn

import jax.numpy as jnp
from jax.nn.initializers import uniform as uniform_init
from jax import lax
from jax.random import uniform
from typing import Any, Callable, Sequence, Tuple
from functools import partial
import jax

def siren_init(weight_std, dtype):
    def init_fun(key, shape, dtype=dtype):
        if dtype == jnp.dtype(jnp.array([1j])):
            key1, key2 = jax.random.split(key)
            dtype = jnp.dtype(jnp.array([1j]).real)
            a = uniform(key1, shape, dtype) * 2 * weight_std - weight_std
            b = uniform(key2, shape, dtype) * 2 * weight_std - weight_std
            return a + 1j*b
        else:
            return uniform(key, shape, dtype) * 2 * weight_std - weight_std

    return init_fun


class Sine(nn.Module):
    w0: float = 1.0
    dtype: Any = jnp.float32

    @nn.compact
    def __call__(self, inputs):
        inputs = jnp.asarray(inputs, self.dtype)
        return jnp.sin(self.w0 * inputs)


class SirenLayer(nn.Module):
    features: int = 32
    w0: float = 1.0
    c: float = 6.0
    is_first: bool = False
    use_bias: bool = True
    act: Callable = jnp.sin
    precision: Any = None
    dtype: Any = jnp.float32

    @nn.compact
    def __call__(self, inputs):
        inputs = jnp.asarray(inputs, self.dtype)
        input_dim = inputs.shape[-1]

        # Linear projection with init proposed in SIREN paper
        weight_std = (
            (1 / input_dim) if self.is_first else jnp.sqrt(self.c / input_dim) / self.w0
        )

        kernel = self.param(
            "kernel", siren_init(weight_std, self.dtype), (input_dim, self.features)
        )
        kernel = jnp.asarray(kernel, self.dtype)

        y = lax.dot_general(
            inputs,
            kernel,
            (((inputs.ndim - 1,), (0,)), ((), ())),
            precision=self.precision,
        )

        if self.use_bias:
            bias = self.param("bias", nn.initializers.zeros_init(), (self.features,))
            bias = jnp.asarray(bias, self.dtype)
            y = y + bias

        return self.act(self.w0 * y)
class Siren(nn.Module):
    hidden_dim: int = 256
    output_dim: int = 3
    num_layers: int = 5
    w0: float = 1.0
    w0_first_layer: float = 1.0
    use_bias: bool = True
    final_activation: Callable = lambda x: x  # Identity
    dtype: Any = jnp.float32

    @nn.compact
    def __call__(self, inputs):
        x = jnp.asarray(inputs, self.dtype)
        
        for layernum in range(self.num_layers - 1):
            is_first = layernum == 0

            x = SirenLayer(
                features=self.hidden_dim,
                w0=self.w0_first_layer if is_first else self.w0,
                is_first=is_first,
                use_bias=self.use_bias,
            )(x)

        # Last layer, with different activation function
        x = SirenLayer(
            features=self.output_dim,
            w0=self.w0,
            is_first=False,
            use_bias=self.use_bias,
            act=self.final_activation,
        )(x)

        return x


        



# class Siren(MLP):
#     layers: list[eqx.Module]
#     input_scale: float

#     def __init__(self,
#                  in_features: int,
#                  hidden_features: int,
#                  hidden_layers: int,
#                  out_features: int,
#                  key: jax.random.PRNGKey,
#                  first_omega_0: float = 30,
#                  hidden_omega_0: float = 30,
#                  input_scale: float = 1,
#                  **kwargs):
#         keys = jax.random.split(key, hidden_layers + 2)
#         self.input_scale = input_scale

#         # Section 3.2
#         # For [-1, 1], first_omega_0 span it to [-30, 30]
#         # Here to scale periods back
#         first_omega_0 = first_omega_0 / input_scale

#         self.layers = [
#             SineLayer(in_features,
#                       hidden_features,
#                       keys[0],
#                       is_first=True,
#                       omega_0=first_omega_0)
#         ] + [
#             SineLayer(hidden_features,
#                       hidden_features,
#                       keys[i + 1],
#                       omega_0=hidden_omega_0) for i in range(hidden_layers)
#         ] + [Linear(hidden_features, out_features, keys[-1], False)]

#     def single_call(self, x, z):
#         x = jnp.hstack([self.input_scale * x, z])
#         for i in range(len(self.layers)):
#             x = self.layers[i](x)
#         return x
