import torch
import numpy as np
import math
import torch.nn.functional as F
from create_data import *

def create_projection_matrix(m, d,scale=0): # m: num of random, d: orginal dim
    setup_seed(202)
    final_matrix = np.random.normal(size = (m, d))
    for i in range(m):
        final_matrix[i,:] = final_matrix[i,:]/np.linalg.norm(final_matrix[i,:])
    final_matrix = torch.tensor(final_matrix, dtype=torch.float32)

    if scale == 0:
        multiplier = torch.norm(torch.normal(mean = 0, std = 1, size=(m,d)), dim=1)
    elif scale == 1:
        multiplier = torch.sqrt(torch.tensor(d, dtype=torch.float32))*torch.ones((m))

    return torch.matmul(torch.diag(multiplier), final_matrix)

def softmax_kernel_transformation(data, is_query, projection_matrix = None, episilon = 0.000001):
    # data: [token_num, token_dim], for simplify, no batch_num and multi-heads
    token_num, token_dim = data.shape[0], data.shape[1] # N, d
    num_random = projection_matrix.shape[0]     # m  projection_matrix: [m, d]
    data_normalizer = 1.0 / torch.sqrt(torch.tensor(token_dim, dtype = torch.float32))
    data = data_normalizer * data
    ratio = 1.0/ torch.sqrt(torch.tensor(num_random, dtype=torch.float32))
    data_dash = torch.einsum("Nd,md->Nm", data, projection_matrix)
    diag_data = torch.square(data)
    diag_data = torch.sum(diag_data, dim = -1)
    diag_data = diag_data / 2.0
    diag_data = diag_data.unsqueeze(dim = -1)
    if is_query == True:
        data_dash = ratio * (
            torch.exp(data_dash - diag_data - torch.amax(data_dash, dim = -1, keepdim = True)[0])
        + episilon)
    else:
        data_dash = ratio * (
                torch.exp(data_dash - diag_data - torch.amax(data_dash, dim=(0,-1), keepdim=True)[0])
                + episilon)
        # print(data_dash)
    return data_dash # data_dash: [token_num, random_num]

def noncausal_numerator(q, k, v):
    # q,k: [token_num, random_num], v: [token_num, token_dim]
    kv = torch.einsum("Nm, Nd->md", k, v)
    q_kv = torch.einsum("Nm,md->Nd", q, kv)
    return q_kv # res: [token_num, token_dim]

def noncausal_denominator(q, k):
    # q,k: [token_num, random_num]
    token_num = q.shape[0]
    all_ones = torch.ones(token_num)
    k_one = torch.einsum("Nm,N->m", k, all_ones)
    v_k_one = torch.einsum("Nm,m->N", q, k_one)
    return v_k_one


def favor_attention(q,k,v,projection_matrix):
    print(v.shape)
    print('********************')
    # q, k, v: [token_num, token_dim]   projection_matrix: [num_random ,token_dim]
    q_prime = softmax_kernel_transformation(q, True, projection_matrix)     # [N, m]
    k_prime = softmax_kernel_transformation(k, False, projection_matrix)    # [N, m]
    # print("v:   "+ str(v[-1,:]))
    av_attention = noncausal_numerator(q_prime, k_prime, v)     # [N,d]
    attention_normalize = noncausal_denominator(q_prime, k_prime)   # [N]
    attention_normalize = attention_normalize.unsqueeze(-1)     # [N,1]
    # print(attention_normalize[:,:])
    attention = torch.einsum("Xm, Ym->XY", q_prime, k_prime)
    attention = attention / attention_normalize
    res = av_attention/attention_normalize
    return res, attention


def favor_attention_regul(q,k,v,projection_matrix, alpha=0.1):
    
    q_prime = softmax_kernel_transformation(q, True, projection_matrix)     # [N, m]
    k_prime = softmax_kernel_transformation(k, False, projection_matrix)    # [N, m]
    attention = torch.einsum("Xm, Ym->XY", q_prime, k_prime)
    attention_normalize = noncausal_denominator(q_prime, k_prime)   # [N]
    attention_normalize = attention_normalize.unsqueeze(-1)     # [N,1]
    attention = attention / attention_normalize
    
    attention = attention - torch.diag(alpha * torch.diag(attention))
    
    res = torch.einsum("XY, Yd->Xd", attention, v)
    return res, attention


def favor_attention_nega(q, k, v, projection_matrix, nega_k=1):
    
    q_prime = softmax_kernel_transformation(q, True, projection_matrix)     # [N, m]
    k_prime = softmax_kernel_transformation(k, False, projection_matrix)    # [N, m]
    attention = torch.einsum("Xm, Ym->XY", q_prime, k_prime)
    attention_normalize = noncausal_denominator(q_prime, k_prime)   # [N]
    attention_normalize = attention_normalize.unsqueeze(-1)     # [N,1]
    attention = attention / attention_normalize
    
    N = attention.shape[0]
    alpha = 0.5
    beta = (N*alpha)/(N-nega_k)
    values__,index = torch.topk(attention, nega_k, largest = False)
    for i  in range(attention.shape[0]):
        attention[i, index[i]] -= (beta/nega_k)
    attention += (beta/N)
    
    res = torch.einsum("XY, Yd->Xd", attention, v)
    return res, attention


def test_softmax_noncausal_attention_block_output(N, d, m):

    query = torch.normal(mean = -0.001247, std = 0.4965, size=(N,d))
    key = torch.normal(mean = 0.000384, std = 1.4052, size=(N,d))
    value = torch.normal(mean = 0.0157, std = 4.395, size=(N,d))
    # value = torch.empty(size=(N, d), dtype=torch.float32).uniform_(-5, 5)

    projection_matrix = create_projection_matrix(m,d,scale=0)

    favor_output,  attention = favor_attention(query, key, value, projection_matrix)

    query = query/math.sqrt(float(d))
    attention_scores = torch.einsum("Xd,Yd->XY", query, key)
    attention_scores = F.softmax(attention_scores, dim = 1)
    exact_output = torch.einsum("XY,Yd->Xd", attention_scores, value)

    error_a = torch.abs(attention - attention_scores)
    error_a_tmp = torch.square(error_a)
    error_a_mse = torch.mean(error_a_tmp)

    error = torch.abs(exact_output - favor_output)
    error_tmp = torch.square(error)
    error_mse = torch.mean(error_tmp)
    return error_mse

