import torch
import torch.nn as nn
import torch.nn.functional as F
import seaborn as sns
import numpy as np
import math
from math import sqrt
import os
import torch
 
  
class Flow_Attention(nn.Module):
    def __init__(self, attention_dropout=0.1, beta = None):
        super(Flow_Attention, self).__init__()
        self.dropout = nn.Dropout(attention_dropout)
        self.beta = beta
 
    def kernel_method(self, x):
        return torch.sigmoid(x)
 
    def forward(self, queries, keys, values, beta0, gamma):
        queries = queries.transpose(1, 2) #nhld
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)
         
        mean_k = keys.mean(dim = -2, keepdim = True)
        
        queries = queries -self.beta* mean_k
        keys = keys - self.beta* mean_k
    
        attn = torch.einsum('nhld,nhsd->nhls', queries, keys)/math.sqrt(queries.shape[-1])
        attn = torch.softmax(attn, dim = -1)

        x = torch.matmul(attn, values)
        return x
      
class LinearAttention(nn.Module):
    def __init__(self, attention_dropout=0.1, beta = None):
        super(LinearAttention, self).__init__()
        self.dropout = nn.Dropout(attention_dropout)
        self.beta = beta
 
    def kernel_method(self, x):
        return torch.sigmoid(x)
 
    def forward(self, queries, keys, values, beta0, gamma):
        queries = queries.transpose(1, 2) #nhld
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)
         
        mean_k = keys.mean(dim = -2, keepdim = True)
        
        ## Norm then elu
        queries = queries -self.beta* mean_k
        keys = keys - self.beta* mean_k
        
        queries = F.elu(queries) + 1
        keys = F.elu(keys) + 1
        
        attn = torch.einsum('nhld,nhsd->nhls', queries, keys)
        attn = attn/(attn.sum(dim = -1, keepdim = True) + 1e-6)

        x = torch.matmul(attn, values)
#         assert 1==2
        return x
      
class MrAttention(nn.Module):
    def __init__(self, attention_dropout=0.1,beta = None, mode = None):
        super(MrAttention, self).__init__()
        self.dropout = nn.Dropout(attention_dropout)
        self.beta = beta
        self.mode = mode
 
    def kernel_method(self, x):
        return torch.sigmoid(x)
 
    def forward(self, queries, keys, values, beta0 = None, gamma = None):
#         beta = None
        q = queries.transpose(1, 2)
        k = keys.transpose(1, 2)
        v = values.transpose(1, 2)
        if self.mode == 'norm_first':
            k_mean = k.mean(dim = -2, keepdim = True)
            q -= self.beta*k_mean
            k -= self.beta*k_mean
      
        x_ = torch.zeros_like(q)
        for h in range(q.shape[1]):
            if h == 0:
                if self.mode == 'norm_last':
#                     assert 1==2
                    k_mean = k[:, :2].mean(dim = -2, keepdim = True)
                    attn = (q[:, :2]-self.beta*k_mean) @ (k[:, :2]-self.beta*k_mean).transpose(-2, -1)/math.sqrt(queries.shape[-1])
                else:
                    attn = q[:, :2] @ k[:, :2].transpose(-2, -1)/math.sqrt(queries.shape[-1])
                attn = attn.softmax(dim=-1)
                x_[:,:2] = (attn @ v[:,:2])
            elif h == 1:
                continue
            elif h in [2, 3]:
                ks = 2**1
            elif h in [4, 5]:
                ks = 2**2
            elif h in [6, 7]:
                ks = 2**3
            if h in [2,3,4,5,6,7]:
                v_ = F.avg_pool1d(v[:, h].transpose(1, 2), kernel_size = ks, stride = ks, ceil_mode = True).transpose(1,2)
                k_ = F.avg_pool1d(k[:, h].transpose(1, 2), kernel_size = ks, stride = ks, ceil_mode = True).transpose(1,2)
                if self.mode == 'norm_last':
                    k_mean = k_.mean(dim = -2, keepdim = True)
                    attn = (q[:, h]-self.beta*k_mean) @ (k_ -self.beta*k_mean).transpose(-2, -1)/math.sqrt(queries.shape[-1])
                else: 
                    attn = q[:, h] @ k_.transpose(-2, -1)/math.sqrt(queries.shape[-1])
                            
                attn = attn.softmax(dim=-1)
                x_[:,h] = (attn @ v_)

        return x_
      
class MrLinAttention(nn.Module):
    def __init__(self, attention_dropout=0.1,beta = None, mode = None):
        super(MrLinAttention, self).__init__()
        self.dropout = nn.Dropout(attention_dropout)
        self.beta = beta
        self.mode = mode
 
    def kernel_method(self, x):
        return torch.sigmoid(x)
 
    def forward(self, queries, keys, values, beta0 = None, gamma = None):
#         beta = None
        q = queries.transpose(1, 2)
        k = keys.transpose(1, 2)
        v = values.transpose(1, 2)
        if self.mode == 'norm_first':
            k_mean = k.mean(dim = -2, keepdim = True)
            q -= self.beta*k_mean
            k -= self.beta*k_mean
            
      
        x_ = torch.zeros_like(q)
        for h in range(q.shape[1]):
            if h == 0:
                if self.mode == 'norm_last':
                    assert 1==2
                    k_mean = k[:, :2].mean(dim = -2, keepdim = True)
                    attn = (q[:, :2]-self.beta*k_mean) @ (k[:, :2]-self.beta*k_mean).transpose(-2, -1)
                else:
                    attn = (F.elu(q[:, :2]) + 1) @ (F.elu(k[:, :2]) + 1).transpose(-2, -1)
                attn = attn/(attn.sum(dim = -1, keepdim = True) + 1e-6)
                x_[:,:2] = (attn @ v[:,:2])
            elif h == 1:
                continue
            elif h in [2, 3]:
                ks = 2**1
            elif h in [4, 5]:
                ks = 2**2
            elif h in [6, 7]:
                ks = 2**3
            if h in [2,3,4,5,6,7]:
                v_ = F.avg_pool1d(v[:, h].transpose(1, 2), kernel_size = ks, stride = ks, ceil_mode = True).transpose(1,2)
                k_ = F.avg_pool1d(k[:, h].transpose(1, 2), kernel_size = ks, stride = ks, ceil_mode = True).transpose(1,2)
                if self.mode == 'norm_last':
                    assert 1==2
                    k_mean = k_.mean(dim = -2, keepdim = True)
                    attn = (q[:, h]-self.beta*k_mean) @ (k_ -self.beta*k_mean).transpose(-2, -1)
                else: 
                    attn = (F.elu(q[:, h]) + 1) @ (F.elu(k_) + 1).transpose(-2, -1)
                attn = attn/(attn.sum(dim = -1, keepdim = True) + 1e-6)
                            
                x_[:,h] = (attn @ v_)
  
#         import pdb;pdb.set_trace()

        return x_
      
class AttentionLayer(nn.Module):
    def __init__(self, attention, d_model, n_heads, d_keys=None,
                 d_values=None):
        super(AttentionLayer, self).__init__()
        d_keys = d_keys or (d_model // n_heads)
        d_values = d_values or (d_model // n_heads)
        self.inner_attention = attention
        self.query_projection = nn.Linear(d_model, d_keys * n_heads)
        self.key_projection = nn.Linear(d_model, d_keys * n_heads)
        self.value_projection = nn.Linear(d_model, d_values * n_heads)
        self.out_projection = nn.Linear(d_values * n_heads, d_model)
        self.n_heads = n_heads
        
 
        #conv1d
#         self.kernel_size = 3
#         self.conv_weight = nn.Parameter(torch.ones(self.n_heads, 1,self.kernel_size, 1), requires_grad = True)
#         self.conv_weight_mask = torch.eye(self.kernel_size)[None, None, :, :]
        self.beta0 = nn.Parameter(torch.zeros(n_heads, 1, d_model // n_heads).float(), requires_grad = True)
        self.gamma = nn.Parameter(torch.ones(n_heads, 1, d_model // n_heads).float(), requires_grad = True)
 
    def forward(self, queries, keys, values):
        B, L, _ = queries.shape
        _, S, _ = keys.shape
        H = self.n_heads
 
        queries = self.query_projection(queries).view(B, L, H, -1)
        keys = self.key_projection(keys).view(B, S, H, -1)
        values = self.value_projection(values).view(B, S, H, -1)
 
        out = self.inner_attention(
            queries,
            keys,
            values,
            self.beta0,
            self.gamma
        )
        out = out.reshape(B, L, -1)
 
        return self.out_projection(out)
 

