import pdb
from typing import Sequence

import jax.numpy as jnp
from flax import linen as nn
from functools import partial
from typing import Any, Callable, Sequence, Tuple, Optional, Union, Dict

from flax import linen as nn
from flax.core.frozen_dict import freeze

from jax import random, jit, vmap
import jax.numpy as jnp
from jax.nn.initializers import glorot_normal, normal, zeros, constant, uniform
import numpy as np

activation_fn = {
    "relu": nn.relu,
    "gelu": nn.gelu,
    "swish": nn.swish,
    "sigmoid": nn.sigmoid,
    "tanh": jnp.tanh,
    "sin": jnp.sin,
}


def _get_activation(str):
    if str in activation_fn:
        return activation_fn[str]

    else:
        raise NotImplementedError(f"Activation {str} not supported yet!")


def _weight_fact(init_fn, mean, stddev):
    def init(key, shape):
        key1, key2 = random.split(key)
        w = init_fn(key1, shape)
        g = mean + normal(stddev)(key2, (shape[-1],))
        g = jnp.exp(g)
        v = w / g
        return g, v

    return init


class PeriodEmbs(nn.Module):
    period: Tuple[float]  # Periods for different axes
    axis: Tuple[int]  # Axes where the period embeddings are to be applied
    trainable: Tuple[
        bool
    ]  # Specifies whether the period for each axis is trainable or not

    def setup(self):
        # Initialize period parameters as trainable or constant and store them in a flax frozen dict
        period_params = {}
        for idx, is_trainable in enumerate(self.trainable):
            if is_trainable:
                period_params[f"period_{idx}"] = self.param(
                    f"period_{idx}", constant(self.period[idx]), ()
                )
            else:
                period_params[f"period_{idx}"] = self.period[idx]

        self.period_params = freeze(period_params)

    @nn.compact
    def __call__(self, x):
        """
        Apply the period embeddings to the specified axes.
        """
        y = []

        for i, xi in enumerate(x):
            if i in self.axis:
                idx = self.axis.index(i)
                period = self.period_params[f"period_{idx}"]
                y.extend([jnp.cos(period * xi), jnp.sin(period * xi)])
            else:
                y.append(xi)

        return jnp.hstack(y)


class FourierEmbs(nn.Module):
    embed_scale: float
    embed_dim: int

    @nn.compact
    def __call__(self, x):
        kernel = self.param(
            "kernel", normal(self.embed_scale), (x.shape[-1], self.embed_dim // 2)
        )
        y = jnp.concatenate(
            [jnp.cos(jnp.dot(x, kernel)), jnp.sin(jnp.dot(x, kernel))], axis=-1
        )
        return y


class Dense(nn.Module):
    features: int
    kernel_init: Callable = glorot_normal()
    bias_init: Callable = zeros
    reparam: Union[None, Dict] = None

    @nn.compact
    def __call__(self, x):
        if self.reparam is None:
            kernel = self.param(
                "kernel", self.kernel_init, (x.shape[-1], self.features)
            )

        elif self.reparam["type"] == "weight_fact":
            g, v = self.param(
                "kernel",
                _weight_fact(
                    self.kernel_init,
                    mean=self.reparam["mean"],
                    stddev=self.reparam["stddev"],
                ),
                (x.shape[-1], self.features),
            )
            kernel = g * v

        bias = self.param("bias", self.bias_init, (self.features,))

        y = jnp.dot(x, kernel) + bias

        return y


        
def _navier_stokes4d_exact_w(t, x, y, z, nu):
    # analytic form of vortcity
    w_x = -3*jnp.exp(-9*nu*t)*jnp.sin(2*x)*jnp.cos(2*y)*jnp.cos(z)
    w_y = 6*jnp.exp(-9*nu*t)*jnp.cos(2*x)*jnp.sin(2*y)*jnp.cos(z)
    w_z = -6*jnp.exp(-9*nu*t)*jnp.cos(2*x)*jnp.cos(2*y)*jnp.sin(z)
    return w_x, w_y, w_z

def _navier_stokes4d_exact_u(t, x, y, z, nu=0.05):
    # analytic form of velocity
    u_x = 2*jnp.exp(-9*nu*t)*jnp.cos(2*x)*jnp.sin(2*y)*jnp.sin(z)
    u_y = -1*jnp.exp(-9*nu*t)*jnp.sin(2*x)*jnp.cos(2*y)*jnp.sin(z)
    u_z = -2*jnp.exp(-9*nu*t)*jnp.sin(2*x)*jnp.sin(2*y)*jnp.cos(z)
    return u_x, u_y, u_z


class NS_exact(nn.Module):
    @nn.compact
    def __call__(self, t, x, y, z):
        # pdb.set_trace()
        if jnp.ndim(t) > 1:
            t = jnp.squeeze(t, axis=1)
        if jnp.ndim(x) > 1:
            x = jnp.squeeze(x, axis=1)
        if jnp.ndim(y) > 1:
            y = jnp.squeeze(y, axis=1)
        if jnp.ndim(z) > 1:
            z = jnp.squeeze(z, axis=1)
        t, x, y, z = jnp.meshgrid(t, x, y, z, indexing='ij')
        u_x, u_y, u_z = _navier_stokes4d_exact_u(t, x, y, z)
        # pdb.set_trace()
        return u_x, u_y, u_z



class PINN2d(nn.Module):
    features: Sequence[int]

    @nn.compact
    def __call__(self, x, y):
        X = jnp.concatenate([x, y], axis=1)
        init = nn.initializers.glorot_normal()
        for fs in self.features[:-1]:
            X = nn.Dense(fs, kernel_init=init)(X)
            X = nn.activation.tanh(X)
        X = nn.Dense(self.features[-1], kernel_init=init)(X)
        return X
    
import pdb
from typing import Sequence

import jax.numpy as jnp
from flax import linen as nn
from functools import partial
from jax.nn.initializers import constant, normal, uniform
from functools import partial
from typing import Any, Callable, Sequence, Tuple, Optional, Union, Dict

from flax import linen as nn
from flax.core.frozen_dict import freeze
from jax import lax

class Gaussian3dFull(nn.Module):
    num_gaussian: int
    grid_range: float
    sigmas_range : float
    mlp_dim : int

    def setup(self):
        self.mu_x = self.param("mu_x", uniform(self.grid_range), (self.mlp_dim, 1, self.num_gaussian, 1))
        self.mu_y = self.param("mu_y", uniform(self.grid_range), (self.mlp_dim, 1, self.num_gaussian, 1))
        self.mu_z = self.param("mu_z", uniform(self.grid_range), (self.mlp_dim, 1, self.num_gaussian, 1))
        self.sigmas = self.param("sigmas", constant(self.sigmas_range), (self.mlp_dim, 1, self.num_gaussian, 3))
        self.r = self.param("r", constant(1.0), (self.mlp_dim, 1, self.num_gaussian, 4))
        self.weight = self.param("weight", normal(), (self.mlp_dim, self.num_gaussian, 1))

        L = self.build_scaling_rotation(self.sigmas, self.r)
        self.cov = L @ L.transpose(0, 1, 2, 4, 3)
    @nn.compact
    def __call__(self, x, y, z):
        # klein-gordon 3d
        x = (x/ 10.) * self.grid_range
        y = ((y+1.) / 2.) * self.grid_range
        z = ((z+1.) / 2.) * self.grid_range

        mu = jnp.concatenate([self.mu_x, self.mu_y, self.mu_z], -1)

        X = jnp.concatenate([x[None, :, None, :], y[None, :, None, :], z[None, :, None, :] ], -1)
        d = X - mu
        
        out = self.cov @ d[..., None]
        out = d[..., None, :] @ out
        out = out.squeeze()

        pdf = jnp.exp(-0.5*out)

        rasterized_color_primes = pdf * self.weight.squeeze()[:, None, :]

        output = rasterized_color_primes.sum(2)
        
        return output.T #.reshape(-1, 1)

    def build_rotation(self, r):
        norm = jnp.sqrt(r[..., 0]*r[..., 0] + r[..., 1]*r[..., 1] + r[..., 2]*r[..., 2] + r[..., 3]*r[..., 3])

        q = r / norm[..., None]

        R = jnp.zeros((q.shape[0], q.shape[1], q.shape[2], 3, 3))

        r = q[..., 0]
        x = q[..., 1]
        y = q[..., 2]
        z = q[..., 3]

        R = R.at[..., 0, 0].set(1 - 2 * (y*y + z*z))
        R = R.at[..., 0, 1].set(2 * (x*y - r*z))
        R = R.at[..., 0, 2].set(2 * (x*z + r*y))
        R = R.at[..., 1, 0].set(2 * (x*y + r*z))
        R = R.at[..., 1, 1].set(1 - 2 * (x*x + z*z))
        R = R.at[..., 1, 2].set(2 * (y*z - r*x))
        R = R.at[..., 2, 0].set(2 * (x*z - r*y))
        R = R.at[..., 2, 1].set(2 * (y*z + r*x))
        R = R.at[..., 2, 2].set(1 - 2 * (x*x + y*y))
        return R

    def build_scaling_rotation(self, s, r):
        L = jnp.zeros((s.shape[0],s.shape[1],s.shape[2], 3, 3))
        R = self.build_rotation(r)

        L = L.at[..., 0, 0].set(s[...,0])
        L = L.at[..., 1, 1].set(s[...,1])
        L = L.at[..., 2, 2].set(s[...,2])
        L = R @ L
        return L


class PINN3d(nn.Module):
    features: Sequence[int]
    out_dim: int
    pos_enc: int
    num_gaussian: int = 100
    grid_range: float = 2.
    sigmas_range : float = 15.
    mlp_dim: int= 4
    reparam: Union[None, Dict] = None

    def setup(self):
        self.activation_fn = _get_activation('tanh')

    @nn.compact
    def __call__(self, x, y, z):
        X = Gaussian3dFull(self.num_gaussian, self.grid_range, self.sigmas_range, self.mlp_dim)(x,y,z)
        for _ in range(len(self.features)-1):
            X = Dense(features=self.features[0], reparam=self.reparam)(X)
            X = self.activation_fn(X)

        X = Dense(features=self.out_dim, reparam=self.reparam)(X)
        return X

class SPINN3d(nn.Module):
    features: Sequence[int]
    r: int
    out_dim: int
    pos_enc: int
    mlp: str

    @nn.compact
    def __call__(self, x, y, z):
        '''
        inputs: input factorized coordinates
        outputs: feature output of each body network
        xy: intermediate tensor for feature merge btw. x and y axis
        pred: final model prediction (e.g. for 2d output, pred=[u, v])
        '''
        if self.pos_enc != 0:
            # positional encoding only to spatial coordinates
            freq = jnp.expand_dims(jnp.arange(1, self.pos_enc+1, 1), 0)
            y = jnp.concatenate((jnp.ones((y.shape[0], 1)), jnp.sin(y@freq), jnp.cos(y@freq)), 1)
            z = jnp.concatenate((jnp.ones((z.shape[0], 1)), jnp.sin(z@freq), jnp.cos(z@freq)), 1)

            # causal PINN version (also on time axis)
            #  freq_x = jnp.expand_dims(jnp.power(10.0, jnp.arange(0, 3)), 0)
            # x = x@freq_x
            
        inputs, outputs, xy, pred = [x, y, z], [], [], []
        init = nn.initializers.glorot_normal()

        if self.mlp == 'mlp':
            for X in inputs:
                for fs in self.features[:-1]:
                    X = nn.Dense(fs, kernel_init=init)(X)
                    X = nn.activation.tanh(X)
                X = nn.Dense(self.r*self.out_dim, kernel_init=init)(X)
                outputs += [jnp.transpose(X, (1, 0))]

        elif self.mlp == 'modified_mlp':
            for X in inputs:
                U = nn.activation.tanh(nn.Dense(self.features[0], kernel_init=init)(X))
                V = nn.activation.tanh(nn.Dense(self.features[0], kernel_init=init)(X))
                H = nn.activation.tanh(nn.Dense(self.features[0], kernel_init=init)(X))
                for fs in self.features[:-1]:
                    Z = nn.Dense(fs, kernel_init=init)(H)
                    Z = nn.activation.tanh(Z)
                    H = (jnp.ones_like(Z)-Z)*U + Z*V
                H = nn.Dense(self.r*self.out_dim, kernel_init=init)(H)
                outputs += [jnp.transpose(H, (1, 0))]
        
        for i in range(self.out_dim):
            xy += [jnp.einsum('fx, fy->fxy', outputs[0][self.r*i:self.r*(i+1)], outputs[1][self.r*i:self.r*(i+1)])]
            pred += [jnp.einsum('fxy, fz->xyz', xy[i], outputs[-1][self.r*i:self.r*(i+1)])]

        if len(pred) == 1:
            # 1-dimensional output
            return pred[0]
        else:
            # n-dimensional output
            return pred

