from src.models.feature_net import FeatureNet

import torch
from torch import nn
from typing import Optional
from transformers.cache_utils import Cache

class LinearAttnCache(Cache):
    Gps: Optional[torch.Tensor]
    Grenorm: Optional[torch.Tensor]

    def __init__(self):
        self.Gps = None
        self.Grenorm = None
    
    def attn(self, k_prime: torch.Tensor, q_prime: torch.Tensor, v: torch.Tensor, attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        _Gps = k_prime.unsqueeze(-1) * v.unsqueeze(-2)
        _Grenorm = k_prime.unsqueeze(-1) * torch.ones_like(v.unsqueeze(-2))

        _attention_mask = attention_mask.squeeze(1).unsqueeze(-1).unsqueeze(-1)
        _attention_mask = _attention_mask.clamp(min=-1., max=0.) + 1.
        _Gps = _Gps * _attention_mask
        _Grenorm = _Grenorm * _attention_mask

        if k_prime.size(2) > 1:
            _Gps = _Gps.cumsum(dim=2)
            _Grenorm = _Grenorm.cumsum(dim=2)
        
        if self.Gps is not None:
            _Gps = _Gps + self.Gps.unsqueeze(2)
            _Grenorm = _Grenorm + self.Grenorm.unsqueeze(2)

        att_raw = torch.einsum("bhlfd,bhlf->bhld", _Gps, q_prime)
        att_norm = torch.einsum("bhlfd,bhlf->bhld", _Grenorm, q_prime)

        self.Gps = _Gps[:, :, -1, :, :]
        self.Grenorm = _Grenorm[:, :, -1, :, :]
            
        return att_raw, att_norm

    def size(self):
        return torch.Size([1, self.past_length, 1, 1])

def linear_attention(feature_net: FeatureNet, attn_dropout: nn.Dropout, 
                     query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, 
                     attention_mask: torch.Tensor, head_mask=None, cache: Optional[LinearAttnCache] = None, 
                     adaptive_shift: float=None, bias: torch.Tensor=None, kernel_clip: torch.Tensor=None) -> tuple[torch.Tensor, torch.Tensor]:
    attn_info = {"query": query, "key": key}

    if adaptive_shift is not None:
        log_query_feature = feature_net.get_log_features(query)
        max_log_query_feature = log_query_feature.amax(dim=-1, keepdim=True)
        query_feature = torch.exp(adaptive_shift 
                                  + log_query_feature - max_log_query_feature)
        
        log_key_feature = feature_net.get_log_features(key)
        max_log_key_feature = log_key_feature.amax(dim=(-2, -1), keepdim=True)
        key_feature = torch.exp(adaptive_shift 
                                + log_key_feature - max_log_key_feature)
        
    else:
        query_feature = feature_net(query, korq='q') 
        key_feature = feature_net(key, korq='k') 

        max_query_feature = query_feature.amax(dim=-1, keepdim=True) 
        max_key_feature = key_feature.amax(dim=(-2, -1), keepdim=True) 
        
        max_query_feature = torch.clamp(max_query_feature, min=1.)
        max_key_feature = torch.clamp(max_key_feature, min=1.)

        query_feature = query_feature / max_query_feature
        key_feature = key_feature / max_key_feature
    
    if cache is not None:
        attn_weights = None

        att_raw, att_norm = cache.attn(key_feature, query_feature, value, attention_mask)

        att_norm = att_norm.clamp(min=kernel_clip)

        attn_output = att_raw / att_norm

    else:
        lin_kernel = torch.einsum("bhqf,bhkf->bhqk", query_feature, key_feature)
            
        if True:
            query_length, key_length = query.size(-2), key.size(-2)
            causal_mask = bias[:, :, key_length - query_length : key_length, :key_length]
            mask_value = torch.full([], 0., dtype=lin_kernel.dtype, device=lin_kernel.device)
            lin_kernel = torch.where(causal_mask, lin_kernel.to(lin_kernel.dtype), mask_value)

        if attention_mask is not None:
            _attn_mask = torch.clamp(attention_mask, -1., 0.) + 1.
            lin_kernel = lin_kernel * _attn_mask

        lin_kernel = lin_kernel.clamp(min=kernel_clip)

        attn_querywise_max = lin_kernel.max(dim=-1, keepdim=True).values 
        
        lin_kernel = lin_kernel / attn_querywise_max
        lin_kernel = torch.clamp(lin_kernel, min=1e-6)

        kernel_sum = torch.sum(lin_kernel, dim=-1) 
        attn_weights = lin_kernel / kernel_sum.unsqueeze(-1)

        attn_weights = attn_weights.type(value.dtype)
        attn_weights = attn_dropout(attn_weights)

        if head_mask is not None:
            attn_weights = attn_weights * head_mask

        attn_info["weight"] = attn_weights
        attn_output = torch.matmul(attn_weights, value)

    return attn_output, attn_info