"""
An implementation of Blockwise parallel transformer https://arxiv.org/abs/2305.19370
Also include a reference implementation of memory-efficient transformer https://arxiv.org/abs/2112.05682
"""

import functools
from typing import NamedTuple

import flax.linen as nn
import jax
import jax.lax as lax
import jax.numpy as jnp
from einops import rearrange

"""
Computing ffn blockwise without materializing the large hidden tensor, training
4x longer sequences than the memory-efficient transformer.
Blockwise parallel transformer https://arxiv.org/abs/2305.19370 Liu et al. 2023
"""


def blockwise_ffn(remat_ffn, inputs, chunk_size=2048, deterministic=True):
    # remat_ffn: a rematerialized ffn with policy jax.checkpoint_policies.nothing_saveable()
    # inputs: (batch, seq_len, dim)
    # chunk_size: the chunk size to split the sequence
    inputs = rearrange(inputs, "b (c n) d -> b c n d", c=chunk_size)

    def scan_ffn(remat_ffn, carry, hidden_states):
        outputs = remat_ffn(hidden_states, deterministic=deterministic)
        return carry, outputs

    scan_axis = inputs.ndim - 2
    _, res = nn.scan(
        scan_ffn,
        variable_broadcast="params",
        split_rngs={"params": False, "dropout": True},
        in_axes=scan_axis,
        out_axes=scan_axis,
    )(remat_ffn, None, inputs)
    res = rearrange(res, "b c n d -> b (c n) d")
    return res


"""
Compute attention blockwise without materializing the full attention matrix,
initially proposed in memory-efficient transformer https://arxiv.org/abs/2112.05682 Rabe et al. 2021;
flash attention https://arxiv.org/abs/2205.14135 Dao et al. 2022 proposes a CUDA
efficient implementation; blockwise parallel transformer https://arxiv.org/abs/2305.19370
Liu et al. 2023 proposes blockwise computing both attention and FFN, enabling 4x
longer sequences than memory-efficient/flash-attention and fusion of attention and FFN.
"""


def blockwise_attn(
    query,
    key,
    value,
    bias=None,
    deterministic=True,
    dropout_rng=None,
    attn_pdrop=0.0,
    causal=True,
    query_chunk_size=2048,
    key_chunk_size=2048,
    dtype=jnp.float32,
    policy=jax.checkpoint_policies.nothing_saveable(),
    precision=None,
    float32_logits=True,
    prevent_cse=True,
    norm_query=True,
):
    # query, key, value: (batch, seq_len, num_heads, dim_per_head)
    # bias: (batch, seq_len) can be used to mask out attention (e.g. padding)
    # causal: whether to use causal mask
    # policy: one of jax.checkpoint_policies
    if norm_query:
        query = query / jnp.sqrt(query.shape[-1]).astype(dtype)
    if float32_logits:
        query = query.astype(jnp.float32)
        key = key.astype(jnp.float32)

    batch, q_len, num_heads, dim_per_head = query.shape
    batch, kv_len, num_heads, dim_per_head = key.shape
    batch, kv_len, num_heads, dim_per_head = value.shape

    num_q = q_len // query_chunk_size
    num_kv = kv_len // key_chunk_size
    query = query.reshape((batch, num_q, query_chunk_size, num_heads, dim_per_head))
    key = key.reshape((batch, num_kv, key_chunk_size, num_heads, dim_per_head))
    value = value.reshape((batch, num_kv, key_chunk_size, num_heads, dim_per_head))

    query = jnp.moveaxis(query, 1, 0)
    key = jnp.moveaxis(key, 1, 0)
    value = jnp.moveaxis(value, 1, 0)

    if bias is not None:
        for bias_dim, broadcast_dim in zip(
            bias.shape, (batch, num_heads, q_len, kv_len)
        ):
            assert bias_dim == 1 or bias_dim == broadcast_dim
    if not deterministic and attn_pdrop > 0.0:
        attn_dropout_rng, dropout_rng = jax.random.split(dropout_rng)
        attn_dropout = jax.random.bernoulli(
            attn_dropout_rng, attn_pdrop, (batch, num_heads, q_len, kv_len)
        )
    else:
        attn_dropout = None

    _chunk_bias_fn = functools.partial(
        _chunk_attention_bias,
        query_chunk_size,
        key_chunk_size,
        bias,
        deterministic,
        attn_dropout,
        attn_pdrop,
        causal,
        dtype,
    )

    def scan_attention(args):
        query_chunk, query_chunk_idx = args

        @functools.partial(jax.checkpoint, prevent_cse=prevent_cse, policy=policy)
        def scan_kv_block(carry, args):
            key_chunk, value_chunk, key_chunk_idx = args
            (numerator, denominator, prev_max_score) = carry
            attn_weights = jnp.einsum(
                "bqhd,bkhd->bqhk", query_chunk, key_chunk, precision=precision
            )
            bias_chunk = _chunk_bias_fn(query_chunk_idx, key_chunk_idx)
            bias_chunk = jnp.moveaxis(bias_chunk, 1, 2)
            attn_weights = attn_weights + bias_chunk

            max_score = jnp.max(attn_weights, axis=-1, keepdims=True)
            max_score = jnp.maximum(prev_max_score, max_score)
            max_score = jax.lax.stop_gradient(max_score)
            exp_weights = jnp.exp(attn_weights - max_score)
            exp_values = jnp.einsum(
                "bqhv,bvhd->bqhd", exp_weights, value_chunk, precision=precision
            )
            correction = jnp.exp(prev_max_score - max_score)
            numerator = numerator * correction + exp_values
            denominator = denominator * correction + exp_weights.sum(
                axis=-1, keepdims=True
            )
            return Carry(numerator, denominator, max_score), None

        def skip_upper_half(carry, args):
            key_chunk, value_chunk, key_chunk_idx = args
            skip_block = jnp.array(False)
            if causal:
                skip_block = query_chunk_idx < key_chunk_idx
            return jax.lax.cond(
                skip_block,
                lambda carry, args: (carry, None),
                scan_kv_block,
                carry,
                args,
            )

        init_carry = Carry(
            jnp.zeros(
                (batch, query_chunk_size, num_heads, dim_per_head), dtype=query.dtype
            ),
            jnp.zeros(
                (batch, query_chunk_size, num_heads, dim_per_head), dtype=query.dtype
            ),
            (-jnp.inf)
            * jnp.ones((batch, query_chunk_size, num_heads, 1), dtype=query.dtype),
        )
        (numerator, denominator, max_score), _ = lax.scan(
            skip_upper_half, init_carry, xs=(key, value, jnp.arange(0, num_kv))
        )
        outputs = (numerator / denominator).astype(dtype)
        return outputs

    _, res = lax.scan(
        lambda _, x: ((), scan_attention(x)), (), xs=(query, jnp.arange(0, num_q))
    )
    res = rearrange(res, "n b c h d -> b (n c) h d")
    return res


class Carry(NamedTuple):
    numerator: jax.Array
    denominator: jax.Array
    max_so_far: jax.Array


def _chunk_attention_bias(
    query_chunk_size,
    key_chunk_size,
    bias,
    deterministic,
    attn_dropout,
    attn_pdrop,
    causal,
    dtype,
    query_chunk_idx,
    key_chunk_idx,
):
    query_offset = query_chunk_idx * query_chunk_size
    key_offset = key_chunk_idx * key_chunk_size
    chunk_bias = jnp.zeros((1, 1, 1, 1), dtype=dtype)
    if bias is not None:
        chunk_bias = lax.dynamic_slice(
            bias,
            start_indices=(0, 0, query_offset, key_offset),
            slice_sizes=(
                *bias.shape[:2],
                min(bias.shape[-2], query_chunk_size),
                min(bias.shape[-1], key_chunk_size),
            ),
        )

    if causal:
        query_idx = lax.broadcasted_iota(
            dtype=jnp.int32, shape=(query_chunk_size, 1), dimension=0
        )
        key_idx = lax.broadcasted_iota(
            dtype=jnp.int32, shape=(1, key_chunk_size), dimension=1
        )
        offset = query_offset - key_offset
        query_idx += offset
        causal_mask_value = (query_idx < key_idx) * jnp.finfo(dtype).min
        chunk_bias += causal_mask_value.reshape(1, 1, *causal_mask_value.shape)

    if not deterministic and attn_pdrop > 0.0:
        attn_dropout_slice = lax.dynamic_slice(
            attn_dropout,
            start_indices=(0, 0, query_offset, key_offset),
            slice_sizes=(
                *attn_dropout.shape[:2],
                min(attn_dropout.shape[-2], query_chunk_size),
                min(attn_dropout.shape[-1], key_chunk_size),
            ),
        )
        chunk_bias += attn_dropout_slice * jnp.finfo(dtype).min
    return chunk_bias.astype(dtype)


if __name__ == "__main__":
    # test
    def reference_attn(query, key, value, causal, dtype):
        query = query / jnp.sqrt(query.shape[-1]).astype(dtype)
        logits = jnp.einsum("bqhc,bkhc->bhqk", query, key)
        if causal:
            mask_value = jnp.finfo(logits.dtype).min
            _, q_seq_len, _, _ = query.shape
            _, kv_seq_len, _, _ = key.shape
            mask_shape = (q_seq_len, kv_seq_len)
            row_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 0)
            col_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 1)
            causal_mask = (row_ids < col_ids)[None, None, :, :]
            logits = logits + jnp.where(causal_mask, mask_value, 0.0)
        weights = jax.nn.softmax(logits, axis=-1)
        out = jnp.einsum("bhqk,bkhc->bqhc", weights, value)
        return out

    # random inputs
    shape = (1, 32, 8, 64)
    query = jax.random.normal(jax.random.PRNGKey(0), shape)
    key = jax.random.normal(jax.random.PRNGKey(1), shape)
    value = jax.random.normal(jax.random.PRNGKey(2), shape)

    causal = True
    chunk_size = 4
    policy = jax.checkpoint_policies.nothing_saveable()

    blockwise = blockwise_attn(
        query,
        key,
        value,
        None,
        False,
        None,
        0.0,
        causal,
        chunk_size,
        chunk_size,
        jnp.float32,
        policy,
        "float32",
        True,
        False,
    )
    reference = reference_attn(query, key, value, causal, "float32")

    assert jnp.allclose(reference, blockwise, atol=1e-6)
