from flax import linen as nn
import jax
import jax.numpy as jnp
import numpy as np


class CircularMLP(nn.Module):
    """Base class shared by all MLP variants used in the project."""

    p: int               # size of the modular alphabet (labels 0 … p‑1)
    num_neurons: int      # width of every hidden layer

    # ---------------------------------------------------------------------
    # Interfaces that the training / analysis code relies on
    # ---------------------------------------------------------------------
    def bias(self, params):
        """Return the bias vector of the *first* hidden layer."""
        raise NotImplementedError

    def extract_embeddings_ab(self, params):
        """Return (embedding_a, embedding_b) — each of shape (p, D)."""
        raise NotImplementedError

    # Convenience helpers used elsewhere in the code‑base ------------------
    def extract_effective_embeddings_horizontal(self, params):
        """Return concatenated embeddings [A | B] of shape (p, 2·D)."""
        emb_a, emb_b = self.extract_embeddings_ab(params)
        return np.concatenate([np.asarray(emb_a), np.asarray(emb_b)], axis=1)

    def compute_effective_embeddings_vertical(self, params):
        """Return stacked embeddings of shape (2p, D)."""
        emb_a, emb_b = self.extract_embeddings_ab(params)
        return np.concatenate([np.asarray(emb_a), np.asarray(emb_b)], axis=0)

    def all_p_squared_embeddings(self, params):
        """Return the (p², 2·D) matrix for all (a, b) pairs."""
        emb_a, emb_b = self.extract_embeddings_ab(params)
        return np.stack([
            np.concatenate([emb_a[i], emb_b[j]])
            for i in range(self.p) for j in range(self.p)
        ], axis=0)

# =====================================================================
# Helper to build an arbitrary‑depth feed‑forward tower
# =====================================================================

def _forward_tower(x, num_layers, num_neurons, first_layer_name_prefix="dense"):
    """Utility: build *num_layers* Dense → ReLU blocks.

    Returns
    -------
    activation : jnp.ndarray
        Output after the final ReLU.
    preactivations : list[jnp.ndarray]
        List of pre‑ReLU tensors, one per hidden layer (index 0 == layer 1).
    dense0_kernel : jnp.ndarray
        Kernel of the *first* Dense layer — needed for contribution splits.
    """
    preacts = []
    activation = x

    for layer_idx in range(1, num_layers + 1):
        dense = nn.Dense(
            features=num_neurons,
            kernel_init=nn.initializers.he_normal(),
            name=f"{first_layer_name_prefix}_{layer_idx}")
        pre_act = dense(activation)
        activation = nn.relu(pre_act)
        preacts.append(pre_act)

        if layer_idx == 1:
            first_kernel = dense.variables["params"]["kernel"]

    return activation, preacts, first_kernel

# =====================================================================
# 1) One‑Hot concatenation
# =====================================================================
class MLPOneHot(CircularMLP):
    num_layers: int = 1
    features: int = 128
    @nn.compact
    def __call__(self, x, training: bool = False):
        a, b = x[:, 0], x[:, 1]
        a_onehot = jax.nn.one_hot(a, self.p)
        b_onehot = jax.nn.one_hot(b, self.p)
        concat = jnp.concatenate([a_onehot, b_onehot], axis=-1)

        # Build the hidden tower
        hidden, preacts, kernel1 = _forward_tower(concat, self.num_layers, self.num_neurons,
                                                  first_layer_name_prefix="dense")
        # Split contributions of the first layer --------------------------
        contribution_a = jnp.dot(a_onehot, kernel1[: self.p, :])
        contribution_b = jnp.dot(b_onehot, kernel1[self.p : 2 * self.p, :])

        # Output layer -----------------------------------------------------
        logits = nn.Dense(features=self.p,
                          kernel_init=nn.initializers.he_normal(),
                          name="output_dense")(hidden)
        return logits, preacts, contribution_a, contribution_b

    # Interfaces -----------------------------------------------------------
    def bias(self, params):
        return params["dense_1"]["bias"]

    def extract_embeddings_ab(self, params):
        W = params["dense_1"]["kernel"]      # (2p, num_neurons)
        return W[: self.p, :], W[self.p : 2 * self.p, :]

# =====================================================================
# 2) One shared embedding (duplicated)
# =====================================================================
class MLPOneEmbed(CircularMLP):
    features: int
    num_layers: int = 1

    @nn.compact
    def __call__(self, x, training: bool = False):
        a, b = x[:, 0], x[:, 1]
        shared = nn.Embed(self.p, self.features, name="shared_embed",
                          embedding_init=nn.initializers.he_normal())
        a_emb = shared(a)
        b_emb = shared(b)
        concat = jnp.concatenate([a_emb, b_emb], axis=-1)

        hidden, preacts, kernel1 = _forward_tower(concat, self.num_layers, self.num_neurons,
                                                  first_layer_name_prefix="dense")
        contribution_a = jnp.dot(a_emb, kernel1[: self.features, :])
        contribution_b = jnp.dot(b_emb, kernel1[self.features :, :])

        logits = nn.Dense(self.p, kernel_init=nn.initializers.he_normal(),
                          name="output_dense")(hidden)
        return logits, preacts, contribution_a, contribution_b

    def bias(self, params):
        return params["dense_1"]["bias"]

    def extract_embeddings_ab(self, params):
        emb = np.asarray(params["shared_embed"]["embedding"])
        return emb, emb

