import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from functools import partial
from einops import rearrange, repeat
import warnings
warnings.filterwarnings("ignore")

from .polyconv import PolyConvFrame
    
## SingularValueDecomposition attention
class SVDAttention(nn.Module):
    def __init__(self, attention_dropout, n_heads, d_model, poly_type= "jacobi", 
                 K=5, fixI=True):
        super(SVDAttention, self).__init__()
        self.dropout = nn.Dropout(attention_dropout)
        self.coeff_list = [(-0.5,-0.5), (0.0,0.0), (0.5,0.5), (1.0,1.0), (-0.8,-0.8), (2.0,2.0)][:n_heads]
        self.poly = nn.ModuleList([PolyConvFrame(conv_fn_type=poly_type, depth=K, alpha=alpha, beta=beta, fixI=fixI) for alpha,beta in self.coeff_list])
        self.d_model = d_model
        self.n_heads = n_heads

    def forward(self, U, Sigma, V, values):
        # U : [batch, seq len, nhead, d_model]
        B,L,H,_ = U.size()
        _,S,_,_ = V.size()
        U, V_t, values, Sigma = U.permute(0,2,1,3), V.permute(0,2,3,1), values.permute(0,2,1,3), Sigma.permute(0,2,1,3)
        U, V_t, values = U.reshape(B*H,L,self.d_model), V_t.reshape(B*H,self.d_model,S), values.reshape(B*H,S,self.d_model)
        # U, values : [batch*nhead, seq len, d_model] / V_t : [batch*nhead, d_model, seq len]
        U = nn.functional.softmax(U, dim=-1)
        V_t = nn.functional.softmax(V_t, dim=-1)
        
        Sigma = nn.functional.sigmoid(Sigma)
        graph_filter = torch.stack([self.poly[i](Sigma[:, i:i+1]) for i in range(self.n_heads)],dim=1).reshape(B*H,L,self.d_model)
        
        out = torch.bmm(U * graph_filter, torch.bmm(V_t, values)).contiguous()
        # output : [batch*nhead, seq len, d_model]
        out = out.reshape(B, H, L, self.d_model).permute(0,2,1,3)
        
        # reg is eta in loss.py
        sym = torch.bmm(U.permute(0,2,1), U)
        I = torch.eye(self.d_model, device=sym.device).expand(B*H,self.d_model, self.d_model)
        sym = sym - I
        ortho_loss = torch.mean(torch.abs(sym), dim=[-1,-2])
        
        sym = torch.bmm(V_t, V_t.permute(0,2,1))
        sym = sym - I
        ortho_loss = ortho_loss + torch.mean(torch.abs(sym), dim=[-1,-2])
        
        return out, ortho_loss


## Attention layer
class AttentionLayer(nn.Module):
    def __init__(self, d_model, n_heads=5, attention_dropout=0.1, 
                 poly_type= "jacobi", K=5, alpha=2.0, beta=-1.0, fixI=True):
        super(AttentionLayer, self).__init__()
        self.query_projection = nn.Linear(d_model, d_model*n_heads)
        self.key_projection   = nn.Linear(d_model, d_model*n_heads)
        self.value_projection = nn.Linear(d_model, d_model*n_heads)
        self.Sigma_projection = nn.Linear(d_model, d_model*n_heads)

        self.out_projection = nn.Linear(d_model*n_heads, d_model)
        self.n_heads = n_heads
        self.d_model = d_model
        self.inner_attention = SVDAttention(attention_dropout=attention_dropout, n_heads=n_heads, 
                                            d_model=d_model, poly_type=poly_type, K=K, fixI=fixI)

    def forward(self, q, kv):
        B, L, _  = q.shape
        _, S, _ = kv.shape
        H = self.n_heads

        queries = self.query_projection(q).reshape(B, L, H, self.d_model)
        Sigma = self.Sigma_projection(q).reshape(B, L, H, self.d_model)
        
        keys = self.key_projection(kv).reshape(B, S, H, self.d_model)
        values = self.value_projection(kv).reshape(B, S, H, self.d_model)
        
        out, ortho_loss = self.inner_attention(queries, Sigma, keys, values)
        out = out.reshape(B, L, self.d_model*H)
        return self.out_projection(out), ortho_loss