from __future__ import annotations

from typing import Any, Callable, Dict, Iterator, Optional, Tuple, Union

import jax
import jax.numpy as jnp
from flax import linen as nn
import numpy as np


Array = jnp.ndarray
ActivationFn = Callable[[Array], Array]
Activation = Union[str, ActivationFn]

def triu_size(n: int) -> int:
    return n * (n + 1) // 2


def triu_pack(x: Array, idx: Tuple[np.ndarray, np.ndarray]) -> Array:
    """Pack upper-triangular (incl. diagonal) of (..., n, n) into (..., d)."""
    return x[..., idx[0], idx[1]]


def triu_unpack_symmetric(v: Array, n: int, idx: Tuple[np.ndarray, np.ndarray]) -> Array:
    """Unpack (..., d) to symmetric (..., n, n) (mirror upper triangle to lower)."""
    upper = jnp.zeros(v.shape[:-1] + (n, n), dtype=v.dtype)
    upper = upper.at[..., idx[0], idx[1]].set(v)

    diag = jnp.diagonal(upper, axis1=-2, axis2=-1)  # (..., n)
    diag_mat = diag[..., :, None] * jnp.eye(n, dtype=v.dtype)  # (..., n, n)
    sym = upper + jnp.swapaxes(upper, -1, -2) - diag_mat
    return sym

def triu_unpack_cholesky(v: Array, n: int, idx: Tuple[np.ndarray, np.ndarray]) -> Array:
    """Unpack (..., d) to symmetric (..., n, n) (mirror upper triangle to lower)."""
    upper = jnp.zeros(v.shape[:-1] + (n, n), dtype=v.dtype)
    upper = upper.at[..., idx[0], idx[1]].set(v)

    sym = jnp.swapaxes(upper, -1, -2) @ upper
    # diag = jnp.diagonal(upper, axis1=-2, axis2=-1)  # (..., n)
    # diag_mat = diag[..., :, None] * jnp.eye(n, dtype=v.dtype)  # (..., n, n)
    # sym = upper + jnp.swapaxes(upper, -1, -2) - diag_mat
    return sym


def _get_activation(act: Activation) -> ActivationFn:
    """Maps a string (or callable) to an activation function."""
    if callable(act):
        return act

    name = act.lower()
    if name == "tanh":
        return jnp.tanh
    if name == "gelu":
        return jax.nn.gelu
    if name in ("swish", "silu"):
        # swish == SiLU in JAX
        return jax.nn.silu

    raise ValueError(
        f"Unknown activation '{act}'. Expected one of: 'tanh', 'gelu', 'swish' (or a callable)."
    )


def _default_hidden_kernel_init(act: Activation) -> Callable:
    """
    Sensible default initializers by activation.

    - tanh: Glorot/Xavier is a solid default for symmetric squashing activations.
    - gelu/swish: He (fan_in, scale=2) often works well for ReLU-like activations.
    """
    if isinstance(act, str):
        name = act.lower()
        if name == "tanh":
            return nn.initializers.xavier_uniform()
        if name in ("gelu", "swish", "silu"):
            # "He normal" equivalent:
            return nn.initializers.variance_scaling(
                scale=2.0, mode="fan_in", distribution="truncated_normal"
            )
    # Fallback (if callable activation or unknown string)
    return nn.initializers.xavier_uniform()


class MLP(nn.Module):
    """
    A simple MLP in Flax Linen.

    Args:
      num_layers: total number of Dense layers (>= 1). If == 1, it's just a linear projection to out_dim.
      hidden_dim: hidden width for the first num_layers-1 layers.
      out_dim: output width of the final layer.
      activation: 'tanh', 'gelu', or 'swish' (or a callable).
    """
    num_layers: int
    hidden_dim: int
    out_dim: int
    activation: Activation

    use_bias: bool = True

    # Mixed precision friendly defaults:
    dtype: jnp.dtype = jnp.float32         # computation dtype
    param_dtype: jnp.dtype = jnp.float32   # parameter dtype (often keep fp32 even if compute bf16/fp16)

    # Initializers (if None, choose sensible defaults):
    kernel_init: Optional[Callable] = None         # for hidden layers
    out_kernel_init: Optional[Callable] = None     # for final layer
    bias_init: Callable = nn.initializers.zeros

    @nn.compact
    def __call__(self, x: Array, *, train: bool = False) -> Array:
        if self.num_layers < 1:
            raise ValueError(f"num_layers must be >= 1, got {self.num_layers}")

        act_fn = _get_activation(self.activation)
        hidden_kinit = self.kernel_init or _default_hidden_kernel_init(self.activation)

        # Final layer: Xavier is a generally safe default.
        # (If you want smaller initial outputs for stability in some tasks, you can pass
        #  out_kernel_init=nn.initializers.normal(1e-2) or similar.)
        out_kinit = self.out_kernel_init or nn.initializers.xavier_uniform()

        # Hidden layers (num_layers-1 of them)
        for i in range(self.num_layers - 1):
            x = nn.Dense(
                self.hidden_dim,
                use_bias=self.use_bias,
                dtype=self.dtype,
                param_dtype=self.param_dtype,
                kernel_init=hidden_kinit,
                bias_init=self.bias_init,
                name=f"dense_{i}",
            )(x)
            x = act_fn(x)

        # Output layer
        x = nn.Dense(
            self.out_dim,
            use_bias=self.use_bias,
            dtype=self.dtype,
            param_dtype=self.param_dtype,
            kernel_init=out_kinit,
            bias_init=self.bias_init,
            name="out",
        )(x)
        return x


class FourierTimeEmbedding(nn.Module):
    """
    Fourier embedding with fixed frequencies:
      w = linspace(0, fourier_scale, n_fourier+1)[1:]
      emb(t) = concat(sin(w t), cos(w t))

    Notes:
      - t can be shape (...,) or (..., 1). Output is (..., 2*n_fourier).
      - If you want 2π scaling, change wt = (2*jnp.pi) * t[...,None] * w[None,:].
    """
    n_fourier: int
    fourier_scale: float
    dtype: jnp.dtype = jnp.float32
    fourier_log_scale: bool = False  # if True, use log-spaced frequencies
    log_t: bool = False               # if True, embed log(t) instead of t

    def setup(self):
        if self.fourier_log_scale:
            w_np = np.logspace(
                0.0,
                np.log10(float(self.fourier_scale)),
                int(self.n_fourier),
                dtype=np.float32,
            )
        else:
            w_np = np.linspace(
                0.0, float(self.fourier_scale), int(self.n_fourier) + 1, dtype=np.float32
            )[1:]  # drop 0 frequency
        self.w = jnp.asarray(w_np, dtype=self.dtype)  # non-trainable constant

    def __call__(self, t: Array) -> Array:
        t = jnp.asarray(t, dtype=self.dtype)
        if self.log_t:
            t = jnp.log(t + 1e-10)

        # Enforce (..., 1)
        if t.ndim < 1 or t.shape[-1] != 1:
            raise ValueError(f"FourierTimeEmbedding expects t shape (..., 1), got {t.shape}")

        t = jnp.squeeze(t, axis=-1)             # (...,)
        wt = t[..., None] * self.w     # (..., n_fourier)
        return jnp.concatenate([jnp.sin(wt), jnp.cos(wt)], axis=-1) 


# -----------------------------------------------------------------------------
# SPD -> symmetric model: pack triu + concat time embedding + MLP + unpack symmetric
# -----------------------------------------------------------------------------

class SPDToSymmetricMLP(nn.Module):
    mat_dim: int
    num_layers: int
    hidden_dim: int
    activation: Activation

    # time embedding knobs
    n_fourier: int = 64
    fourier_scale: float = 16.0
    fourier_log_scale: bool = False
    log_t: bool = False

    dtype: jnp.dtype = jnp.float32
    param_dtype: jnp.dtype = jnp.float32

    # optional init overrides
    out_kernel_init: Optional[Callable] = None

    def setup(self):
        self._triu_idx = np.triu_indices(self.mat_dim)
        self._d = triu_size(self.mat_dim)

        self.time_emb = None
        if self.n_fourier and self.n_fourier > 0:
            self.time_emb = FourierTimeEmbedding(
                n_fourier=self.n_fourier,
                fourier_scale=self.fourier_scale,
                fourier_log_scale=self.fourier_log_scale,
                log_t=self.log_t,
                dtype=self.dtype,
            )

        self.mlp = MLP(
            num_layers=self.num_layers,
            hidden_dim=self.hidden_dim,
            out_dim=self._d,  # output stays triu-packed symmetric
            activation=self.activation,
            dtype=self.dtype,
            param_dtype=self.param_dtype,
            out_kernel_init=self.out_kernel_init,  # default small normal if None
        )

    def __call__(self, x_spd: Array, t: Array, *, train: bool = False, condition = None) -> Array:
        """
        x_spd: (..., n, n) SPD input
        t:     (...) or (..., 1) diffusion time
        returns: (..., n, n) symmetric output
        """
        v = triu_pack(x_spd, self._triu_idx)  # (..., d)

        if self.time_emb is not None:
            temb = self.time_emb(t)          # (..., 2*n_fourier)
            h = jnp.concatenate([v, temb], axis=-1)
        else:
            h = v
        if condition is not None:
            h = jnp.concatenate([h, condition], axis=-1)

        v_out = self.mlp(h, train=train)     # (..., d)
        # y_sym = triu_unpack_symmetric(v_out, self.mat_dim, self._triu_idx)
        y_sym = triu_unpack_cholesky(v_out, self.mat_dim, self._triu_idx)
        return y_sym