from functools import partial
from typing import Any, Callable, Sequence, Tuple, Optional, Union, Dict

from flax import linen as nn
from flax.core.frozen_dict import freeze

from jax import random, jit, vmap
import jax.numpy as jnp
from jax.nn.initializers import glorot_normal, kaiming_normal, normal, zeros, constant

activation_fn = {
    "relu": nn.relu,
    "gelu": nn.gelu,
    "swish": nn.swish,
    "sigmoid": nn.sigmoid,
    "tanh": jnp.tanh,
    "sin": jnp.sin,
}


def _get_activation(str):
    if str in activation_fn:
        return activation_fn[str]

    else:
        raise NotImplementedError(f"Activation {str} not supported yet!")


# def _weight_fact(init_fn, mean, stddev):
#     def init(key, shape):
#         key1, key2 = random.split(key)
#         w = init_fn(key1, shape)
#         g = mean + normal(stddev)(key2, (shape[-1],))
#         g = jnp.exp(g)
#         v = w / g
#         return g, v

#     return init

def _weight_fact(init_fn, mean, stddev):
    def init(key, shape):
        key1, key2 = random.split(key)
        w = init_fn(key1, shape)
        g = mean + normal(stddev)(key2, (shape[-1],))
        v = w / jnp.exp(g)
        return g, v

    return init


class PeriodEmbs(nn.Module):
    period: Tuple[float]  # Periods for different axes
    axis: Tuple[int]  # Axes where the period embeddings are to be applied
    trainable: Tuple[
        bool
    ]  # Specifies whether the period for each axis is trainable or not

    def setup(self):
        # Initialize period parameters as trainable or constant and store them in a flax frozen dict
        period_params = {}
        for idx, is_trainable in enumerate(self.trainable):
            if is_trainable:
                period_params[f"period_{idx}"] = self.param(
                    f"period_{idx}", constant(self.period[idx]), ()
                )
            else:
                period_params[f"period_{idx}"] = self.period[idx]

        self.period_params = freeze(period_params)

    @nn.compact
    def __call__(self, x):
        """
        Apply the period embeddings to the specified axes.
        """
        y = []

        for i, xi in enumerate(x):
            if i in self.axis:
                idx = self.axis.index(i)
                period = self.period_params[f"period_{idx}"]
                y.extend([jnp.cos(period * xi), jnp.sin(period * xi)])
            else:
                y.append(xi)

        return jnp.hstack(y)


# class FourierEmbs(nn.Module):
#     embed_scale: float
#     embed_dim: int

#     @nn.compact
#     def __call__(self, x):
#         kernel = self.param(
#             "kernel", normal(self.embed_scale), (x.shape[-1], self.embed_dim // 2)
#         )
#         y = jnp.concatenate(
#             [jnp.cos(jnp.dot(x, kernel)), jnp.sin(jnp.dot(x, kernel))], axis=-1
#         )
#         return y


class FourierEmbs(nn.Module):
    # embed_scale can be a single float (old behavior) or a sequence of floats (multi-scale)
    embed_scale: Union[float, Sequence[float]]
    embed_dim: int

    @nn.compact
    def __call__(self, x):
        # Multi-scale branch: embed_scale is something with length > 1
        if hasattr(self.embed_scale, "__len__") and len(self.embed_scale) > 1:
            ys = []
            for i, scale in enumerate(self.embed_scale):
                kernel = self.param(
                    f"kernel_{i}",
                    normal(scale),
                    (x.shape[-1], self.embed_dim // 2),
                )
                features = jnp.concatenate(
                    [jnp.cos(jnp.dot(x, kernel)), jnp.sin(jnp.dot(x, kernel))],
                    axis=-1,
                )
                ys.append(features)

            # Concatenate features from all scales along the feature dimension
            y = jnp.concatenate(ys, axis=-1)

        else:
            # Single-scale behavior (compatible with your original code)
            if hasattr(self.embed_scale, "__len__"):
                scale = self.embed_scale[0]
            else:
                scale = self.embed_scale

            kernel = self.param(
                "kernel", normal(scale), (x.shape[-1], self.embed_dim // 2)
            )
            y = jnp.concatenate(
                [jnp.cos(jnp.dot(x, kernel)), jnp.sin(jnp.dot(x, kernel))],
                axis=-1,
            )

        return y

from jax.scipy.special import logsumexp


def _safe_k(k: jnp.ndarray, k_min: float) -> jnp.ndarray:
    """
    Avoid k == 0 (or very small |k|) which would cause division blowups / NaNs.

    Keeps the sign of k, but clips magnitude to at least k_min.
    """
    sign = jnp.where(k < 0.0, -1.0, 1.0)
    return sign * jnp.maximum(jnp.abs(k), k_min)


class SymmetricEmbs(nn.Module):
    """
    Symmetric embedding:
        (sum_i |x_i|^k)^(1/k) = exp( (1/k) * logsumexp( k * log|x_i| ) )

    Produces `embed_dim` features per scale. Multi-scale concatenates.
    """
    embed_scale_min: Union[float, Sequence[float]]
    embed_scale_max: Union[float, Sequence[float]]
    embed_dim: int
    rescale_factor: float = 0.
    eps: float = 1e-12       # for log(|x| + eps)
    k_min: float = 1e-2      # keep |k| >= k_min

    def get_k(self, key):
        assert hasattr(self.embed_scale_max, "__len__")
        assert len(self.embed_scale_min) == len(self.embed_scale_max)
        scales = []
        n_embeds = self.embed_dim // len(self.embed_scale_min)
        for s_min, s_max in zip(self.embed_scale_min, self.embed_scale_max):
            if s_min>0 and s_max>0:
                scales.append(jnp.exp(jnp.linspace(jnp.log(s_min), jnp.log(s_max), n_embeds)))
            elif s_min<0 and s_max<0:
                scales.append(-jnp.exp(jnp.linspace(jnp.log(-s_max), jnp.log(-s_min), n_embeds)))
            else:
                raise ValueError("embed_scale_min and embed_scale_max must be both positive or both negative for each scale.")
        return jnp.concatenate(scales, axis=0)
        # else:
        #     s_min = self.embed_scale_min[0] if hasattr(self.embed_scale_min, "__len__") else self.embed_scale_min
        #     s_max = self.embed_scale_max[0] if hasattr(self.embed_scale_max, "__len__") else self.embed_scale_max
        #     return [jnp.linspace(s_min, s_max, self.embed_dim)]

    def _embed_from_k(self, x: jnp.ndarray, k: jnp.ndarray) -> jnp.ndarray:
        # k_raw: (embed_dim,)
        # k = _safe_k(k_raw, self.k_min)  # (embed_dim,)

        # log|x| over last axis
        log_abs_x = jnp.log(x ** 2 + self.eps)  # (..., d)

        # Build (..., d, embed_dim) by broadcasting k over the last axis
        scaled = log_abs_x[..., :, None] * k  # (..., d, embed_dim)

        # Sum over i (the "d" axis), leaving (..., embed_dim)
        lse = logsumexp(scaled, axis=-2) - jnp.log(x.shape[-1])   # (..., embed_dim)
        # const = 2 - (k>0).astype(jnp.float32) # 1 if k>0 else 2

        #  exp(1/k) * logsumexp(...) 
        return jnp.clip(jnp.exp(lse / k), max=1000)  # (..., embed_dim)

    @nn.compact
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        k = self.param("k", self.get_k)
        u = self._embed_from_k(x, k)
        u = u / self.embed_dim**(self.rescale_factor)
        return u
 # 1805075 1836031 1841434

class AntiSymmetricEmbs(nn.Module):
    """
    "Antisymmetric" embedding (as you defined it via pairwise distances):
        replace x_i by |x_i - x_j| for i < j, then apply same formula:
        (sum_{i<j} |x_i - x_j|^k)^(1/k)

    Produces `embed_dim` features per scale. Multi-scale concatenates.
    """
    embed_scale_min: Union[float, Sequence[float]]
    embed_scale_max: Union[float, Sequence[float]]
    embed_dim: int
    rescale_factor: float = 0.
    eps: float = 1e-12
    k_min: float = 1e-2

    def get_k(self, key):
        assert hasattr(self.embed_scale_max, "__len__")
        assert len(self.embed_scale_min) == len(self.embed_scale_max)
        scales = []
        n_embeds = self.embed_dim // len(self.embed_scale_min)
        for s_min, s_max in zip(self.embed_scale_min, self.embed_scale_max):
            if s_min>0 and s_max>0:
                scales.append(jnp.exp(jnp.linspace(jnp.log(s_min), jnp.log(s_max), n_embeds)))
            elif s_min<0 and s_max<0:
                scales.append(-jnp.exp(jnp.linspace(jnp.log(-s_max), jnp.log(-s_min), n_embeds)))
            else:
                raise ValueError("embed_scale_min and embed_scale_max must be both positive or both negative for each scale.")
        return jnp.concatenate(scales, axis=0)

    def _pairwise_log_diffs(self, x: jnp.ndarray) -> jnp.ndarray:
        """
        Returns log(|x_i - x_j| + eps) for i<j, as a vector on the last axis:
            shape (..., n_pairs)
        """
        d = x.shape[-1]
        if d < 2:
            # No pairs exist
            return jnp.empty(x.shape[:-1] + (0,), dtype=x.dtype)

        ii, jj = jnp.triu_indices(d, k=1)  # (n_pairs,), (n_pairs,)
        diffs = (x[..., ii] - x[..., jj]) ** 2 # (..., n_pairs)
        return jnp.log(diffs + self.eps)          # (..., n_pairs)

    def _embed_from_k(self, log_diffs: jnp.ndarray, k_raw: jnp.ndarray) -> jnp.ndarray:
        k = _safe_k(k_raw, self.k_min)  # (embed_dim,)

        # (..., n_pairs, embed_dim)
        scaled = log_diffs[..., :, None] * k
        lse = logsumexp(scaled, axis=-2) - jnp.log(log_diffs.shape[-1])  # (..., embed_dim)
        # const = 2 - (k>0).astype(jnp.float32) # 1 if k>0 else 2
        return jnp.clip(jnp.exp(lse / k), max=1000)  # (..., embed_dim)

    @nn.compact
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        log_diffs = self._pairwise_log_diffs(x)  # (..., n_pairs)

        k = self.param("k", self.get_k)
        u = self._embed_from_k(log_diffs, k)
        u = u / self.embed_dim**(self.rescale_factor)
        return u


class Dense(nn.Module):
    features: int
    kernel_init: Callable = glorot_normal()
    # kernel_init: Callable = kaiming_normal()
    bias_init: Callable = zeros
    reparam: Union[None, Dict] = None

    @nn.compact
    def __call__(self, x):
        if self.reparam is None:
            kernel = self.param(
                "kernel", self.kernel_init, (x.shape[-1], self.features)
            )

        elif self.reparam["type"] == "weight_fact":
            g, v = self.param(
                "kernel",
                _weight_fact(
                    self.kernel_init,
                    mean=self.reparam["mean"],
                    stddev=self.reparam["stddev"],
                ),
                (x.shape[-1], self.features),
            )
            # kernel = g * v
            kernel = jnp.exp(g) * v

        bias = self.param("bias", self.bias_init, (self.features,))

        y = jnp.dot(x, kernel) + bias

        return y


# TODO: Make it more general, e.g. imposing periodicity for the given axis


class Mlp(nn.Module):
    arch_name: Optional[str] = "Mlp"
    num_layers: int = 4
    hidden_dim: int = 256
    out_dim: int = 1
    activation: str = "tanh"
    periodicity: Union[None, Dict] = None
    fourier_emb: Union[None, Dict] = None
    reparam: Union[None, Dict] = None

    def setup(self):
        self.activation_fn = _get_activation(self.activation)

    @nn.compact
    def __call__(self, x):
        if self.periodicity:
            x = PeriodEmbs(**self.periodicity)(x)

        if self.fourier_emb:
            x = FourierEmbs(**self.fourier_emb)(x)

        for _ in range(self.num_layers):
            x = Dense(features=self.hidden_dim, reparam=self.reparam)(x)
            x = self.activation_fn(x)

        x = Dense(features=self.out_dim, reparam=self.reparam)(x)
        return x


class ModifiedMlp(nn.Module):
    arch_name: Optional[str] = "ModifiedMlp"
    num_layers: int = 4
    hidden_dim: int = 256
    out_dim: int = 1
    activation: str = "tanh"
    periodicity: Union[None, Dict] = None
    fourier_emb: Union[None, Dict] = None
    symmetric_emb: Union[None, Dict] = None
    reparam: Union[None, Dict] = None
    global_multiplier: float = 1.0
    global_add_func: Optional[Callable] = None
    no_embedding_for_first_dim: bool = False

    def setup(self):
        self.activation_fn = _get_activation(self.activation)

    @nn.compact
    def __call__(self, x):
        x_orig = x.copy()
        no_embed_dims = 1 if self.no_embedding_for_first_dim else 0
        if self.periodicity:
            x = PeriodEmbs(**self.periodicity)(x[..., no_embed_dims:])
            x = jnp.concatenate([x_orig[..., :no_embed_dims], x], axis=-1)

        
        if self.symmetric_emb:
            x_sym = SymmetricEmbs(**self.symmetric_emb)(x[...,no_embed_dims:])
            x_antisym = AntiSymmetricEmbs(**self.symmetric_emb)(x[...,no_embed_dims:])
            x = jnp.concatenate([x_orig[...,:no_embed_dims], x_sym, x_antisym], axis=-1)
        
        if self.fourier_emb:
            x = FourierEmbs(**self.fourier_emb)(x)

        u = Dense(features=self.hidden_dim, reparam=self.reparam)(x)
        v = Dense(features=self.hidden_dim, reparam=self.reparam)(x)

        u = self.activation_fn(u)
        v = self.activation_fn(v)

        for _ in range(self.num_layers):
            x = Dense(features=self.hidden_dim, reparam=self.reparam)(x)
            x = self.activation_fn(x)
            x = x * u + (1 - x) * v


        x = Dense(features=self.out_dim, reparam=self.reparam)(x)
        x = x * self.global_multiplier
        if self.global_add_func is not None:
            x = x + self.global_add_func(x_orig)
        return x


class MlpBlock(nn.Module):
    num_layers: int = 4
    hidden_dim: int = 256
    out_dim: int = 1
    activation: str = "tanh"
    reparam: Union[None, Dict] = None
    periodicity: Union[None, Dict] = None
    fourier_emb: Union[None, Dict] = None
    symmetric_emb: Union[None, Dict] = None
    global_multiplier: float = 1.0
    global_add_func: Optional[Callable] = None
    no_embedding_for_first_dim: bool = False

    def setup(self):
        self.activation_fn = _get_activation(self.activation)

    @nn.compact
    def __call__(self, x):
        x_orig = x.copy()
        no_embed_dims = 1 if self.no_embedding_for_first_dim else 0
        if self.periodicity:
            x = PeriodEmbs(**self.periodicity)(x[..., no_embed_dims:])
            x = jnp.concatenate([x_orig[..., :no_embed_dims], x], axis=-1)

        
        if self.symmetric_emb:
            x_sym = SymmetricEmbs(**self.symmetric_emb)(x[...,no_embed_dims:])
            x_antisym = AntiSymmetricEmbs(**self.symmetric_emb)(x[...,no_embed_dims:])
            x = jnp.concatenate([x_orig[...,:no_embed_dims], x_sym, x_antisym], axis=-1)
        
        if self.fourier_emb:
            x = FourierEmbs(**self.fourier_emb)(x)
        
        for _ in range(self.num_layers):
            x = Dense(features=self.hidden_dim, reparam=self.reparam)(x)
            x = self.activation_fn(x)


        x = Dense(features=self.out_dim, reparam=self.reparam)(x)
        # x = self.activation_fn(x)

        return x


class DeepONet(nn.Module):
    arch_name: Optional[str] = "DeepONet"
    num_branch_layers: int = 4
    num_trunk_layers: int = 4
    hidden_dim: int = 256
    out_dim: int = 1
    activation: str = "tanh"
    periodicity: Union[None, Dict] = None
    fourier_emb: Union[None, Dict] = None
    reparam: Union[None, Dict] = None

    def setup(self):
        self.activation_fn = _get_activation(self.activation)

    @nn.compact
    def __call__(self, u, x):
        u = MlpBlock(
            num_layers=self.num_branch_layers,
            hidden_dim=self.hidden_dim,
            out_dim=self.hidden_dim,
            activation=self.activation,
            final_activation=False,
            reparam=self.reparam,
        )(u)

        x = Mlp(
            num_layers=self.num_trunk_layers,
            hidden_dim=self.hidden_dim,
            out_dim=self.hidden_dim,
            activation=self.activation,
            periodicity=self.periodicity,
            fourier_emb=self.fourier_emb,
            reparam=self.reparam,
        )(x)

        y = u * x
        y = self.activation_fn(y)
        y = Dense(features=self.out_dim, reparam=self.reparam)(y)
        return y