from typing import Union

from functools import partial

import jax
import jax.numpy as jnp
from jax.nn import gelu, softmax
from flax import linen as nn


# For typing
from enf.bi_invariant._base_bi_invariant import BaseBiInvariant


class RFFNet(nn.Module):
    in_dim: int
    output_dim: int
    hidden_dim: int
    num_layers: int
    learnable_coefficients: bool
    std: float
    numerator: float = 2.0

    def setup(self):
        assert (
            self.num_layers >= 2
        ), "At least two layers (the hidden plus the output one) are required."

        # Encoding
        self.encoding = RFFEmbedding(
            embedding_dim=self.hidden_dim,
            learnable_coefficients=self.learnable_coefficients,
            std=self.std,
        )

        # Hidden layers
        self.layers = [
            Layer(hidden_dim=self.hidden_dim, numerator=self.numerator)
            for _ in range(self.num_layers - 1)
        ]

        # Output layer
        self.linear_final = nn.Dense(
            features=self.output_dim,
            use_bias=True,
            kernel_init=nn.initializers.variance_scaling(self.numerator, "fan_in", "uniform"),
            bias_init=nn.initializers.normal(stddev=1e-6),
        )

    def __call__(self, x):
        x = self.encoding(x)
        for i, layer in enumerate(self.layers):
            x = layer(x)
        x = self.linear_final(x)
        return x


class Layer(nn.Module):
    hidden_dim: int
    numerator: float = 2.0

    def setup(self):
        self.linear = nn.Dense(
            features=self.hidden_dim,
            use_bias=True,
            kernel_init=nn.initializers.variance_scaling(self.numerator, "fan_in", "normal"),
            bias_init=nn.initializers.normal(stddev=1e-6),
        )
        self.activation = nn.relu

    def __call__(self, x):
        return self.activation(self.linear(x))


class RFFEmbedding(nn.Module):
    embedding_dim: int
    learnable_coefficients: bool
    std: float

    @nn.compact
    def __call__(self, coords: jnp.ndarray) -> jnp.ndarray:
        emb = nn.Dense(self.embedding_dim // 2, kernel_init=nn.initializers.normal(self.std), use_bias=False)(
            jnp.pi * (coords + 1))  # scale to [0, 2pi]
        return jnp.sin(jnp.concatenate([coords, emb, emb + jnp.pi / 2.0], axis=-1))


class EquivariantCrossAttention(nn.Module):
    num_hidden: int
    num_heads: int
    bi_invariant: BaseBiInvariant
    embedding_freq_multiplier: tuple
    k_nearest: int

    def setup(self):
        # Bi-invariant embedding for the query and value transforms.
        emb_freq_mult_q, emb_freq_mult_v = self.embedding_freq_multiplier
        self.emb_q = nn.Sequential([
            RFFEmbedding(embedding_dim=2 * self.num_hidden, learnable_coefficients=True, std=emb_freq_mult_q),
            PointwiseFFN(self.num_hidden, self.num_hidden, norm=False),
        ])
        self.emb_v = nn.Sequential([
            RFFEmbedding(embedding_dim=2 * self.num_hidden, learnable_coefficients=True, std=emb_freq_mult_v),
            PointwiseFFN(self.num_hidden, self.num_hidden, norm=False),
        ])

        # Bi-invariant embedding -> query
        self.emb_to_q = nn.Dense(self.num_heads * self.num_hidden)

        # Context vector -> key, query
        self.c_to_kv = nn.Dense(2 * self.num_heads * self.num_hidden)
        self.emb_to_v = nn.Dense(2 * self.num_heads * self.num_hidden)
        self.v_mixer = nn.Dense(self.num_hidden)

        # Output projection
        self.out_proj = nn.Dense(self.num_heads * self.num_hidden)

        # Set the scale factor for the attention weights.
        self.scale = 1.0

    def __call__(self, x, p, c, g):
        """ Apply equivariant cross attention.

        Args:
            x (jax.numpy.ndarray): The input coordinates. Shape (batch_size, num_coords, coord_dim).
            p (jax.numpy.ndarray): The latent poses. Shape (batch_size, num_latents, coord_dim).
            c (jax.numpy.ndarray): The latent context vectors. Shape (batch_size, num_latents, latent_dim).
            g (jax.numpy.ndarray): The window size for the gaussian window. Shape (batch_size, num_latents, 1).
        """
        # Get bi-invariants of input coordinates wrt latent coordinates. Depending on the bi-invariant, the shape of the
        # bi-invariants tensor will be different.
        bi_inv = self.bi_invariant(x, p)

        # # Take top-k nearest latents to every input coordinate.
        latent_coord_distances = jnp.linalg.norm(x[:, :, None] - p[:, None, :, :self.bi_invariant.num_z_pos_dims], axis=-1)
        nearest_latents = jnp.argsort(latent_coord_distances, axis=-1)[:, :, :self.k_nearest]
        nearest_latents_exp = nearest_latents[..., :, jnp.newaxis]
        nearest_latents_exp = jnp.broadcast_to(nearest_latents_exp, (*bi_inv.shape[:2], self.k_nearest, *bi_inv.shape[3:]))

        bi_inv = jnp.take_along_axis(bi_inv, nearest_latents_exp, axis=2)
        c = jnp.take_along_axis(c[:, None, :, :], nearest_latents[:, :, :, None], axis=2)

        # c = c[:, None, :, :]
        # c = jnp.repeat(c[:, None, :, :], x.shape[1], axis=1)

        # Apply bi-invariant embedding for the query tranform and conditioning of the value transform.
        emb_q = self.emb_q(bi_inv)
        emb_v = self.emb_v(bi_inv)

        # Calculate the query, key and value.
        q = self.emb_to_q(emb_q)
        k, v = jnp.split(self.c_to_kv(c), 2, axis=-1)

        # Attend the values to the queries and keys.
        # Get gamma, beta conditioning variables for the value transform.
        cond_v_g, cond_v_b = jnp.split(self.emb_to_v(emb_v), 2, axis=-1)

        # Apply conditioning to the value transform, broadcast over the coordinates.
        v = v * cond_v_g + cond_v_b

        # Reshape to separate the heads, mix the values.
        v = self.v_mixer(v.reshape(v.shape[:-1] + (self.num_heads, self.num_hidden)))

        # Reshape the query, key and value to separate the heads.
        q = q.reshape(q.shape[:-1] + (self.num_heads, self.num_hidden))
        k = k.reshape(k.shape[:-1] + (self.num_heads, self.num_hidden))

        # For every input coordinate, calculate the attention weights for every latent.
        att = (q * k).sum(axis=-1) * self.scale

        # Apply gaussian window if needed.
        gaussian_window = self.bi_invariant.calculate_gaussian_window(x, p, sigma=g)
        gaussian_window = jnp.take_along_axis(gaussian_window, nearest_latents[..., None], axis=2)
        att = att + gaussian_window
        att = softmax(att, axis=-2)

        # Apply attention to the values.
        y = (att[..., None] * v).sum(axis=2)  # 'bczh,bczhd->bchd'

        # Reshape y to concatenate the heads.
        y = y.reshape(*y.shape[:2], self.num_heads * self.num_hidden)

        # output projection
        y = self.out_proj(y)
        return y


class PointwiseFFN(nn.Module):
    num_hidden: int
    num_out: int
    num_layers: int = 1
    norm: bool = True

    @nn.compact
    def __call__(self, x):
        for _ in range(self.num_layers):
            x = nn.Dense(self.num_hidden)(x)
            x = gelu(x)
            if self.norm:
                x = nn.LayerNorm()(x)
        x = nn.Dense(self.num_out)(x)
        return x


class EquivariantNeuralField(nn.Module):
    """ Equivariant cross attention network for the latent points, conditioned on the poses.

    Args:
        num_hidden (int): The number of hidden units.
        num_heads (int): The number of attention heads.
        num_out (int): The number of output coordinates.
        latent_dim (int): The dimensionality of the latent code.
        bi_invariant (BaseBiInvariant): The invariant to use for the attention operation.
        embedding_freq_multiplier (Union[float, float]): The frequency multiplier for the embedding.
    """
    num_hidden: int
    num_heads: int
    num_out: int
    latent_dim: int
    bi_invariant: BaseBiInvariant
    embedding_freq_multiplier: Union[float, float]
    cross_attention_blocks = []
    k_nearest: int

    def setup(self):

        # Maps latent to hidden space
        self.latent_stem = nn.Dense(self.num_hidden)

        # Cross attn block
        self.layer_norm_attn = nn.LayerNorm()
        self.attn = EquivariantCrossAttention(
            num_hidden=self.num_hidden,
            num_heads=self.num_heads,
            bi_invariant=self.bi_invariant,
            embedding_freq_multiplier=self.embedding_freq_multiplier,
            k_nearest=self.k_nearest,
        )
        self.ffn_out = nn.Dense(self.num_out)

    def __call__(self, x, p, c, g):
        """ Sample from the model.

        Args:
            x (jnp.Array): The pose of the input points. Shape (batch_size, num_coords, 2).
            p (jnp.Array): The pose of the latent points. Shape (batch_size, num_latents, num_ori (1), 4).
            c (jnp.Arrays): The latent features. Shape (batch_size, num_latents, num_hidden).
            g (float or None): The window size for the gaussian window.
        """
        # Map code to latent space
        c = self.latent_stem(c)

        # Cross attention block, pre-norm on context vectors
        c = self.layer_norm_attn(c)
        f_hat = self.attn(x=x, p=p, c=c, g=g)

        # Output layers
        return self.ffn_out(f_hat)
