from flax import linen as nn
import jax
import jax.numpy as jnp
import numpy as np


class DonutMLP(nn.Module):
    group_size: int
    num_neurons: int

    def _combine_embeddings(self,
                             vec_a: np.ndarray,
                             vec_b: np.ndarray) -> np.ndarray:
        return np.concatenate([vec_a, vec_b], axis=0)
    
    def bias(self, params):
        raise NotImplementedError

    def extract_embeddings_ab(self, params):
        raise NotImplementedError

    def all_p_squared_embeddings(self, params):
        emb_a, emb_b = self.extract_embeddings_ab(params)
        return np.stack([
            self._combine_embeddings(emb_a[i], emb_b[j])
            for i in range(self.group_size)
            for j in range(self.group_size)
        ], axis=0)
    
    def call_from_embedding(self, emb: jnp.ndarray, params: dict):
        hidden = emb
        preacts = []
        for layer_idx in range(1, self.num_layers + 1):
            key = f"dense_{layer_idx}"
            w = params[key]["kernel"]
            b = params[key]["bias"]
            pre_act = jnp.dot(hidden, w) + b
            hidden = nn.relu(pre_act)
            preacts.append(pre_act)
        w_out = params["output_dense"]["kernel"]
        b_out = params["output_dense"]["bias"]
        logits = jnp.dot(hidden, w_out) + b_out
        return logits, preacts


def _forward_tower(x, num_layers, num_neurons, first_layer_name_prefix="dense"):
    preacts = []
    activation = x
    for layer_idx in range(1, num_layers + 1):
        dense = nn.Dense(
            features=num_neurons,
            kernel_init=nn.initializers.he_normal(),
            name=f"{first_layer_name_prefix}_{layer_idx}")
        pre_act = dense(activation)
        activation = nn.relu(pre_act)
        preacts.append(pre_act)
        if layer_idx == 1:
            first_kernel = dense.variables["params"]["kernel"]
    return activation, preacts, first_kernel


class MLPOneHot(DonutMLP):
    num_layers: int = 1
    features: int = 128
    @nn.compact
    def __call__(self, x, training: bool = False):
        a, b = x[:, 0], x[:, 1]
        a_onehot = jax.nn.one_hot(a, self.group_size)
        b_onehot = jax.nn.one_hot(b, self.group_size)
        concat = jnp.concatenate([a_onehot, b_onehot], axis=-1)
        hidden, preacts, kernel1 = _forward_tower(concat, self.num_layers, self.num_neurons,
                                                  first_layer_name_prefix="dense")
        contribution_a = jnp.dot(a_onehot, kernel1[: self.group_size, :])
        contribution_b = jnp.dot(b_onehot, kernel1[self.group_size : 2 * self.group_size, :])
        logits = nn.Dense(features=self.group_size,
                          kernel_init=nn.initializers.he_normal(),
                          name="output_dense")(hidden)
        return logits, preacts, contribution_a, contribution_b

    def bias(self, params):
        return params["dense_1"]["bias"]
    
    def all_p_squared_embeddings(self, params=None):
        eye = np.eye(self.group_size, dtype=np.float32)
        a_idx = np.repeat(np.arange(self.group_size), self.group_size)
        b_idx = np.tile  (np.arange(self.group_size), self.group_size)
        return np.concatenate([eye[a_idx], eye[b_idx]], axis=1)

    def extract_embeddings_ab(self, params):
        W = params["dense_1"]["kernel"]
        return W[: self.group_size, :], W[self.group_size : 2 * self.group_size, :]


class MLPOneEmbed(DonutMLP):
    features: int
    num_layers: int = 1

    @nn.compact
    def __call__(self, x, training: bool = False):
        a, b = x[:, 0], x[:, 1]
        shared = nn.Embed(self.group_size, self.features, name="shared_embed",
                          embedding_init=nn.initializers.he_normal())
        a_emb = shared(a)
        b_emb = shared(b)
        concat = jnp.concatenate([a_emb, b_emb], axis=-1)
        hidden, preacts, kernel1 = _forward_tower(concat, self.num_layers, self.num_neurons,
                                                  first_layer_name_prefix="dense")
        contribution_a = jnp.dot(a_emb, kernel1[: self.features, :])
        contribution_b = jnp.dot(b_emb, kernel1[self.features :, :])
        logits = nn.Dense(self.group_size, kernel_init=nn.initializers.he_normal(),
                          name="output_dense")(hidden)
        return logits, preacts, contribution_a, contribution_b

    def bias(self, params):
        return params["dense_1"]["bias"]

    def extract_embeddings_ab(self, params):
        emb = params["shared_embed"]["embedding"]
        return emb, emb


class MLPTwoEmbed(DonutMLP):
    features: int
    num_layers: int = 1

    @nn.compact
    def __call__(self, x, training: bool = False):
        a, b = x[:, 0], x[:, 1]
        emb_a = nn.Embed(self.group_size, self.features, name="embedding_a",
                         embedding_init=nn.initializers.he_normal())
        emb_b = nn.Embed(self.group_size, self.features, name="embedding_b",
                         embedding_init=nn.initializers.he_normal())
        a_emb = emb_a(a)
        b_emb = emb_b(b)
        concat = jnp.concatenate([a_emb, b_emb], axis=-1)
        hidden, preacts, kernel1 = _forward_tower(concat, self.num_layers, self.num_neurons,
                                                  first_layer_name_prefix="dense")
        contribution_a = jnp.dot(a_emb, kernel1[: self.features, :])
        contribution_b = jnp.dot(b_emb, kernel1[self.features :, :])
        logits = nn.Dense(self.group_size, kernel_init=nn.initializers.he_normal(),
                          name="output_dense")(hidden)
        return logits, preacts, contribution_a, contribution_b

    def bias(self, params):
        return params["dense_1"]["bias"]

    def extract_embeddings_ab(self, params):
        return (params["embedding_a"]["embedding"],
                params["embedding_b"]["embedding"])


class MLPOneHot_cheating(DonutMLP):
    num_layers: int = 1
    features: int = 128
    @nn.compact
    def __call__(self, x, training: bool = False):
        a, b = x[:, 0], x[:, 1]
        a_oh, b_oh = jax.nn.one_hot(a, self.group_size), jax.nn.one_hot(b, self.group_size)
        added = a_oh + b_oh
        hidden, preacts, kernel1 = _forward_tower(added, self.num_layers, self.num_neurons,
                                                  first_layer_name_prefix="dense")
        contribution_a = jnp.dot(a_oh, kernel1)
        contribution_b = jnp.dot(b_oh, kernel1)
        logits = nn.Dense(self.group_size, kernel_init=nn.initializers.he_normal(),
                          name="output_dense")(hidden)
        return logits, preacts, contribution_a, contribution_b

    def bias(self, params):
        return params["dense_1"]["bias"]

    def extract_embeddings_ab(self, params):
        eye = np.eye(self.group_size)
        return eye, eye
    
    def _combine_embeddings(self, vec_a, vec_b):
        return vec_a + vec_b


class MLPOneEmbed_cheating(DonutMLP):
    features: int
    num_layers: int = 1

    @nn.compact
    def __call__(self, x, training: bool = False):
        a, b = x[:, 0], x[:, 1]
        shared = nn.Embed(self.group_size, self.features, name="shared_embed",
                          embedding_init=nn.initializers.he_normal())
        a_emb, b_emb = shared(a), shared(b)
        added = a_emb + b_emb
        hidden, preacts, kernel1 = _forward_tower(added, self.num_layers, self.num_neurons,
                                                  first_layer_name_prefix="dense")
        contribution_a = jnp.dot(a_emb, kernel1)
        contribution_b = jnp.dot(b_emb, kernel1)
        logits = nn.Dense(self.group_size, kernel_init=nn.initializers.he_normal(),
                          name="output_dense")(hidden)
        return logits, preacts, contribution_a, contribution_b

    def bias(self, params):
        return params["dense_1"]["bias"]

    def extract_embeddings_ab(self, params):
        emb = params["shared_embed"]["embedding"]
        return emb, emb
    
    def _combine_embeddings(self, vec_a, vec_b):
        return vec_a + vec_b


class MLPTwoEmbed_cheating(DonutMLP):
    features: int
    num_layers: int = 1

    @nn.compact
    def __call__(self, x, training: bool = False):
        a, b = x[:, 0], x[:, 1]
        emb_a = nn.Embed(self.group_size, self.features, name="embedding_a",
                         embedding_init=nn.initializers.he_normal())
        emb_b = nn.Embed(self.group_size, self.features, name="embedding_b",
                         embedding_init=nn.initializers.he_normal())
        a_emb, b_emb = emb_a(a), emb_b(b)
        added = a_emb + b_emb
        hidden, preacts, kernel1 = _forward_tower(added, self.num_layers, self.num_neurons,
                                                  first_layer_name_prefix="dense")
        contribution_a = jnp.dot(a_emb, kernel1)
        contribution_b = jnp.dot(b_emb, kernel1)
        logits = nn.Dense(self.group_size, kernel_init=nn.initializers.he_normal(),
                          name="output_dense")(hidden)
        return logits, preacts, contribution_a, contribution_b

    def bias(self, params):
        return params["dense_1"]["bias"]

    def extract_embeddings_ab(self, params):
        return (params["embedding_a"]["embedding"],
                params["embedding_b"]["embedding"])
    
    def _combine_embeddings(self, vec_a, vec_b):
        return vec_a + vec_b


class MLPOneEmbedResidual(DonutMLP):
    features: int
    num_layers: int = 1

    @nn.compact
    def __call__(self, x, *, training: bool = False):
        a, b = x[:, 0], x[:, 1]
        shared = nn.Embed(self.group_size, self.features,
                          embedding_init=nn.initializers.he_normal(),
                          name="shared_embed")
        pos = self.param("pos_bias", nn.initializers.zeros, (2, self.features))
        a_emb = shared(a) + pos[0]
        b_emb = shared(b) + pos[1]
        V_proj = nn.Dense(
            self.features,
            use_bias=False,
            kernel_init=nn.initializers.he_normal(),
            name="V_proj",
        )
        v_a = V_proj(a_emb)
        v_b = V_proj(b_emb)
        v_sum = v_a + v_b
        O_proj = nn.Dense(
            self.features,
            use_bias=False,
            kernel_init=nn.initializers.he_normal(),
            name="O_proj",
        )
        mix = O_proj(v_sum)
        resid1 = b_emb + mix
        hidden, preacts, kernel1 = _forward_tower(
            resid1,
            num_layers=self.num_layers,
            num_neurons=self.num_neurons,
            first_layer_name_prefix="dense"
        )
        logits = nn.Dense(self.group_size,
                          kernel_init=nn.initializers.he_normal(),
                          name="output_dense")(hidden)
        contribution_a = jnp.dot(a_emb, kernel1)
        contribution_b = jnp.dot(b_emb, kernel1)
        return logits, preacts, contribution_a, contribution_b

    def bias(self, params):
        return params["dense_1"]["bias"]

    def extract_embeddings_ab(self, params):
        emb = params["shared_embed"]["embedding"]
        pos = params["pos_bias"]
        return emb + pos[0], emb + pos[1]

    def call_from_embedding(self, emb: jnp.ndarray, params: dict):
        D = params["dense_1"]["kernel"].shape[0]
        if emb.shape[-1] != 2 * D:
            raise ValueError(
                f"call_from_embedding expects {2*D} features, got {emb.shape[-1]}"
            )
        a_emb, b_emb = jnp.split(emb, 2, axis=-1)
        W1, b1 = params["dense_1"]["kernel"], params["dense_1"]["bias"]
        a_pre = a_emb @ W1 + b1
        b_pre = b_emb @ W1 + b1
        a_act = jnp.maximum(a_pre, 0.)
        b_act = jnp.maximum(b_pre, 0.)
        h_sum  = a_act + b_act
        if "residual_proj" in params:
            W_res, b_res = params["residual_proj"]["kernel"], params["residual_proj"]["bias"]
            h_final = b_act + (b_emb @ W_res + b_res) + h_sum
        else:
            h_final = b_act + h_sum
        W_out, b_out = params["output_dense"]["kernel"], params["output_dense"]["bias"]
        logits = h_final @ W_out + b_out
        return logits, [b_pre]
