import math
from typing import Union, Tuple, List, Callable, Any

import jax
import jax.numpy as jnp
from jax.nn import gelu, softmax
from flax import linen as nn
import numpy as np

# Import modules
from equiv_eikonal.steerable_attention.threeway_invariants import BaseThreewayInvariants
from equiv_eikonal.steerable_attention.equivariant_cross_attention import (
    EquivariantCrossAttention,
    PointwiseFFN,
)
from equiv_eikonal.models.dense import DenseBody

from equiv_eikonal.utils import torch_compatible_dense, ones_dense


class EquivariantNeuralField(nn.Module):
    """
    Equivariant cross attention network for latent points conditioned on poses,
    optimized for both training stability and meta-learning on the eikonal equation.

    Args:
        num_hidden (int): Number of hidden units.
        num_heads (int): Number of attention heads.
        latent_dim (int): Dimensionality of the latent code.
        num_out (int): Number of output coordinates.
        invariant (BaseThreewayInvariants): Invariant module for attention.
        embedding_type (str): Embedding type ('rff' in this case).
        embedding_freq_multiplier (Tuple[float, float]): Frequency multipliers for the invariant embeddings.
    """

    num_hidden: int
    num_heads: int
    latent_dim: int
    num_out: int
    invariant: Any  # "BaseThreewayInvariants" (assumed to be defined elsewhere)
    embedding_type: str
    embedding_freq_multiplier: Tuple[float, float]

    def setup(self):
        self.activation = gelu
        # Map latent code to hidden space.
        self.latent_stem = torch_compatible_dense(
            in_features=self.latent_dim, out_features=self.num_hidden
        )
        # Pre-norm layers.
        # self.layer_norm_stem = nn.LayerNorm()
        self.layer_norm_attn = nn.LayerNorm(
            epsilon=1e-5,
            use_fast_variance=False,
            force_float32_reductions=False,
        )
        # self.layer_norm_ffn = nn.LayerNorm()

        # Equivariant cross attention module.
        self.attn = EquivariantCrossAttention(
            num_hidden=self.num_hidden,
            num_heads=self.num_heads,
            invariant=self.invariant,
            embedding_type=self.embedding_type,
            embedding_freq_multiplier=self.embedding_freq_multiplier,
        )

        # Pointwise feed-forward block.
        self.pointwise_ffn = PointwiseFFN(
            num_in=self.num_heads * self.num_hidden,
            num_hidden=self.num_heads * self.num_hidden,
            num_out=self.num_heads * self.num_hidden,
        )

        # Output projection block.
        self.out_proj = DenseBody(
            input_dim=self.num_heads * self.num_hidden,
            nu=self.num_hidden,
            nl=3,
            out_dim=1,
            act="ad-gauss-1",
            # act="gelu",
            out_act="linear",
        )

    def __call__(self, inputs, p, a):

        # Map latent features.
        a = self.latent_stem(a)

        # Pre-norm residual attention block.
        out = self.attn(inputs, p, self.layer_norm_attn(a))
        # Pre-norm residual feed-forward block.
        out = self.pointwise_ffn(out)
        out = self.activation(out)
        # Final projection after a last normalization.
        out = self.out_proj(out)
        return out
