import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
from fast_attention import *


def favor_attention_all(q, k, v, projection_matrix, para_dict):

   q_prime = softmax_kernel_transformation(q, True, projection_matrix)     # [N, m]
   k_prime = softmax_kernel_transformation(k, False, projection_matrix)    # [N, m]
   attention = torch.einsum("Xm, Ym->XY", q_prime, k_prime)
   attention_normalize = noncausal_denominator(q_prime, k_prime)    # [N]
   attention_normalize = attention_normalize.unsqueeze(-1)     # [N,1]
   attention = attention / attention_normalize
   
   if para_dict['nega']:
       N = attention.shape[0]
       beta_0 = para_dict['beta_0']
       nega_k = para_dict['nega_k']
       beta = (N * beta_0)/(N-nega_k)
       values__,index = torch.topk(attention, nega_k, largest = False)
       for i  in range(attention.shape[0]):
           attention[i, index[i]] -= (beta/nega_k)
       attention += (beta/N)
   
   if para_dict['regu']:
       alpha =  para_dict['alpha']
       attention = attention - torch.diag(alpha * torch.diag(attention))
   
    
   res = torch.einsum("XY, Yd->Xd", attention, v)
   return res, attention

class self_fast_attention_layer_all(nn.Module):

    def __init__(self, input_dim, output_dim, num_random, para_dict):
        # super(self_fast_attention_layer_all , self).__init__()
        super().__init__()
        self.input_dim = torch.tensor(input_dim, dtype=torch.float32)
        self.output_dim = torch.tensor(input_dim, dtype=torch.float32)
        self.W_q = nn.Linear(input_dim, input_dim, bias = False)
        self.W_k = nn.Linear(input_dim, input_dim, bias=False)
        self.W_v = nn.Linear(input_dim, input_dim, bias=False)
        self.softmax = nn.Softmax(dim = -1)
        self.projection_matrix = create_projection_matrix(num_random, input_dim)
        self.register_buffer("projectionmatrix", self.projection_matrix)
        self.para_dict = para_dict
        
        if self.para_dict['augu']:
            if self.para_dict['augu_k']:
                self.augu_k = nn.Sequential(
                    nn.Linear(input_dim, input_dim*4),
                    nn.ELU(),
                    nn.Linear(input_dim*4, input_dim),
                    nn.ELU()
                )
            if self.para_dict['augu_v']:
                self.augu_v = nn.Sequential(
                    nn.Linear(input_dim, input_dim*4),
                    nn.ELU(),
                    nn.Linear(input_dim*4, input_dim),
                    nn.ELU()
                )
        
    def forward(self, x):  # x: [token_num, input_dim]
        
        Q = self.W_q(x)  # Q, K, V : [token_num, output_dim]
        K = self.W_k(x)
        V = self.W_v(x)
        
        if self.para_dict['augu']:
            if self.para_dict['augu_k']:
                K = self.augu_k(K)
            if self.para_dict['augu_v']:
                V = self.augu_v(V)  
                # stopgrad_V = V.detach() 
        
        self.Q = Q
        self.K = K
        self.V = V
        
        # self.output_exact, self.Attention_exact = self.exact_res(Q,K,V)
        self.output, self.Attention = favor_attention_all(Q, K, V, self.projection_matrix, self.para_dict)
        return self.output, self.Attention
