import jax
import jax.numpy as jnp
from flax import nnx
from . import Cache


class Featurize(nnx.Module):
    """ Input embedding layer
    input shape: (batch, context_len)
    output shape: (batch, context_len, feature)
    """
    def __init__(self, vocab_size: int, feature: int, key, ema_interval: int = None, momentum: float = 0.98, dtype=jnp.float32):
        super().__init__()
        self.feature = feature
        self.momentum = momentum
        self.ema_interval = ema_interval
        self.dtype = dtype

        if momentum < 0 or momentum > 1:
            raise ValueError('Momentum must be in [0, 1]')

        initializer = nnx.initializers.variance_scaling(1.0, 'fan_in', 'normal', out_axis=0)

        self.embedding = nnx.Param(initializer(key, shape=(vocab_size, feature), dtype=dtype), name='EMBD_MATRIX')

        if ema_interval is not None:
            if ema_interval > jnp.iinfo(jnp.int8).max:
                raise ValueError(f'Too many accumulation step, support no more than {jnp.iinfo(jnp.int8).max}')

            self.embedding_ema = Cache(self.embedding.value)
            self.step = Cache(jnp.array(ema_interval, dtype=jnp.int8))

    def ema_reset(self, ema_interval: int):
        """ Reset EMA cache to current embedding values """
        self.embedding_ema.value = self.embedding.value
        self.step.value = ema_interval

    def __call__(self, x):
        """ Get embedding & build EMA embedding """
        def true_fn():
            return jnp.asarray(self.step.value - 1, dtype=jnp.int8), self.embedding_ema.value

        def false_fn():
            return jnp.asarray(self.ema_interval - 1, dtype=jnp.int8), self.momentum * self.embedding_ema.value + (1 - self.momentum) * self.embedding.value

        self.step.value, self.embedding_ema.value = jax.lax.cond(self.step.value, true_fn, false_fn)
        return self.embed(x)

    def embed(self, x):
        """ Get embedding """
        return jnp.take(self.embedding.value, x, axis=0)

    def embed_ema(self, x):
        """ Get exponential averaged embedding """
        return jnp.take(self.embedding_ema.value, x, axis=0)

    def assemble(self, x):
        """ Assemble embedding to logits """
        return x @ jnp.transpose(self.embedding.value, axes=(1, 0))


class RotaryEmbed(nnx.Module):
    """ Rotary positional embedding layer
    input shape: (batch, context_len, num_head, head_feature)
    output shape: (batch, context_len, num_head, head_feature)
    """
    def __init__(self, head_feature: int, dtype=jnp.float32):
        super().__init__()
        self.head_feature = head_feature
        self.dtype = dtype

        # Default settings. Configure by passing param through `instance.train()/.eval()`
        self.rope_base = 100000
        self.max_len = 8192

        self.theta = Cache(self.theta_init())
        self.embed = Cache(self.embed_init())

    def __call__(self, x):
        context_len = x.shape[-3]
        embed = self.embed.value[:context_len]

        x = x.reshape(*x.shape[:-1], -1, 2)     # (batch, context_len, num_head, head_feature / 2, 2)

        x = jnp.stack(
            [
                x[..., 0] * embed[..., 0] - x[..., 1] * embed[..., 1],
                x[..., 0] * embed[..., 1] + x[..., 1] * embed[..., 0],
                ],
            -1,
        )

        return x.reshape(*x.shape[:-2], self.head_feature)

    def embed_init(self):
        """ Create the rotary matrix
        output shape: (max_len, head_feature / 2, 2)
        """
        positions = jnp.arange(self.max_len, dtype=self.dtype)
        freq = jnp.outer(positions, self.theta.value)
        return jnp.expand_dims(jnp.stack((jnp.cos(freq), jnp.sin(freq)), axis=-1), axis=-3)

    def theta_init(self):
        """ Create theta
        output shape: (head_feature / 2,)
        """
        return 1.0 / (
                self.rope_base ** (
                jnp.arange(0, self.head_feature, 2, dtype=self.dtype)[:(self.head_feature // 2)] / self.head_feature
            )
        )
