import jax
import jax.numpy as jnp
from flax import nnx
from layers import *


class Reasoner(nnx.Module):
    def __init__(self, feature: int, attn_feature: int, ffn_feature: int, num_head: int, decoder_count: int,
                 is_causal: bool, init_scalar: float, vocab_size: int, key, ema_interval: int = None,
                 dtype=jnp.float32):
        """ The JEPA-Reasoner Model """
        super().__init__()
        keys = jax.random.split(key, num=decoder_count + 2)

        self.featurizer = Featurize(vocab_size=vocab_size, feature=feature, ema_interval=ema_interval, key=keys[-1],
                                    dtype=dtype)
        self.output_norm = nnx.RMSNorm(
            num_features=feature,
            epsilon=1e-6,
            dtype=dtype,
            param_dtype=jnp.float32,
            rngs=nnx.rnglib.Rngs(keys[-2])
        )
        self.decoders = [
            MultiHeadDecoder(
                feature=feature,
                attn_feature=attn_feature,
                ffn_feature=ffn_feature,
                num_head=num_head,
                is_causal=is_causal,
                init_scalar=init_scalar,
                key=keys[i],
                dtype=dtype
            ) for i in range(decoder_count)
        ]

    def embed(self, x):
        """ Create JEPA embedding.
        :param x: token sequence, shaped (batch_size, context_len)
        :return: embedded token sequence, shaped (batch_size, context_len, feature)
        :note: intended for inference
        """
        return self.featurizer.embed(x)

    def embed_ema(self, x):
        """ Create JEPA exponential averaged embedding.
        :param x: token sequence, shaped (batch_size, context_len)
        :return: embedded token sequence, shaped (batch_size, context_len, feature)
        :note: intended for target embedding
        """
        return self.featurizer.embed_ema(x)

    def embed_rec(self, x):
        """ Create JEPA exponential averaged embedding.
        :param x: token sequence, shaped (batch_size, context_len)
        :return: embedded token sequence, shaped (batch_size, context_len, feature)
        :note: intended for data embedding in EMA enabled training
        """
        return self.featurizer(x)

    def reset_embed_ema(self, ema_interval: int):
        """ Reset EMA cache of embedding to current embedding values """
        self.featurizer.ema_reset(ema_interval)

    def assemble(self, x):
        """ Assemble token from JEPA embedding.
        :param x: embedding matrix, shaped (batch_size, context_len, feature)
        :return: assembled sequence of token probabilities, shaped (batch_size, context_len, vocab_size)
        """
        return self.featurizer.assemble(x)

    def reason_free(self, x, mask=None, q_len=None, kv_len=None):
        """ Reason without L2 norm
        :param x: embedding matrix, shaped (batch_size, context_len, feature)
        :param mask: mask matrix, shaped (batch_size, context_len{query}, context_len{key, value})
        :param q_len: valid sequence length of Q
        :param kv_len: valid sequence length of KV
        :return: reasoning result in embedding matrix, shaped (batch_size, context_len, feature)
        """
        for decoder_instance in self.decoders:
            x = decoder_instance(x, mask=mask, q_len=q_len, kv_len=kv_len)
        return self.output_norm(x)

    def reason(self, x, mask=None, q_len=None, kv_len=None):
        """ Reason with L2 norm, forcing to unit hypersphere
        :param x: embedding matrix, shaped (batch_size, context_len, feature)
        :param mask: mask matrix, shaped (batch_size, context_len{query}, context_len{key, value})
        :param q_len: valid sequence length of Q
        :param kv_len: valid sequence length of KV
        :return: reasoning result in embedding matrix forced to unit hypersphere, shaped (batch_size, context_len, feature)
        """
        x = self.reason_free(x, mask, q_len, kv_len)
        return x / jnp.sqrt(jnp.sum(x ** 2, axis=-1, keepdims=True))


class DualTalker(nnx.Module):
    def __init__(self, feature: int, latent_feature: int, attn_feature: int, ffn_feature: int, num_head: int,
                 encoder_count: int, decoder_count: int, init_scalar: float, vocab_size: int, key, dtype=jnp.float32):
        """ The Talker Model: translate embedding space into tokens with encoder-decoder architecture
        Handles different input latent length and output token length
        """
        super().__init__()
        keys = jax.random.split(key, num=encoder_count + decoder_count + 4)

        self.featurizer = Featurize(vocab_size=vocab_size, feature=feature, ema_interval=None, key=keys[-1], dtype=dtype)
        self.norm = nnx.RMSNorm(
            num_features=feature,
            epsilon=1e-6,
            dtype=dtype,
            param_dtype=jnp.float32,
            rngs=nnx.rnglib.Rngs(keys[-2])
        )
        self.latent_projection = nnx.Linear(
            in_features=latent_feature,
            out_features=feature,
            use_bias=True,
            rngs=nnx.rnglib.Rngs(keys[-3]),
            param_dtype=dtype,
            dtype=dtype,
        )

        # Encoder layers to process input latent
        self.encoders = [
            MultiHeadEncoder(
                feature=feature,
                attn_feature=attn_feature,
                ffn_feature=ffn_feature,
                num_head=num_head,
                init_scalar=init_scalar,
                key=keys[i],
                dtype=dtype
            ) for i in range(encoder_count)
        ]

        # Decoder layers for autoregressive token generation
        self.decoders = [
            CrossDecoder(
                feature=feature,
                attn_feature=attn_feature,
                ffn_feature=ffn_feature,
                num_head=num_head,
                init_scalar=init_scalar,
                key=keys[encoder_count + i],
                dtype=dtype
            ) for i in range(decoder_count)
        ]

        # Learnable query embeddings for decoder initialization
        initializer = jax.nn.initializers.variance_scaling(init_scalar, 'fan_in', 'normal')
        self.decoder_query_embed = nnx.Param(
            initializer(keys[-4], shape=(1, feature), dtype=dtype),
            name='DECODER_QUERY_EMBED'
        )

    def __call__(self, latent_input, latent_mask=None, decoder_mask=None,
                 latent_len=None, token_len=None, input_tokens=None):
        """ Translate latent embedding matrix into tokens using encoder-decoder architecture
        :param latent_input: latent embedding matrix, shaped (batch_size, latent_len, latent_feature)
        :param token_len: target token sequence length, shape (batch,)
        :param latent_mask: mask for encoder input, shaped (batch_size, latent_len)
        :param decoder_mask: causal mask for decoder, shaped (batch_size, token_len)
        :param latent_len: valid sequence length of latent input
        :param input_tokens: inputted tokens for decoder, shaped (batch_size, token_len)
        :return: sequence of token probabilities, shaped (batch_size, token_len, vocab_size)
        """
        # Project latent to feature space
        encoder_input = self.latent_projection(latent_input)

        # Encode the latent input
        for encoder_instance in self.encoders:
            encoder_input = encoder_instance(encoder_input, mask=latent_mask, q_len=latent_len, kv_len=latent_len)

        decoder_input = self.featurizer.embed(input_tokens)
        # Decode with cross-attention to encoder output
        for decoder_instance in self.decoders:
            decoder_input = decoder_instance(
                decoder_input,
                encoder_input,
                q_len=token_len,
                kv_len=latent_len,
                encoder_len=latent_len
            )

        # Final normalization and projection to vocabulary
        decoder_output = self.norm(decoder_input)
        return self.featurizer.assemble(decoder_output)


class MonoTalker(nnx.Module):
    def __init__(self, feature: int, latent_feature: int, attn_feature: int, ffn_feature: int, num_head: int,
                 decoder_count: int, is_causal: bool, init_scalar: float, vocab_size: int, key, dtype=jnp.float32):
        """ The Talker Model: translate embedding space into tokens with encoder-decoder architecture
        Handles different input latent length and output token length
        """
        super().__init__()
        keys = jax.random.split(key, num=decoder_count + 3)

        self.featurizer = Featurize(vocab_size=vocab_size, feature=feature, ema_interval=None, key=keys[-1],
                                    dtype=dtype)
        self.output_norm = nnx.RMSNorm(
            num_features=feature,
            epsilon=1e-6,
            dtype=dtype,
            param_dtype=jnp.float32,
            rngs=nnx.rnglib.Rngs(keys[-2])
        )
        self.latent_projection = nnx.Linear(
            in_features=latent_feature,
            out_features=feature,
            use_bias=True,
            rngs=nnx.rnglib.Rngs(keys[-3]),
            param_dtype=dtype,
            dtype=dtype,
        )

        # Decoder layers for autoregressive token generation
        self.decoders = [
            MultiHeadDecoder(
                feature=feature,
                attn_feature=attn_feature,
                ffn_feature=ffn_feature,
                num_head=num_head,
                is_causal=is_causal,
                init_scalar=init_scalar,
                key=keys[i],
                dtype=dtype
            ) for i in range(decoder_count)
        ]

    def __call__(self, x, mask=None, q_len=None, kv_len=None):
        """ Translate latent embedding matrix into tokens using encoder-decoder architecture
        :param x: latent embedding matrix, shaped (batch_size, latent_len, latent_feature)
        :param mask: mask matrix, shaped (batch_size, context_len{query}, context_len{key, value})
        :param q_len: valid sequence length of Q
        :param kv_len: valid sequence length of KV
        :return: logits
        """
        x = self.latent_projection(x)

        for decoder_instance in self.decoders:
            x = decoder_instance(x, mask=mask, q_len=q_len, kv_len=kv_len)

        x = self.output_norm(x)
        return self.featurizer.assemble(x)


class Transformer(nnx.Module):
    def __init__(self, feature: int, attn_feature: int, ffn_feature: int, num_head: int, decoder_count: int,
                 is_causal: bool, init_scalar: float, vocab_size: int, key, dtype=jnp.float32):
        """ The Reed Model
        input shape: (batch, context_len)
        output shape: (batch, context_len, vocab_size)
        """
        super().__init__()
        keys = jax.random.split(key, num=decoder_count + 3)

        self.featurizer = Featurize(vocab_size=vocab_size, feature=feature, key=keys[-1], dtype=dtype)
        self.norm = nnx.RMSNorm(
            num_features=feature,
            epsilon=1e-6,
            dtype=dtype,
            param_dtype=jnp.float32,
            rngs=nnx.rnglib.Rngs(keys[-3])
        )

        self.decoders = [
            MultiHeadDecoder(
                feature=feature,
                attn_feature=attn_feature,
                ffn_feature=ffn_feature,
                num_head=num_head,
                is_causal=is_causal,
                init_scalar=init_scalar,
                key=keys[i],
                dtype=dtype
            ) for i in range(decoder_count)
        ]

    def get_last_hidden(self, x, mask=None, q_len=None, kv_len=None):
        x = self.featurizer.embed(x)

        for decoder_instance in self.decoders:
            x = decoder_instance(x, mask, q_len, kv_len)

        return self.norm(x)

    def latent_reasoning(self, latent, mask=None, q_len=None, kv_len=None):
        for decoder_instance in self.decoders:
            latent = decoder_instance(latent, mask, q_len, kv_len)

        return self.norm(latent)

    def __call__(self, x, mask=None, q_len=None, kv_len=None):
        x = self.featurizer.embed(x)

        for decoder_instance in self.decoders:
            x = decoder_instance(x, mask, q_len, kv_len)

        x = self.norm(x)
        return self.featurizer.assemble(x)

    def embed(self, x):
        return self.featurizer.embed(x)

    def assemble(self, x):
        return self.featurizer.assemble(x)