# Basic Library Imports
import jax
import jax.numpy as jnp
from jax import random
from jax import vmap, jit

from flax import linen as nn

from typing import Any, Callable, Sequence, Tuple, Union


PrecisionLike = Union[None, str, jax.lax.Precision, Tuple[str, str],
                      Tuple[jax.lax.Precision, jax.lax.Precision]]
identity = lambda x : x


######################################################
#################### Initializers ####################
######################################################

# Siren Initialization
def siren_initializer(key, shape, dtype=jnp.float32):
  aux = jnp.sqrt(6. / shape[0])
  return random.uniform(key, shape=shape, minval=-aux, maxval=aux, dtype=dtype)

def siren_first_layer_initializer(key, shape, dtype):
  aux = 1/shape[0]
  return random.uniform(key, shape, minval=-aux, maxval=aux, dtype=dtype)

# Custom Initialization
def kan_initializer(key, shape, dtype=jnp.float32, sigma_0=0.1):
  aux = sigma_0/jnp.sqrt(shape[0])
  return aux*random.normal(key, shape=shape, dtype=dtype)

def get_kan_initializer(sigma=0.1):
    return lambda key, shape, dtype=jnp.float32 : kan_initializer(key, shape, dtype=dtype, sigma_0=sigma)


###############################################################
######################## Architectures ########################
###############################################################

#############
#### MLP ####
#############

class MLP(nn.Module):
  features: Sequence[int]
  activation : Callable=nn.gelu
  output_activation : Callable=identity
  precision: PrecisionLike = None

  @nn.compact
  def __call__(self, x):
    for feat in self.features[:-1]:
      x = self.activation(nn.Dense(feat, precision=self.precision)(x))
    x = nn.Dense(self.features[-1], precision=self.precision)(x)
    return self.output_activation(x)
  

###############
#### Siren ####
###############

# see https://arxiv.org/abs/2006.09661 for details about Siren, which is an MLP with sine as activation
# See below for an iteractive colab notebook provided by the authors:
# https://colab.research.google.com/github/vsitzmann/siren/blob/master/explore_siren.ipynb

class Siren(nn.Module):
  features: Sequence[int]
  w0 : float
  output_activation : Callable=identity
  precision: PrecisionLike = None

  @nn.compact
  def __call__(self, x):
    x = x*self.w0
    x = jnp.sin(nn.Dense(self.features[0], kernel_init=siren_first_layer_initializer, precision=self.precision)(x))
    for feat in self.features[1:-1]:
      x = jnp.sin(nn.Dense(feat, kernel_init=siren_initializer, precision=self.precision)(x))
    x = nn.Dense(self.features[-1])(x)
    return self.output_activation(x)


################
#### ActNet ####
################

# from https://www.wolframalpha.com/input?i=E%5B%28sin%28wx%2Bp%29%29%5D+where+x+is+normally+distributed
def _mean_transf(mu, sigma, w, p):
    return jnp.exp(-0.5* (sigma*w)**2) * jnp.sin(p + mu*w)

# from https://www.wolframalpha.com/input?i=E%5Bsin%28wx%2Bp%29%5E2%5D+where+x+is+normally+distributed
def _var_transf(mu, sigma, w, p):
    return 0.5 - 0.5*jnp.exp(-2 * ((sigma*w)**2))*jnp.cos(2*(p+mu*w)) - _mean_transf(mu, sigma, w, p)**2

class ActLayer(nn.Module):
    out_dim : int
    num_freqs : int
    use_bias : bool=True
    # parameter initializers
    freqs_init : Callable=nn.initializers.normal(stddev=1.)  # normal entries w/ mean zero
    phases_init : Callable=nn.initializers.zeros
    beta_init : Callable=nn.initializers.variance_scaling(1., 'fan_in', distribution='uniform')
    lamb_init : Callable=nn.initializers.variance_scaling(1., 'fan_in', distribution='uniform')
    bias_init : Callable=nn.initializers.zeros
    # other configurations
    freeze_basis : bool=False
    freq_scaling : bool=True
    freq_scaling_eps : float=1e-3 # used for numerical stability of gradients
    precision: PrecisionLike = None

    @nn.compact
    def __call__(self, x):
        # x should initially be shape (batch, d)

        # initialize trainable parameters
        freqs = self.param('freqs',
                           self.freqs_init,
                           (1,1,self.num_freqs)) # shape (1, 1, num_freqs)
        phases = self.param('phases',
                            self.phases_init,
                            (1,1,self.num_freqs)) # shape (1, 1, num_freqs)
        beta = self.param('beta',
                          self.beta_init,
                          (self.num_freqs, self.out_dim)) # shape (num_freqs, out_dim)
        lamb = self.param('lamb',
                          self.lamb_init,
                          (x.shape[-1], self.out_dim)) # shape (d, out_dim)

        if self.freeze_basis:
            freqs = jax.lax.stop_gradient(freqs)
            phases = jax.lax.stop_gradient(phases)
        
        # perform basis expansion
        x = jnp.expand_dims(x, 2) # shape (batch, d, 1)
        x = jnp.sin(freqs*x + phases) # shape (batch_dim, d, num_freqs)
        if self.freq_scaling:
            x = (x - _mean_transf(0., 1., freqs, phases)) / (jnp.sqrt(self.freq_scaling_eps + _var_transf(0., 1., freqs, phases)))

        
        # combines lamb and beta into a single matrix 'aux'
        # 'aux' encodes out_dim outter products between rows of beta and columns of lamb
        # this is a batch-efficient way of carrying out the forward pass prescibed by the Kolmogorov representation
        # otherwise, for each element of the batch it would implicitly repeat several computations
        # (there are likely more elegant ways of doing this)
        # this whole block can also be implemented as x=jnp.einsum('bij, jk, ik->bk', x, beta, lamb), but runs slower on my computer (maybe a JAX bug?)
        aux = jnp.matmul(lamb.T[...,None], beta.T[:,None,:], precision=self.precision) # shape (out_dim, d, num_freqs)
        aux = aux.reshape((self.out_dim,-1)) # shape (out_dim, d*num_freqs)
        aux = aux.T # shape (d*num_freqs, out_dim)
        x = x.reshape((x.shape[0], -1)) # shape (batch, d*num_freqs)
        x = jnp.matmul(x, aux, precision=self.precision) # Shape (batch_size, out_dim)

        # optionally add bias
        if self.use_bias:
           bias = self.param('bias',
                             self.bias_init,
                             (self.out_dim,))
           x = x + bias # Shape (batch_size, out_dim)

        return x # Shape (batch_size, out_dim)
    

class ActNet(nn.Module):
    embed_dim : int
    num_layers : int
    out_dim : int
    num_freqs : int
    output_activation : Callable = identity
    op_order : str='A' # string containing only 'A' (ActLayer), 'S' (Skip connection) or 'L' (LayerNorm) characters
    # op_order was used for development/debugging, but is not used in any experiment

    # parameter initializers
    freqs_init : Callable=nn.initializers.normal(stddev=1.)  # normal entries w/ mean zero
    phases_init : Callable=nn.initializers.zeros
    beta_init : Callable=nn.initializers.variance_scaling(1., 'fan_in', distribution='uniform')
    lamb_init : Callable=nn.initializers.variance_scaling(1., 'fan_in', distribution='uniform')
    act_bias_init : Callable=nn.initializers.zeros
    proj_bias_init : Callable=lambda key, shape, dtype : random.uniform(key, shape, dtype, minval=-jnp.sqrt(3), maxval=jnp.sqrt(3))
    
    w0_init : Callable=nn.initializers.constant(30.) # following SIREN strategy
    w0_fixed : Union[False, float]=False

    # other ActLayer configurations
    use_act_bias : bool=True
    freeze_basis : bool=False
    freq_scaling : bool=True
    freq_scaling_eps : float=1e-3 # used for numerical stability of gradients
    precision: PrecisionLike = None

    @nn.compact
    def __call__(self, x):
        # initialize w0 parameter
        if self.w0_fixed is False:
            # trainable scalar parameter
            w0 = self.param('w0',
                            self.w0_init,
                            ())
            # use softplus to ensure w0 is positive and does not decay to zero too fast (used only while debugging)
            w0 = nn.softplus(w0)
        else: # use user-specified value for w0
            w0 = self.w0_fixed
        # project to embeded dimension
        x = x*w0
        x = nn.Dense(self.embed_dim, bias_init=self.proj_bias_init, precision=self.precision)(x)
        
        for _ in range(self.num_layers):
            y = x # store initial value as x, do operations on y
            for char in self.op_order:
                if char == 'A': # ActLayer
                    y  = ActLayer(
                            out_dim = self.embed_dim,
                            num_freqs = self.num_freqs,
                            use_bias = self.use_act_bias,
                            freqs_init = self.freqs_init,
                            phases_init = self.phases_init,
                            beta_init = self.beta_init,
                            lamb_init = self.lamb_init,
                            bias_init = self.act_bias_init,
                            freeze_basis = self.freeze_basis,
                            freq_scaling = self.freq_scaling,
                            freq_scaling_eps = self.freq_scaling_eps,
                            precision=self.precision,
                            )(y)
                elif char == 'S': # Skip connection
                    y = y + x
                elif char == 'L': # LayerNorm
                    y = nn.LayerNorm()(y)
                else:
                    raise NotImplementedError(f"Could not recognize option '{char}'. Options for op_order should be 'A' (ActLayer), 'S' (Skip connection) or 'L' (LayerNorm).")
            x = y # update value of x after all operations are done

        # project to output dimension and potentially use output activation
        x = nn.Dense(self.out_dim, kernel_init=nn.initializers.he_uniform(), precision=self.precision)(x)
        x = self.output_activation(x)

        return x


##############
#### KAN #####
##############

# adapted to JAX from the "EfficientKAN" GitHub repository (PyTorch)
# https://github.com/Blealtan/efficient-kan/blob/master/src/efficient_kan/kan.py

class KANLinear(nn.Module):
    in_features : int
    out_features : int
    grid_size : int=5
    spline_order: int=3
    scale_noise : float=0.1
    scale_base : float=1.0
    scale_spline : float=1.0
    enable_standalone_scale_spline : bool=True
    base_activation : Callable=nn.silu
    grid_eps : float=0.02
    grid_range : Sequence[Union[float, int]]=(-1,1)
    precision: PrecisionLike = None

    def setup(self):
        h = (self.grid_range[1] - self.grid_range[0]) / self.grid_size
        self.h = h
        grid = (
            (
                jnp.arange(start=-self.spline_order, stop=self.grid_size + self.spline_order + 1) * h
                + self.grid_range[0]
            )
        )
        self.grid = grid * jnp.ones((self.in_features, 1))

        self.base_weight = self.param('base_weight', # parameter name
                                      nn.initializers.he_uniform(), # initialization funciton
                                      (self.out_features, self.in_features)) # shape info
        self.spline_weight = self.param('spline_weight', # parameter name
                                        nn.initializers.he_uniform(), # initialization funciton
                                        (self.out_features, self.in_features, self.grid_size+self.spline_order)) # shape info

        if self.enable_standalone_scale_spline:
            self.spline_scaler = self.param('spline_scaler', # parameter name
                                            nn.initializers.he_uniform(), # initialization funciton
                                            (self.out_features, self.in_features)) # shape info
            

    def b_splines(self, x: jax.Array):
        """
        Compute the B-spline bases for the given input tensor.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_features).

        Returns:
            torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order).
        """
        assert len(x.shape) == 2 and x.shape[1] == self.in_features

        # grid is shape (in_features, grid_size + 2 * spline_order + 1)
        grid = self.grid
        x = jnp.expand_dims(x, -1)
        bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:]))
        for k in range(1, self.spline_order + 1):
            bases = (
                (x - grid[:, : -(k + 1)])
                / (grid[:, k:-1] - grid[:, : -(k + 1)])
                * bases[:, :, :-1]
            ) + (
                (grid[:, k + 1 :] - x)
                / (grid[:, k + 1 :] - grid[:, 1:(-k)])
                * bases[:, :, 1:]
            )

        assert bases.shape == (
            x.shape[0],
            self.in_features,
            self.grid_size + self.spline_order,
        )
        return bases

    @property
    def scaled_spline_weight(self):
        return self.spline_weight * (
            jnp.expand_dims(self.spline_scaler, -1)
            if self.enable_standalone_scale_spline
            else 1.0
        )

    def __call__(self, x: jax.Array):
        assert x.shape[-1] == self.in_features, f"x.shape[-1]={x.shape[-1]} should be equal to {self.in_features}"
        original_shape = x.shape
        x = x.reshape(-1, self.in_features)

        base_output = jnp.matmul(self.base_activation(x), self.base_weight.T, precision=self.precision)
        spline_output = jnp.matmul(
            self.b_splines(x).reshape(x.shape[0], -1),
            self.scaled_spline_weight.reshape(self.out_features, -1).T,
            precision=self.precision,
        )
        output = base_output + spline_output
        
        output = output.reshape(*original_shape[:-1], self.out_features)
        return output
    

class KAN(nn.Module):
    features : Sequence[int]
    output_activation : Callable=identity
    grid_size : int=5
    spline_order: int=3
    scale_noise : float=0.1
    scale_base : float=1.0
    scale_spline : float=1.0
    enable_standalone_scale_spline : bool=True
    base_activation : Callable=nn.silu
    grid_eps : float=0.02
    grid_range : Sequence[Union[float, int]]=(-1,1)
    precision: PrecisionLike = None

    def setup(self):
        self.layers = [KANLinear(
            self.features[i],
            self.features[i+1],
            grid_size=self.grid_size,
            spline_order=self.spline_order,
            scale_noise=self.scale_noise,
            scale_base=self.scale_base,
            scale_spline=self.scale_spline,
            enable_standalone_scale_spline=self.enable_standalone_scale_spline,
            base_activation=self.base_activation,
            grid_eps=self.grid_eps,
            grid_range=self.grid_range,
            precision=self.precision,
                                 ) for i in range(len(self.features) - 1)]
        
    def __call__(self, x):
        for l in self.layers:
            x = l(x)
        return self.output_activation(x)
    


############################################################
################### Architecture Builder ###################
############################################################


def arch_from_config(arch_config):
    if arch_config.arch_type == 'ActNet':
        arch = ActNet(**arch_config.hyperparams)
        return arch
    elif arch_config.arch_type == 'MLP':
        arch = MLP(**arch_config.hyperparams)
        return arch
    elif arch_config.arch_type == 'Siren':
        arch = Siren(**arch_config.hyperparams)
        return arch
    elif arch_config.arch_type == 'KAN':
        arch = KAN(**arch_config.hyperparams)
        return arch
    else:
        raise NotImplementedError(f"Cannot recognize arch_type {arch_config.arch_type}. So far, only 'ActNet', 'MLP', 'Siren' and 'KAN' are implemented")