import torch
from torch import nn
from torch.nn import functional as F


class LinearAttention(nn.Module):
    def __init__(self):
        super().__init__()
        self.eps = 1e-5

    def forward(self, Q, K, V, mask):
        
        Q = (nn.functional.elu(Q) + 1)
        K = (nn.functional.elu(K) + 1) * mask[None, :, None]
        V = V * mask[None, :, None]
        
        K_cumsum = K.sum(dim=-2) # [B, H, D]
        D = torch.einsum('...nd,...d->...n', Q, K_cumsum.type_as(Q))
        D_inv = 1./D.masked_fill_(D == 0, self.eps)
        context = torch.einsum('...nd,...ne->...de', K, V)
        out = torch.einsum('...de,...nd,...n->...ne', context, Q, D_inv)
        return out


class AttentionLayer(nn.Module):
    def __init__(self, dim, head_dim, num_heads, attn_type="vanilla", attn_drop=0.):
        super().__init__()
        
        self.dim = dim
        self.head_dim = head_dim
        self.num_heads = num_heads
        self.attn_type = attn_type
        self.attn_drop = attn_drop
        
        self.W_q = nn.Linear(self.dim, self.num_heads * self.head_dim)
        self.W_k = nn.Linear(self.dim, self.num_heads * self.head_dim)
        self.W_v = nn.Linear(self.dim, self.num_heads * self.head_dim)
        
        # if self.attn_type == "vanilla":
        #     self.attn = VanillaAttention(self.head_dim, self.attn_drop)
        # elif self.attn_type == "linear":
        self.attn = LinearAttention()
        # else:
        #     raise NotImplementedError
        
        self.ff = nn.Linear(self.num_heads * self.head_dim, self.dim)
        
    def forward(self, input):
        
        X, mask = input
        
        Q = self.split_heads(self.W_q(X)) # [B, H, N, D]
        K = self.split_heads(self.W_k(X))
        V = self.split_heads(self.W_v(X))
        
        attn_out = self.attn(Q, K, V, mask)
        attn_out = self.combine_heads(attn_out)
        
        out = self.ff(attn_out)
        
        return out
    
    def combine_heads(self, X):
        X = X.transpose(0, 1) # [N, H, D]
        X = X.reshape(X.size(0), self.num_heads * self.head_dim) # [N, H*D]
        return X
    
    def split_heads(self, X):
        X = X.reshape(X.size(0), self.num_heads, self.head_dim) # [N, H, D]
        X = X.transpose(0, 1) # [H, N, D]
        return X
  

class Transformer(nn.Module):
    def __init__(self, args):
        
        super(Transformer, self).__init__()
        assert args.num_layers > 0
        
        
        self.mha = AttentionLayer(
            args.dim,
            args.head_dim, 
            args.num_heads,
            attn_type=args.attn_type, 
            attn_drop=args.attn_dropout
        )
        self.dropout1 = torch.nn.Dropout(p=args.dropout)
        
        self.act = getattr(nn, args.act)()
        
        if args.layer_norm:
            self.norm1 = nn.LayerNorm(args.dim)
            self.norm2 = nn.LayerNorm(args.dim)
        else:
            self.norm1 = nn.Identity()
            self.norm2 = nn.Identity()
        
        self.mlpblock = nn.Sequential(
            nn.Linear(args.dim, args.dim),
            self.act,
            torch.nn.Dropout(p = args.dropout),
            nn.Linear(args.dim, args.dim),
            torch.nn.Dropout(p = args.dropout)
        )
        
    def forward(self, X, mask):
        X = self.dropout1(self.mha((self.norm1(X), mask))) + X
        X = self.mlpblock(self.norm2(X)) + X
        return X


class Encoder(nn.Module):
    """
    Tokenized Graph Transformer.
    """
    def __init__(self, args):
        super().__init__()
        assert args.num_layers > 0
        
        self.num_layers = args.num_layers
        self.tied_weights = args.tied_weights
        self.attn_type = args.attn_type
        self.feat_dim = args.feat_dim
        self.dim = args.dim
        self.lap_k = args.lap_k
        self.lap_dropout = args.lap_dropout
        self.lap_sign_flip = args.lap_sign_flip
        self.num_heads = args.num_heads
        self.hops = args.hops
        self.task = args.task
        self.dropout = args.dropout
        self.act = getattr(nn, args.act)()
        
        self.hop_mixer = nn.Linear(self.hops + 1, 1, bias=False)
        self.input_encoder = nn.Linear(self.feat_dim, self.dim, bias=True)
        
        self.lap_encoder = nn.Sequential(
            nn.Linear(4*self.lap_k, self.dim, bias=False),
            self.act,
            nn.Linear(self.dim, self.dim, bias=False)
        )
        
        self.lap_eig_dropout = nn.Dropout(p=self.lap_dropout) if self.lap_dropout > 0 else None
        self.order_embed = nn.Embedding(2, self.dim)
        self.embed_dropout = torch.nn.Dropout(p = self.dropout)
        
        if self.tied_weights:
            self.layers = nn.Sequential(*[Transformer(args)])
        else:
            layers = []
            for _ in range(self.num_layers):
                layers.append(Transformer(args))
            self.layers = nn.Sequential(*layers)
            
        if args.layer_norm:
            self.norm = nn.LayerNorm(args.dim)
        else:
            self.norm = nn.Identity()

        if args.task == 'nc':
            self.cls = nn.Sequential(
                nn.Linear(self.dim, args.n_classes, args.bias),
                nn.Dropout(args.dropout)
            )
            # self.cls = nn.Linear(self.dim, args.n_classes, args.dropout, lambda x: x, args.bias)
        
    @staticmethod
    @torch.no_grad()
    def get_random_sign_flip(eigvec):
        K = eigvec.shape[-1]
        s = torch.rand(K,device=eigvec.device)
        s[s >= 0.5] = 1.0
        s[s < 0.5] = -1.0
        return eigvec * s

    def logit(self, embeddings):
        return F.log_softmax(self.cls(embeddings), dim=1)

    def sqdist(self, x_1, x_2):
        return (x_1 - x_2).pow(2).sum(dim=-1)
            
    def forward(self, data):
        x, adj, lap_eigvec, edge_index = data['features'], data['adj_train_norm'], data['lap_eigvec'], data['edge_index']
        N = adj.shape[0]
        M = edge_index.shape[-1]
        device = x.device
        
        node_feat = torch.matmul(x.transpose(-1, -2), F.softmax(self.hop_mixer.weight, dim=-1).t()).squeeze()
        node_feat = self.input_encoder(node_feat)
        node_id = torch.matmul(lap_eigvec.transpose(-1, -2), F.softmax(self.hop_mixer.weight, dim=-1).t()).squeeze()
        
        if self.lap_sign_flip and self.training:
            node_id = self.get_random_sign_flip(node_id)
        
        if self.lap_eig_dropout is not None:
            node_id = self.lap_eig_dropout(node_id)
        
        # Step 1: Laplacian positional encodings
        src_id = node_id[edge_index[0]]
        dst_id = node_id[edge_index[1]]
        pos_id = torch.cat(
            (torch.cat((node_id, node_id), dim=1), torch.cat((src_id, dst_id), dim=1)),
            dim=0
        )
        tokens = self.lap_encoder(pos_id)
        
        # Step 3: Add order embeddings
        tokens = tokens + self.order_embed(torch.tensor([0] * N + [1] * M, device=device))
        
        # Step 4: add node features
        tokens[:N] = tokens[:N] + node_feat
        tokens = self.embed_dropout(tokens)
        mask = torch.ones_like(tokens[:,0])
        
        if self.tied_weights:
            for idx in range(self.num_layers):
                tokens = self.layers[0](tokens, mask)
        else:
            for idx in range(self.num_layers):
                tokens = self.layers[idx](tokens, mask)

        tokens = self.norm(tokens)
            
        return tokens[:N]

