from functools import partial
from typing import List

import jax
import jax.numpy as jnp
from flax import struct


def build_rope(positions, head_dim, base):
    inv_freq = 1.0 / (
        base ** (jnp.arange(0, head_dim, 2, dtype=jnp.float32) / head_dim)
    )
    positions = positions.astype(jnp.float32)
    freqs = positions[:, :, None] * inv_freq[None, None, :]
    emb = jnp.concatenate((freqs, freqs), axis=-1)

    pad_mask = positions >= 0
    pad_mask = pad_mask[:, :, None].repeat(head_dim, axis=-1)
    emb = jnp.where(pad_mask, emb, 0.0)

    sin_values = jnp.sin(emb)
    cos_values = jnp.cos(emb)

    return sin_values, cos_values


@struct.dataclass
class KVCache:
    keys: jnp.array
    values: jnp.array

    @classmethod
    def init(cls, batch_size, max_len, model_cfg, dtype):
        num_heads = model_cfg["num_kv_heads"]
        head_dim = model_cfg["hidden_size"] // num_heads
        shape = (batch_size, max_len, num_heads, head_dim)

        return cls(
            keys=jnp.zeros(shape).astype(dtype),
            values=jnp.zeros(shape).astype(dtype),
        )


@struct.dataclass
class TransformerCache:
    sin: jnp.array
    cos: jnp.array
    cur_pos: jnp.array
    layers: List[KVCache]
    _max_pos: jnp.array
    _full_sin: jnp.array
    _full_cos: jnp.array
    dirty: bool = struct.field(default=False, pytree_node=True)
    dynamic: bool = struct.field(default=True, pytree_node=False)

    def __getitem__(self, idx):
        return self.layers[idx]

    @classmethod
    @partial(jax.jit, static_argnames=["cls", "model_config", "dtype", "dynamic"])
    def create(cls, positions, model_config, dtype=jnp.bfloat16, dynamic=True):
        pos = positions.astype(jnp.int32)
        batch_size, max_len = pos.shape
        head_dim = model_config["hidden_size"] // model_config["num_heads"]

        final_pos = jnp.arange(pos.shape[-1]).astype(jnp.int32)
        final_pos = jnp.broadcast_to(final_pos, (batch_size, pos.shape[-1]))
        sin, cos = build_rope(final_pos, head_dim, model_config["rope_base"])

        layers = [
            KVCache.init(batch_size, max_len, model_config, dtype)
            for _ in range(model_config["num_layers"])
        ]

        return cls(
            sin=sin,
            cos=cos,
            cur_pos=pos,
            layers=layers,
            dynamic=dynamic,
            _max_pos=pos,
            _full_sin=sin,
            _full_cos=cos,
        )

    @jax.jit
    def roll(self):
        batch_size, *others = self._max_pos.shape
        batch_indices = jnp.arange(batch_size).astype(jnp.int32)
        seq_lens = jnp.max(self._max_pos, axis=-1).astype(jnp.int32) + 1
        full_pos = self._max_pos.at[batch_indices, seq_lens].set(seq_lens)

        if self.dynamic:
            new_pos = full_pos[batch_indices, seq_lens][..., None]
            new_sin = self._full_sin[batch_indices, seq_lens][:, None, :]
            new_cos = self._full_cos[batch_indices, seq_lens][:, None, :]
        else:
            new_pos = full_pos
            new_sin = self.sin
            new_cos = self.cos

        return self.replace(
            cur_pos=new_pos, sin=new_sin, cos=new_cos, _max_pos=full_pos
        )
