"""
This file is from https://github.com/mlpen/Nystromformer
"""

from http.cookiejar import LoadError
import torch
import torch.nn as nn
from torch.nn import functional as F
import math
import json
from torch.utils.checkpoint import checkpoint
import pdb
import numpy as np

class SoftmaxAttention(nn.Module):
    def __init__(self, config, idx):
        super().__init__()
        self.drop_attn = torch.nn.Dropout(p=config.attention_dropout)
        self.head_dim = config.head_dim
        self.beta = config.beta
        self.layer = idx

    def forward(self, Q, K, V, mask):
        if self.beta is not None:
            mean_k = K.mean(dim = -2, keepdim = True)
            Q = Q - self.beta*mean_k
            K = K - self.beta*mean_k
        dot = torch.matmul(Q, torch.transpose(K, -2, -1))
        dot = dot / math.sqrt(self.head_dim)
        dot = dot - 1e6 * (1 - mask[:, None, None, :])

        attn = nn.functional.softmax(dot, dim = -1)

        attn = self.drop_attn(attn)

        X = torch.matmul(attn, V)
        return X

class SHAttention(nn.Module):
    def __init__(self, config, idx):
        super().__init__()
        self.drop_attn = torch.nn.Dropout(p=config.attention_dropout)
        self.head_dim = config.head_dim
        self.beta = config.beta
        self.layer = idx

    def forward(self, Q, K, V, mask):

        if self.beta is not None:
            mean_k = K.mean(dim = -2, keepdim = True)
            Q = Q - self.beta*mean_k
            K = K - self.beta*mean_k

        out = torch.zeros_like(Q)

        #head 1
        dot = torch.matmul(Q[:,0], torch.transpose(K[:,0], -2, -1))
        dot = dot / math.sqrt(self.head_dim)
        dot = dot - 1e6 * (1 - mask[:, None, :])

        attn = nn.functional.softmax(dot, dim = -1)


        attn = self.drop_attn(attn)
        
        # import pdb;pdb.set_trace()
        out[:,0] = torch.matmul(attn, V[:,0])

        #head 2
        v_ = F.avg_pool1d(V[:, 1].transpose(1, 2), kernel_size = 2, stride = 2, ceil_mode = True).transpose(1,2)
        k_ = F.avg_pool1d(K[:, 1].transpose(1, 2), kernel_size = 2, stride = 2, ceil_mode = True).transpose(1,2)

        dot = torch.matmul(Q[:,1], torch.transpose(k_, -2, -1))
        dot = dot / math.sqrt(self.head_dim)
        mask = F.avg_pool1d(mask, kernel_size = 2, stride = 2, ceil_mode = True)
        dot = dot - 1e6 * (1 - mask[:, None, :])
        attn = nn.functional.softmax(dot, dim = -1)

        attn = self.drop_attn(attn)

        out[:,1] = torch.matmul(attn, v)

        return out

class NoneAttention(nn.Module):
    def __init__(self, config):
        super().__init__()

    def forward(self, Q, K, V, mask):
        return V


class Attention(nn.Module):
    def __init__(self, config, idx):
        super().__init__()

        self.grad_checkpointing = config.attention_grad_checkpointing

        self.dim = config.transformer_dim
        self.head_dim = config.head_dim
        self.num_head = config.num_head

        self.attn_type = config.attn_type

        self.W_q = nn.Linear(self.dim, self.num_head * self.head_dim)
        self.W_k = nn.Linear(self.dim, self.num_head * self.head_dim)
        self.W_v = nn.Linear(self.dim, self.num_head * self.head_dim)

        self.dconv_fc = None

        if self.attn_type == "softmax":
            self.attn = SoftmaxAttention(config, idx)
        if self.attn_type == "sh":
            self.attn = SHAttention(config, idx)
        elif self.attn_type == "none":
            self.attn = NoneAttention(config)
        elif self.attn_type.startswith("linformer"):
            from attention_linformer import LinformerAttention
            self.attn = LinformerAttention(config)

        elif self.attn_type.startswith("reformer"):
            from attention_reformer import LSHAttention
            self.attn = LSHAttention(config, self.W_q, self.W_k, self.W_v)
        elif self.attn_type.startswith("nystrom"):
            from attention_nystrom import NystromAttention
            self.attn = NystromAttention(config)
        elif self.attn_type.startswith("performer"):
            from attention_performer import PerformerAttention
            self.attn = PerformerAttention(config)
        elif self.attn_type.startswith("linear"):
            from attention_linear import LinearAttention
            self.attn = LinearAttention(config)
        elif self.attn_type.startswith("mra_head"):
            from attention_mra_head import mra_headAttention
            self.attn = mra_headAttention(config)

        self.ff = nn.Linear(self.num_head * self.head_dim, self.dim)

    def forward(self, X, mask):

        if self.attn_type.startswith("longformer") or self.attn_type.startswith("reformer"):
            with torch.cuda.amp.autocast(enabled = False):
                attn_out = self.attn(X.float(), mask.float())
        else:
            Q = self.split_heads(self.W_q(X))
            K = self.split_heads(self.W_k(X))
            V = self.split_heads(self.W_v(X))

            with torch.cuda.amp.autocast(enabled = False):
                if self.grad_checkpointing:
                    attn_out = checkpoint(self.attn, Q.float(), K.float(), V.float(), mask.float())
                else:
                    attn_out = self.attn(Q.float(), K.float(), V.float(), mask.float())
            attn_out = self.combine_heads(attn_out)

        out = self.ff(attn_out)

        return out


    def combine_heads(self, X):
        X = X.transpose(1, 2)
        X = X.reshape(X.size(0), X.size(1), self.num_head * self.head_dim)
        return X

    def split_heads(self, X):
        X = X.reshape(X.size(0), X.size(1), self.num_head, self.head_dim)
        X = X.transpose(1, 2)
        return X
