import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
from fast_attention import *

class linear_model(nn.Module):

    def __init__(self, input_dim, output_dim, weight = None):
        super(linear_model, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.weight = weight
        self.linear = nn.Linear(input_dim, output_dim,bias=False)
        # print(self.linear.weight)
        if weight != None:
            self.linear.weight = weight

    def forward(self, x):   # x : [token_num, input_dim]
        res = self.linear(x)    # res: [token_num, output_dim]
        return res
    


class self_attention_layer(nn.Module):    

    def __init__(self, input_dim, output_dim):
        super(self_attention_layer,self).__init__()
        self.input_dim = torch.tensor(input_dim, dtype=torch.float32)
        self.output_dim = torch.tensor(output_dim, dtype=torch.float32)
        self.W_q = nn.Linear(input_dim, output_dim, bias = False)
        self.W_k = nn.Linear(input_dim, output_dim, bias=False)
        self.W_v = nn.Linear(input_dim, output_dim, bias=False)
        self.softmax = nn.Softmax(dim = -1)

    def forward(self, x): # x: [token_num, input_dim]

        Q = self.W_q(x)  # Q, K, V : [token_num, output_dim]
        # print("Q++++" + str(Q.shape))
        K = self.W_k(x)
        V = self.W_v(x)
        # print("V++++" + str(V.shape))
        Attention = self.softmax(Q @ K.transpose(0,1) / torch.sqrt(self.input_dim))    # Attention : [token_num, token_num]
        # print("Attention +++ " + str(Attention.shape))
        output = (Attention @ V)     # output: [token_num, output_dim]
        return output, Attention

    def get_weight(self):

        return self.W_q.weight, self.W_k.weight, self.W_v.weight

class self_fast_attention_layer(nn.Module):

    def __init__(self, input_dim, output_dim, num_random):
        super(self_fast_attention_layer,self).__init__()
        self.input_dim = torch.tensor(input_dim, dtype=torch.float32)
        self.output_dim = torch.tensor(output_dim, dtype=torch.float32)
        self.W_q = nn.Linear(input_dim, output_dim, bias = False)
        self.W_k = nn.Linear(input_dim, output_dim, bias=False)
        self.W_v = nn.Linear(input_dim, output_dim, bias=False)
        self.softmax = nn.Softmax(dim = -1)
        self.projection_matrix = create_projection_matrix(num_random, input_dim)
        self.register_buffer("projectionmatrix", self.projection_matrix)
        
        
    def forward(self, x):  # x: [token_num, input_dim]

        Q = self.W_q(x)  # Q, K, V : [token_num, output_dim]
        K = self.W_k(x)
        V = self.W_v(x)
        self.Q = Q
        self.K = K
        self.V = V
        self.output_exact, self.Attention_exact = self.exact_res(Q,K,V)
        self.output, self.Attention = favor_attention(Q, K, V, self.projection_matrix)
        return self.output, self.Attention
    

    def exact_res(self, Q,K,V):
        Attention_exact = self.softmax(
            Q @ K.transpose(0, 1) / torch.sqrt(self.input_dim))  # Attention : [token_num, token_num]
        output_exact = torch.einsum("XY,Yd->Xd", Attention_exact, V)
        return output_exact, Attention_exact
    
    
    def get_output(self):
        return self.output, self.output_exact
    
    
    def get_V_attention(self):
        return self.V, self.Attention, self.Attention_exact

    def get_mse_error(self):
        output_error = torch.nn.functional.mse_loss(self.output, self.output_exact)
        attention_error = torch.nn.functional.mse_loss(self.Attention, self.Attention_exact)
        # print("attention mse error is {attention_error}, output mse error is {output_error}".format(output_error = output_error, attention_error = attention_error))
        return attention_error, output_error
        
        
    def get_mae_error(self):
        output_error = torch.nn.functional.l1_loss(self.output,self.output_exact)
        attention_error = torch.nn.functional.l1_loss(self.Attention, self.Attention_exact)
        # print("attention mae error is {attention_error}, output mae error is {output_error}".format(output_error = output_error, attention_error = attention_error))
        return attention_error, output_error
    
    def get_max_error(self):
        output_error = torch.max(torch.abs(self.output - self.output_exact))
        attention_error = torch.max(torch.abs(self.Attention - self.Attention_exact))
        print("attention max error is {attention_error}, output max error is {output_error}".format(output_error = output_error, attention_error = attention_error))

    def get_QKV_mean_std(self):

        print("Q mean is {mean} , var is {std}".format(mean=torch.mean(self.Q), std=torch.var(self.Q)))
        print("K mean is {mean} , var is {std}".format(mean=torch.mean(self.K), std=torch.var(self.K)))
        print("V mean is {mean} , var is {std}".format(mean=torch.mean(self.V), std=torch.var(self.V)))

    def get_weight(self):
        return self.W_q.weight, self.W_k.weight, self.W_v.weight

    def init_weight(self, Q, K, V):  
        self.W_q.weight = Parameter(Q)
        self.W_k.weight = Parameter(K)
        self.W_v.weight = Parameter(V)
        print("  Parameter migration complete!  ")

class readout_head_mlp(nn.Module):  

    def __init__(self, input_dim, hidden_dim, output_dim):      # the output_dim is "one part" of input_dim

        super(readout_head_mlp,self).__init__()
        self.hidden_dim = hidden_dim
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.model = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim,output_dim),
            nn.GELU()
        )

    def forward(self, tokens):      # tokens: [token_num, input_dim]
        predictions = self.model(tokens)
        return predictions          # predictions: [token_num, output_dim]

class one_layer_transformer(nn.Module):    
    def __init__(self, input_dim, output_dim, hidden_dim, readout_type = "direct"):

        super(one_layer_transformer,self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.readout_type = readout_type

        self.sa_layer = self_attention_layer(input_dim, input_dim)
        if readout_type == "linear":
            self.readout_head = nn.Linear(input_dim, output_dim)
        elif readout_type == "mlp":
            self.readout_head = readout_head_mlp(input_dim, hidden_dim, output_dim)
        elif readout_type == "direct":
            self.readout_head = None

    def forward(self, tokens):  
        updated_tokens, attention = self.sa_layer(tokens)      # token : [token_num, input_dim]
        return updated_tokens, attention



class one_layer_fast_transformer(nn.Module):   

    def __init__(self, input_dim, output_dim, hidden_dim, num_random, readout_type = "direct"):

        super(one_layer_fast_transformer,self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.readout_type = readout_type
        self.num_random = num_random
        self.sa_layer = self_fast_attention_layer(input_dim, input_dim, num_random)

        if readout_type == "linear":
            self.readout_head = nn.Linear(input_dim, output_dim)
        elif readout_type == "mlp":
            self.readout_head = readout_head_mlp(input_dim, hidden_dim, output_dim)
        elif readout_type == "direct":
            self.readout_head = None

    def forward(self, tokens):   
        updated_tokens, attention = self.sa_layer(tokens)      # token : [token_num, input_dim]
        return updated_tokens, attention



#%%

class self_fast_regul_attention_layer(nn.Module):

    def __init__(self, input_dim, output_dim, num_random, alpha):
        super(self_fast_regul_attention_layer,self).__init__()
        self.input_dim = torch.tensor(input_dim, dtype=torch.float32)
        self.output_dim = torch.tensor(output_dim, dtype=torch.float32)
        self.W_q = nn.Linear(input_dim, output_dim, bias = False)
        self.W_k = nn.Linear(input_dim, output_dim, bias=False)
        self.W_v = nn.Linear(input_dim, output_dim, bias=False)
        self.softmax = nn.Softmax(dim = -1)
        self.projection_matrix = create_projection_matrix(num_random, input_dim)
        self.register_buffer("projectionmatrix", self.projection_matrix)
        self.alpha = alpha
        
        
    def forward(self, x):  # x: [token_num, input_dim]

        Q = self.W_q(x)  # Q, K, V : [token_num, output_dim]
        K = self.W_k(x)
        V = self.W_v(x)
        self.Q = Q
        self.K = K
        self.V = V
        self.output_exact, self.Attention_exact = self.exact_res(Q,K,V)
        self.output, self.Attention = favor_attention_regul(Q, K, V, self.projection_matrix, self.alpha)
        return self.output, self.Attention
    
    def exact_res(self, Q,K,V):
        Attention_exact = self.softmax(
            Q @ K.transpose(0, 1) / torch.sqrt(self.input_dim))  # Attention : [token_num, token_num]
        output_exact = torch.einsum("XY,Yd->Xd", Attention_exact, V)
        return output_exact, Attention_exact
    
    def get_output(self):
        return self.output, self.output_exact
    
    
    def get_V_attention(self):
        return self.V, self.Attention, self.Attention_exact

    def get_mse_error(self):
        output_error = torch.nn.functional.mse_loss(self.output, self.output_exact)
        attention_error = torch.nn.functional.mse_loss(self.Attention, self.Attention_exact)
        print("attention mse error is {attention_error}, output mse error is {output_error}".format(output_error = output_error, attention_error = attention_error))

    def get_mae_error(self):
        output_error = torch.nn.functional.l1_loss(self.output,self.output_exact)
        attention_error = torch.nn.functional.l1_loss(self.Attention, self.Attention_exact)
        print("attention mae error is {attention_error}, output mae error is {output_error}".format(output_error = output_error, attention_error = attention_error))

    def get_max_error(self):
        output_error = torch.max(torch.abs(self.output - self.output_exact))
        attention_error = torch.max(torch.abs(self.Attention - self.Attention_exact))
        print("attention max error is {attention_error}, output max error is {output_error}".format(output_error = output_error, attention_error = attention_error))

    def get_QKV_mean_std(self):

        print("Q mean is {mean} , var is {std}".format(mean=torch.mean(self.Q), std=torch.var(self.Q)))
        print("K mean is {mean} , var is {std}".format(mean=torch.mean(self.K), std=torch.var(self.K)))
        print("V mean is {mean} , var is {std}".format(mean=torch.mean(self.V), std=torch.var(self.V)))

    def get_weight(self):
        return self.W_q.weight, self.W_k.weight, self.W_v.weight

    def init_weight(self, Q, K, V):
        self.W_q.weight = Parameter(Q)
        self.W_k.weight = Parameter(K)
        self.W_v.weight = Parameter(V)
        print("  Parameter migration complete!  ")
    
class one_layer_fast_regul_transformer(nn.Module):  
    def __init__(self, input_dim, output_dim, hidden_dim, num_random, readout_type = "direct", alpha = 0.5):

        super(one_layer_fast_regul_transformer,self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.readout_type = readout_type
        self.num_random = num_random
        self.sa_layer = self_fast_regul_attention_layer(input_dim, input_dim, num_random, alpha)

        if readout_type == "linear":
            self.readout_head = nn.Linear(input_dim, output_dim)
        elif readout_type == "mlp":
            self.readout_head = readout_head_mlp(input_dim, hidden_dim, output_dim)
        elif readout_type == "direct":
            self.readout_head = None

    def forward(self, tokens):   
        updated_tokens, attention = self.sa_layer(tokens)      # token : [token_num, input_dim]
        return updated_tokens, attention    


#%%


class augu_mlp(nn.Module):     

    def __init__(self, input_dim, hidden_dim, output_dim):      # the output_dim is "one part" of input_dim

        super(augu_mlp,self).__init__()
        self.hidden_dim = hidden_dim
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.model = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ELU(),
            nn.Linear(hidden_dim, output_dim),
            nn.ELU()
        )

    def forward(self, tokens):      # tokens: [token_num, input_dim]
        predictions = self.model(tokens)
        return predictions          # predictions: [token_num, output_dim]



class self_fast_augu_attention_layer(nn.Module):

    def __init__(self, input_dim, output_dim, num_random):
        super(self_fast_augu_attention_layer,self).__init__()
        
        self.ratio = 1
        self.input_dim = torch.tensor(input_dim, dtype=torch.float32)
        self.output_dim = torch.tensor(output_dim, dtype=torch.float32)
        self.W_q = nn.Linear(input_dim, input_dim, bias = False)
        self.W_k = nn.Linear(input_dim, input_dim, bias=False)
        self.W_v = nn.Linear(input_dim, input_dim, bias=False)
        
        self.augu_k = augu_mlp(input_dim, input_dim*self.ratio, output_dim)
        self.augu_v = augu_mlp(input_dim, input_dim*self.ratio, output_dim)
        
        self.softmax = nn.Softmax(dim = -1)
        self.projection_matrix = create_projection_matrix(num_random, input_dim)
        self.register_buffer("projectionmatrix", self.projection_matrix)
        
        
    def forward(self, x):  # x: [token_num, input_dim]

        Q = self.W_q(x)  # Q, K, V : [token_num, output_dim]
        K = self.W_k(x)
        V = self.W_v(x)
        
        K = self.augu_k(K)
        
        self.Q = Q
        self.K = K
        self.V = V
        
        self.output_exact, self.Attention_exact = self.exact_res(Q,K,V)
        self.output, self.Attention = favor_attention(Q, K, V, self.projection_matrix) 
        return self.output, self.Attention

    def exact_res(self, Q,K,V):
        Attention_exact = self.softmax(
            Q @ K.transpose(0, 1) / torch.sqrt(self.input_dim))  # Attention : [token_num, token_num]
        output_exact = torch.einsum("XY,Yd->Xd", Attention_exact, V)
        return output_exact, Attention_exact
    
    def get_output(self):
        return self.output, self.output_exact
    
    
    def get_V_attention(self):
        return self.V, self.Attention, self.Attention_exact

    def get_mse_error(self):
        output_error = torch.nn.functional.mse_loss(self.output, self.output_exact)
        attention_error = torch.nn.functional.mse_loss(self.Attention, self.Attention_exact)
        print("attention mse error is {attention_error}, output mse error is {output_error}".format(output_error = output_error, attention_error = attention_error))

    def get_mae_error(self):
        output_error = torch.nn.functional.l1_loss(self.output,self.output_exact)
        attention_error = torch.nn.functional.l1_loss(self.Attention, self.Attention_exact)
        print("attention mae error is {attention_error}, output mae error is {output_error}".format(output_error = output_error, attention_error = attention_error))

    def get_max_error(self):
        output_error = torch.max(torch.abs(self.output - self.output_exact))
        attention_error = torch.max(torch.abs(self.Attention - self.Attention_exact))
        print("attention max error is {attention_error}, output max error is {output_error}".format(output_error = output_error, attention_error = attention_error))

    def get_QKV_mean_std(self):

        print("Q mean is {mean} , var is {std}".format(mean=torch.mean(self.Q), std=torch.var(self.Q)))
        print("K mean is {mean} , var is {std}".format(mean=torch.mean(self.K), std=torch.var(self.K)))
        print("V mean is {mean} , var is {std}".format(mean=torch.mean(self.V), std=torch.var(self.V)))

    def get_weight(self):
        return self.W_q.weight, self.W_k.weight, self.W_v.weight

    def init_weight(self, Q, K, V): 
        self.W_q.weight = Parameter(Q)
        self.W_k.weight = Parameter(K)
        self.W_v.weight = Parameter(V)
        print("  Parameter migration complete!  ")
        
        
class one_layer_fast_augu_transformer(nn.Module):   

    def __init__(self, input_dim, output_dim, hidden_dim, num_random, readout_type = "direct"):

        super(one_layer_fast_augu_transformer,self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.readout_type = readout_type
        self.num_random = num_random
        self.sa_layer = self_fast_augu_attention_layer(input_dim, input_dim, num_random)

        if readout_type == "linear":
            self.readout_head = nn.Linear(input_dim, output_dim)
        elif readout_type == "mlp":
            self.readout_head = readout_head_mlp(input_dim, hidden_dim, output_dim)
        elif readout_type == "direct":
            self.readout_head = None

    def forward(self, tokens):   
        updated_tokens, attention = self.sa_layer(tokens)      # token : [token_num, input_dim]
        return updated_tokens, attention    

#%%


class self_fast_nega_attention_layer(nn.Module):

    def __init__(self, input_dim, output_dim, num_random, nega_k = 1):
        super(self_fast_nega_attention_layer,self).__init__()
        self.input_dim = torch.tensor(input_dim, dtype=torch.float32)
        self.output_dim = torch.tensor(output_dim, dtype=torch.float32)
        self.W_q = nn.Linear(input_dim, output_dim, bias = False)
        self.W_k = nn.Linear(input_dim, output_dim, bias=False)
        self.W_v = nn.Linear(input_dim, output_dim, bias=False)
        self.softmax = nn.Softmax(dim = -1)
        self.projection_matrix = create_projection_matrix(num_random, input_dim)
        self.register_buffer("projectionmatrix", self.projection_matrix)
        self.nega_k = nega_k
        
    def forward(self, x):  # x: [token_num, input_dim]

        Q = self.W_q(x)  # Q, K, V : [token_num, output_dim]
        K = self.W_k(x)
        V = self.W_v(x)
        self.Q = Q
        self.K = K
        self.V = V
        self.output_exact, self.Attention_exact = self.exact_res(Q,K,V)
        self.output, self.Attention = favor_attention_nega(Q, K, V, self.projection_matrix, self.nega_k)
        return self.output, self.Attention
    
    def exact_res(self, Q,K,V):
        Attention_exact = self.softmax(
            Q @ K.transpose(0, 1) / torch.sqrt(self.input_dim))  # Attention : [token_num, token_num]
        output_exact = torch.einsum("XY,Yd->Xd", Attention_exact, V)
        return output_exact, Attention_exact
    
    def get_output(self):
        return self.output, self.output_exact
    
    
    def get_V_attention(self):
        return self.V, self.Attention, self.Attention_exact

    def get_mse_error(self):
        output_error = torch.nn.functional.mse_loss(self.output, self.output_exact)
        attention_error = torch.nn.functional.mse_loss(self.Attention, self.Attention_exact)
        print("attention mse error is {attention_error}, output mse error is {output_error}".format(output_error = output_error, attention_error = attention_error))

    def get_mae_error(self):
        output_error = torch.nn.functional.l1_loss(self.output,self.output_exact)
        attention_error = torch.nn.functional.l1_loss(self.Attention, self.Attention_exact)
        print("attention mae error is {attention_error}, output mae error is {output_error}".format(output_error = output_error, attention_error = attention_error))

    def get_max_error(self):
        output_error = torch.max(torch.abs(self.output - self.output_exact))
        attention_error = torch.max(torch.abs(self.Attention - self.Attention_exact))
        print("attention max error is {attention_error}, output max error is {output_error}".format(output_error = output_error, attention_error = attention_error))

    def get_QKV_mean_std(self):

        print("Q mean is {mean} , var is {std}".format(mean=torch.mean(self.Q), std=torch.var(self.Q)))
        print("K mean is {mean} , var is {std}".format(mean=torch.mean(self.K), std=torch.var(self.K)))
        print("V mean is {mean} , var is {std}".format(mean=torch.mean(self.V), std=torch.var(self.V)))

    def get_weight(self):
        return self.W_q.weight, self.W_k.weight, self.W_v.weight

    def init_weight(self, Q, K, V): 
        self.W_q.weight = Parameter(Q)
        self.W_k.weight = Parameter(K)
        self.W_v.weight = Parameter(V)
        print("  Parameter migration complete!  ")
        
class one_layer_fast_nega_transformer(nn.Module):  
    def __init__(self, input_dim, output_dim, hidden_dim, num_random, readout_type = "direct", nega_k = 1):

        super(one_layer_fast_nega_transformer,self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.readout_type = readout_type
        self.num_random = num_random
        self.nega_k = nega_k
        self.sa_layer = self_fast_nega_attention_layer(input_dim, input_dim, num_random, nega_k)

        if readout_type == "linear":
            self.readout_head = nn.Linear(input_dim, output_dim)
        elif readout_type == "mlp":
            self.readout_head = readout_head_mlp(input_dim, hidden_dim, output_dim)
        elif readout_type == "direct":
            self.readout_head = None

    def forward(self, tokens):  
        updated_tokens, attention = self.sa_layer(tokens)      # token : [token_num, input_dim]
        return updated_tokens, attention   
    

#%%
class FFN(nn.Module):
    
    def __init__(self, input_dim, output_dim, hidden_dim):
        super(FFN,self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.hidden_dim = hidden_dim
        self.F1 = nn.Linear(input_dim, hidden_dim, bias=True)
        self.F2 = nn.Linear(hidden_dim, output_dim, bias = True)

    def forward(self, x): # x: [token_num, input_dim]
        self.res_tmp = self.F1(x)   # res_tmp: [token_num, hidden_dim]
        res_1 = nn.functional.relu(self.res_tmp)
        res_2 = self.F2(res_1)
        return res_2

    def get_wb(self):
        b1 = self.F1.bias
        b2 = self.F2.bias
        W1 = self.F1.weight
        W2 = self.F2.weight
        res_tmp = self.res_tmp[-1,:]  # # res_tmp: [token_num, hidden_dim]
        I_M = torch.diag((res_tmp> 0).to(torch.float32))
        W_F = torch.matmul(W2, torch.matmul(I_M, W1))
        b_F = torch.matmul(W2, torch.matmul(I_M, b1)) + b2
        return W_F, b_F
    

class one_layer_atten_FFN(nn.Module):
    
    def __init__(self, input_dim, output_dim, hidden_dim, num_random):

        super(one_layer_atten_FFN,self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_random = num_random
        self.sa_layer = self_fast_attention_layer(input_dim, input_dim, num_random)
        self.FFN = FFN(input_dim, input_dim, hidden_dim)

    def forward(self, tokens): 
        updated_tokens, attention = self.sa_layer(tokens)      # token : [token_num, input_dim]
        res = self.FFN(updated_tokens)
        return res, attention


