"""Inspire from: https://github.com/lucidrains/local-attention/blob/master/local_attention/local_attention.py#L44"""

import jax
from jax import numpy as jnp
import einops


class LocalAtt:
    def __init__(self, window_size: int, exact_windowsize: bool = True):
        self.window_size = window_size
        self.exact_windowsize = exact_windowsize

    @staticmethod
    def look_around(x, backward=1, pad_value=-1, dim=2):
        t = x.shape[1]
        pad_width = len(x.shape) * [(0, 0)]
        pad_width[1] = (backward, 0)
        padded_x = jnp.pad(x, pad_width=pad_width, constant_values=pad_value)
        tensors = [padded_x[:, ind : (ind + t), ...] for ind in range(backward + 1)]
        return jnp.concatenate(tensors, axis=dim)

    def __call__(self, Q, K, V):
        """
        Args:
            Q,K,V: jax.Array(B,T,D)
        """
        B, H, T, D = Q.shape
        assert (
            T % self.window_size
        ) == 0, f"sequence length {T} must be divisible by window size {self.window_size} for local attention"
        pad_value = -1
        windows = T // self.window_size
        # merge batch and heads for ease
        (Q, packed_shape), (K, _), (V, _) = map(
            lambda t: einops.pack([t], "* n d"), (Q, K, V)
        )
        bq, bk, bv = map(
            lambda t: einops.rearrange(t, "b (w n) d -> b w n d", w=windows), (Q, K, V)
        )
        bq = bq * (D**-0.5)  # attention scale sqrt(1/dim_head)
        # concatenate one window ahead to make sure first token had w context length
        bk = self.look_around(bk, backward=1, pad_value=pad_value)
        bv = self.look_around(bv, backward=1, pad_value=pad_value)

        # apply rotary positions to bq, bk
        sim = einops.einsum(bq, bk, "b h i e, b h j e -> b h i j")

        # handle padding
        seq = jnp.arange(T)
        b_t = einops.rearrange(seq, "(w n) -> 1 w n", w=windows, n=self.window_size)
        bq_t = b_t
        bq_k = self.look_around(b_t, backward=1, pad_value=pad_value)

        bq_t = einops.rearrange(bq_t, "... i -> ... i 1")
        bq_k = einops.rearrange(bq_k, "... j -> ... 1 j")
        pad_mask = bq_k == pad_value
        causal_mask = bq_t < bq_k
        print(causal_mask.shape)
        jax.debug.print("CAUSAL mask {x}", x=causal_mask[0])
        if self.exact_windowsize:
            causal_mask = causal_mask | (bq_t > (bq_k + self.window_size))

        jax.debug.print("CAUSAL mask2 {x}", x=(causal_mask | pad_mask)[0])
        sim = jnp.where(causal_mask, -9e15, sim)
        sim = jnp.where(pad_mask, -9e15, sim)

        attn = jax.nn.softmax(sim, axis=-1)
        out = einops.einsum(attn, bv, "b h i j, b h j e -> b h i e")
        out = einops.rearrange(out, "b w n d -> b (w n) d")
        out, *_ = einops.unpack(out, packed_shape, "* n d")
        return out, sim


def main():
    Q = jnp.array(
        [
            [
                [1, 1, 1],
                [2, 2, 2],
                [3, 3, 3],
                [4, 4, 4],
                [4, 4, 4],
                [4, 4, 4],
                [4, 4, 4],
                [4, 4, 4],
            ],
            [
                [10, 10, 10],
                [20, 20, 20],
                [30, 30, 30],
                [40, 40, 40],
                [4, 4, 4],
                [4, 4, 4],
                [4, 4, 4],
                [4, 4, 4],
            ],
            [
                [10, 10, 10],
                [20, 20, 20],
                [30, 30, 30],
                [40, 40, 40],
                [4, 4, 4],
                [4, 4, 4],
                [4, 4, 4],
                [4, 4, 4],
            ],
            [
                [10, 10, 10],
                [20, 20, 20],
                [30, 30, 30],
                [40, 40, 40],
                [4, 4, 4],
                [4, 4, 4],
                [4, 4, 4],
                [4, 4, 4],
            ],
            [
                [10, 10, 10],
                [20, 20, 20],
                [30, 30, 30],
                [40, 40, 40],
                [4, 4, 4],
                [4, 4, 4],
                [4, 4, 4],
                [4, 4, 4],
            ],
        ]
    )
    Q = Q[None, ...]
    jax.debug.print("Q is: {x}", x=Q)
    print("Q shape: ", Q.shape)
    K = jnp.copy(Q)
    V = jnp.copy(Q)

    LA = LocalAtt(window_size=4, exact_windowsize=True)
    LA(Q, K, V)


if __name__ == "__main__":
    main()
