'''
modified from https://github.com/jaketae/alibi/blob/main/alibi/attention.py
'''

import math
import torch
from torch import Tensor
from torch import nn
from torch.nn import functional as F
from typing import Tuple, Optional


def get_relative_positions(t_q: Tensor, t_k: Tensor) -> Tensor:
    '''
    t_q: (*, L_q, ) float, time of query
    t_k: (*, L_k, ) float, time of key
    ret: (*, L_q, L_k)
    '''
    return t_k.unsqueeze(-2) - t_q.unsqueeze(-1)



class ALiBiCausalAttention(nn.Module):
    def __init__(self, hiddim: int, num_head: int, attn_dp: float) -> None:
        super().__init__()
        assert hiddim % num_head == 0, "hiddim should equal to num_head * head_dim"
        self.num_head = num_head
        self.head_dim = hiddim//num_head
        self.hiddim = hiddim
        self.scaling = self.head_dim ** -0.5
        self.register_buffer("alibi_slope", torch.tensor([2 ** (-8 * (i + 1)/self.num_head) for i in range(self.num_head)]))
        self.linqkv = nn.Linear(hiddim, 3 * hiddim, bias=False)
        self.lino = nn.Linear(hiddim, hiddim, bias=False)
        self.attn_dp = attn_dp
        
    def forward(self, x: Tensor, t_q: Tensor, t_k: Tensor, pask_kvt: Optional[Tuple[Tensor, Tensor, Tensor]]=None) -> Tensor:
        '''
        x (L, d), float
        t_q (L, #head) or (L, 1), float
        t_k (L, #head) or (L, 1), float
        ret (L, d), float
        '''
        L: int = x.shape[0]

        qkv: Tensor = self.linqkv(x)
        q, k, v = qkv.chunk(3, dim=-1) # (L, d), (L, d), (L, d)
        q = q.unflatten(-1, (self.num_head, self.head_dim)).transpose(0, 1) # (#head, L, headdim)
        k = k.unflatten(-1, (self.num_head, self.head_dim)).transpose(0, 1) # (#head, L, headdim)
        v = v.unflatten(-1, (self.num_head, self.head_dim)).transpose(0, 1) # (#head, L, headdim)

        bias: Tensor = self.alibi_slope.reshape(-1, 1, 1) * get_relative_positions(t_q.t(), t_k.t()) # (#head, L, L)
        bias = bias.masked_fill_(torch.tril(torch.ones((L, L), dtype=torch.bool, device=bias.device)).logical_not_(), float("-inf")) # (#head, L, L), may be useless when is_causal=True in F.scaled_dot_product_attention

        if pask_kvt is not None:
            past_k, past_v, past_t = pask_kvt
            past_bias = self.alibi_slope.reshape(-1, 1, 1) * get_relative_positions(t_q.t(), past_t.t())

            t_k = torch.concat((past_t, t_k), dim=0)
            k = torch.concat((past_k, k), dim=1)
            v = torch.concat((past_v, v), dim=1)
            bias = torch.concat((past_bias, bias), dim=2)

        try:
            o = F.scaled_dot_product_attention(q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0), bias.unsqueeze(0), self.attn_dp if self.training else 0., is_causal=pask_kvt is None, scale=self.scaling).squeeze(0) # (#head, L, headdim)
        except:
            o = F.scaled_dot_product_attention(q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0), bias.unsqueeze(0), self.attn_dp if self.training else 0., is_causal=False).squeeze(0) # (#head, L, headdim)
        o = o.transpose(0, 1).flatten(1, 2) # (L, d)

        ret = self.lino(o)
        
        return ret, (k, v, t_k)