"""
Simple checks to see if output of GemmaMachiatto layer is the same in inference and normal training.
TODO: Transform rpint statements into asserts.
Precision is low so outputs are close but not exactly the same, so jnp.isclose will probably result in a fail.
"""

from latte_trans.models.modules.inference.gemma_mach import (
    CausalRopeLatteMachiattoSliding,
)
from functools import partial
from tqdm import tqdm
from flax import struct
from jax import numpy as jnp
import jax
import flax.linen as nn


@struct.dataclass
class Config:
    num_key_value_heads: int = 1
    nheads: int = 2
    L: int = 4
    dropout: float = 0.0
    dropout_att: float = 0.0
    head_dim: int = 4
    hidden_dim: int = 8
    pos_embed_max_len: int = 1024
    att_block_len: int = 4
    attention_bias: bool = False
    initializer_range: float = 0.02


def model_apply(fn, params, cache, dropout_key, data):
    res, cache = fn(
        {"params": params},
        data,
        train=False,
        cache=cache,
        do_inference=True,
        rngs={"dropout": dropout_key},
    )
    return res, cache


def test_within_window():
    config = Config()
    model = CausalRopeLatteMachiattoSliding(config=config)
    key = jax.random.PRNGKey(0)
    key, init_rng, dropout_key, dropout_key1 = jax.random.split(key, 4)
    data = jax.random.normal(key, (2, 3, 16))
    variables = model.init(init_rng, data, do_inference=False)
    jax.debug.print("Tests:")
    model_apply2 = partial(model_apply, model.apply)
    model_apply2 = jax.jit(model_apply2)
    res, cache = model.apply(
        {"params": variables["params"]},
        data,
        train=False,
        cache=None,
        do_inference=False,
        rngs={"dropout": dropout_key},
    )
    # res, cache = model_apply2(variables["params"], None, dropout_key1, data)
    res = res[:, :-1, :]  # get rid of the last token
    print("One go: ", res.shape)
    jax.debug.print("One go Logits: {x}", x=res[0])
    cache = None
    all_res = []

    for i in tqdm(range(1, data.shape[1])):
        dropout_key1, dropout_key = jax.random.split(dropout_key)
        res, cache = model_apply2(
            variables["params"], cache, dropout_key1, data[:, :i, :]
        )
        all_res.append(res)
    all_res = jnp.concatenate(all_res, axis=1)
    print("all_res: ", all_res.shape)
    jax.debug.print("All logits: {x}", x=all_res[0])


def test_outside_window():
    config = Config()
    model = CausalRopeLatteMachiattoSliding(config=config)
    key = jax.random.PRNGKey(0)
    key, init_rng, dropout_key, dropout_key1 = jax.random.split(key, 4)
    data = jax.random.normal(key, (2, 15, 16))
    variables = model.init(init_rng, data, do_inference=False)
    jax.debug.print("Tests:")
    model_apply2 = partial(model_apply, model.apply)
    model_apply2 = jax.jit(model_apply2)
    res, cache = model.apply(
        {"params": variables["params"]},
        data,
        train=False,
        cache=None,
        do_inference=False,
        rngs={"dropout": dropout_key},
    )
    # res, cache = model_apply2(variables["params"], None, dropout_key1, data)
    res = res[:, :-1, :]  # get rid of the last token
    print("One go: ", res.shape)
    jax.debug.print("One go Logits: {x}", x=res[0])
    cache = None
    all_res = []

    for i in tqdm(range(1, data.shape[1])):
        dropout_key1, dropout_key = jax.random.split(dropout_key)
        res, cache = model_apply2(
            variables["params"], cache, dropout_key1, data[:, :i, :]
        )
        all_res.append(res)
    all_res = jnp.concatenate(all_res, axis=1)
    print("all_res: ", all_res.shape)
    jax.debug.print("All logits: {x}", x=all_res[0])


if __name__ == "__main__":
    test_outside_window()
