from typing import Iterable
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from Alibi import ALiBiCausalAttention
from typing import Tuple, Dict, List

class GLU(nn.Module):
    def __init__(self, hiddim: int, interdim: int) -> None:
        super().__init__()
        self.lin_up = nn.Linear(hiddim, interdim)
        self.lin_gate = nn.Sequential(nn.Linear(hiddim, interdim), nn.SiLU(inplace=True))
        self.lin_down = nn.Linear(interdim, hiddim)

    def forward(self, x: Tensor):
        '''
        x (L, d) float
        '''
        return self.lin_down(self.lin_up(x) * self.lin_gate(x))
        
class ALiBiTransformer(nn.Module):
    def __init__(self, hiddim: int, num_head: int, num_layer: int, inter_ratio: float = 1., attn_dp: float = 0.05):
        super().__init__()
        self.num_layer = num_layer
        decdim = int(hiddim**0.5)
        self.attns = nn.ModuleList([ALiBiCausalAttention(hiddim, num_head, attn_dp) for _ in range(num_layer)])
        self.glus = nn.ModuleList([GLU(hiddim, int(inter_ratio*hiddim)) for _ in range(num_layer)])
        self.ln = nn.LayerNorm(hiddim, elementwise_affine=False)

        
    def forward(self, x: Tensor, realtime: Tensor, past_kvts: List[Tuple[Tensor, Tensor, Tensor]]=None, out_cache: bool=False):
        '''
        x (L, d) float
        basetime (L, ) float
        realtime (L, ) float
        '''

        ret_caches = []
        for i in range(self.num_layer):
            normedx = self.ln(x)
            t_q = realtime.unsqueeze(-1)
            t_k = realtime.unsqueeze(-1)  
            
            attnout, cache = self.attns[i].forward(normedx, t_q, t_k, None if past_kvts is None else past_kvts[i])
            x = x + attnout

            if out_cache:
                ret_caches.append(cache)
            
            x = x + self.glus[i](self.ln(x))

        return x, ret_caches

