import torch
import torch.nn as nn
import math
import torch.nn.init as init
import torch.nn.functional as F
import numpy as np


def power_method(A):
    n = A.shape[0]
    x = torch.ones(n, device=A.device, dtype=A.dtype)
    l = torch.ones(1, device=A.device, dtype=A.dtype)

    # 10 iterations of power method
    x = A@x / l
    l = x.abs().max()
    x = A@x / l
    l = x.abs().max()
    x = A@x / l
    l = x.abs().max()
    x = A @ x / l
    l = x.abs().max()
    x = A @ x / l
    l = x.abs().max()
    x = A@x / l
    l = x.abs().max()
    x = A@x / l
    l = x.abs().max()
    x = A@x / l
    l = x.abs().max()
    x = A @ x / l
    l = x.abs().max()
    x = A @ x / l
    l = x.abs().max()
        
    return l

def calc_matrix_entropy(A):
    return -(torch.sum(A*torch.log2(A + 1e-10), dim=-1) / np.log2(A.shape[-1])).mean()

class KernelAttention(nn.Module):
    def __init__(self, kernel_type="exp"):
        super().__init__()
        self.kernel_type = kernel_type
        self.register_buffer("sigma_q", torch.zeros(1).sum())
        self.register_buffer("sigma_k", torch.zeros(1).sum())
        self.register_buffer("sg", torch.zeros(1).sum())
        self.register_buffer("entropy", torch.zeros(1).sum())
        self.register_buffer("trace", torch.zeros(1).sum())
        self.eps = 1e-15
        self.calc_spectral_gap = True
        self.calc_entropy = True
        self.calc_trace = True

    def forward(self, q, k, v, mask=None):
        B, H, N, D = q.shape
        is_half = q.dtype == torch.float16
        if is_half:
            q = q.float()
            k = k.float()
            v = v.float()

            
        with torch.no_grad():
            self.sigma_q = q.std()
            self.sigma_k = k.std()
            
        if self.kernel_type == "exp":
            attn = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(D)
            attn = F.softmax(attn, dim=-1)
        elif self.kernel_type == "rbf":
            qnorm2 = q.norm(dim=-1, keepdim=True)**2
            knorm2 = q.norm(dim=-1, keepdim=True)**2
            attn = -(0.5*qnorm2 + 0.5*knorm2.mT - torch.matmul(q, k.transpose(-2, -1))) / math.sqrt(D)
            attn = F.softmax(attn, dim=-1)
        elif self.kernel_type == "quad":
            attn = (q@k.transpose(-2, -1) + 1)**2
            attn = attn / attn.sum(dim=-1, keepdim=True)
        elif self.kernel_type == "tanh":
            attn = torch.tanh(torch.matmul(q, k.transpose(-2, -1))) + 1 + self.eps
            attn = attn / attn.sum(dim=-1, keepdim=True)
        elif self.kernel_type == "quad_lin":
            attn = torch.matmul(q**2, k.transpose(-2, -1)**2)
            attn = attn / attn.sum(dim=-1, keepdim=True)
        elif self.kernel_type == "relu":
            attn = torch.relu(torch.matmul(q, k.transpose(-2, -1))) + self.eps
            attn = attn / attn.sum(dim=-1, keepdim=True)
        elif self.kernel_type == "relu_lin":
            attn = torch.matmul(torch.relu(q)+self.eps, torch.relu(k.transpose(-2, -1))+self.eps)
            attn = attn / attn.sum(dim=-1, keepdim=True)
        
        if self.calc_spectral_gap:
            with torch.no_grad():
                # Calculate spectral gap of sample 0 head 0
                P = attn - 1/N
                lambda2 = 0
                for i in range(min(4, P.shape[0])):
                    lambda2 += power_method(P[i][0])
                self.sg = torch.abs(1 - lambda2 / (i+1))
        
        if self.calc_entropy:
            with torch.no_grad():
                self.entropy = calc_matrix_entropy(attn)
        
        if self.calc_trace:
            with torch.no_grad():
                self.trace = attn.diagonal(offset=0, dim1=-1, dim2=-2).mean()
        
        out = torch.matmul(attn, v)

        if is_half:
            out = out.half()

        return out
    
    def __repr__(self):
        return super().__repr__() + ": kernel_type={}".format(self.kernel_type)

    
def rms_norm(X, eps=1e-5):
    return X / (X.norm(dim=-1, keepdim=True) / X.shape[-1]**0.5 + eps)

    
class LLNAttention(nn.Module):
    def __init__(self, c=1., normalize=False, eps=1e-8):
        super().__init__()
        self.normalize = normalize
        self.eps = eps        

    def forward(self, q, k, v, mask=None):
        B, H, N, D = q.shape

        is_half = q.dtype == torch.float16
        if is_half:
            q = q.float()
            k = k.float()
            v = v.float()

        with torch.no_grad():            
            sig_q = q.std()
            sig_k = k.std()
            a = 0.1485517825109655
            b = -0.3548703913401894

            sig_tild = ((sig_q**2 * sig_k**2 - b) / (2*a))**0.5
            alpha = sig_tild / sig_q
            beta = sig_tild / sig_k
            

        q = alpha * q
        k = beta * k

        q = q - q.amax(dim=-1, keepdim=True)
        k = k - k.amax(dim=(-2, -1), keepdim=True)

        Q = torch.exp(q)
        K = torch.exp(k)
        G = torch.matmul(Q, torch.matmul(K.transpose(-2, -1), v))
        S = torch.einsum('...nd,...d->...n', Q, K.sum(dim=-2)) + self.eps

        out = G / S.unsqueeze(-1)
        if self.normalize:
            out = torch.nn.functional.layer_norm(out, [D])
#             out = rms_norm(out)

        if is_half:
            out = out.half()

        return out

    def __str__(self):
        return self.str()
    
    def __repr__(self):
        return self.str()
    
    def str(self):
        return "LLNAttention(normalize={}, c={}, eps={})".format(self.normalize, self.c, self.eps)


def scaled_dot_product_attn(q, k, v, mask):
    B, H, N, D = q.shape
    
    q = q / math.sqrt(D)
    attn = torch.matmul(q, k.transpose(-2, -1))
#     if attn_mask is not None:
#         attn += attn_mask
    attn = F.softmax(attn, dim=-1)
    # if dropout_p > 0.0:
    #     attn = F.dropout(attn, p=dropout_p)

    output = torch.matmul(attn, v)
    return output

class LLNPlusAttention(nn.Module):
    def __init__(self, c=1., normalize=False, eps=1e-8):
        super().__init__()
        self.normalize = normalize
        self.exp_attn = LLNAttention(eps=eps)
        
    def forward(self, q, k, v, mask=None):
        lin_attn_out = self.exp_attn(q, k, v, mask)
        
        B, H, N, D = q.shape
        ND = D * int(N / D)
        h = H * int(ND / D)
        Q = q[:,:,:ND,:].reshape(B, h, D, D)
        K = k[:,:,:ND,:].reshape(B, h, D, D)
        V = v[:,:,:ND,:].reshape(B, h, D, D)
        
        attn_out = scaled_dot_product_attn(Q, K, V, mask).reshape(B, H, ND, D)
        attn_out_rem = scaled_dot_product_attn(q[:,:,ND:,:], k[:,:,ND:,:], v[:,:,ND:,:], mask)
        attn_out = torch.cat([attn_out, attn_out_rem], dim=-2)
        
        out = 0.5 * (lin_attn_out + attn_out)
        if self.normalize:
            out = torch.nn.functional.layer_norm(out, [D])
        return out
    
    def __str__(self):
        return self.str()
    
    def __repr__(self):
        return self.str()
    
    def str(self):
        return "LLNPlusAttention(normalize={}, c={}, lln={})".format(self.normalize, self.c, self.exp_attn)

    
class BlockDAttention(nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, q, k, v, mask=None):
        B, H, N, D = q.shape
        ND = D * int(N / D)
        h = H * int(ND / D)
        Q = q[:,:,:ND,:].reshape(B, h, D, D)
        K = k[:,:,:ND,:].reshape(B, h, D, D)
        V = v[:,:,:ND,:].reshape(B, h, D, D)
        
        attn_out = scaled_dot_product_attn(Q, K, V, mask).reshape(B, H, ND, D)
        attn_out_rem = scaled_dot_product_attn(q[:,:,ND:,:], k[:,:,ND:,:], v[:,:,ND:,:], mask)
        attn_out = torch.cat([attn_out, attn_out_rem], dim=-2)
        
        return attn_out