""" GraphGPS """

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch_geometric.nn import GCNConv

class SelfAttention(nn.Module):
    def __init__(self, dim, heads=8, dropout=0.1, causal=False):
        super(SelfAttention, self).__init__()
        self.dim = dim
        self.heads = heads
        self.head_dim = dim // heads
        assert dim % heads == 0

        # Linear projections for Q, K, V
        self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
        
        # Final linear projection after concatenation of heads
        self.out_proj = nn.Linear(dim, dim)

        # Dropout
        self.dropout = nn.Dropout(dropout)

        # Causal mask
        self.causal = causal

    def forward(self, x, mask=None):
        B, N, D = x.shape  # B = batch size, N = sequence length, D = embedding dimension
        
        # Compute Q, K, V
        qkv = self.to_qkv(x).chunk(3, dim=-1)  # Split Q, K, V
        q, k, v = map(lambda t: t.reshape(B, N, self.heads, self.head_dim).transpose(1, 2), qkv)
        # Shape after reshape: (B, heads, seq_len, head_dim)

        # Compute attention scores (scaled dot-product attention)
        attn_scores = torch.einsum('bhnd,bhmd->bhnm', q, k)  # (B, heads, seq_len, seq_len)
        attn_scores /= self.head_dim ** 0.5

        # Apply mask if provided (e.g., padding mask or causal mask)
        if mask is not None:
            attn_scores.masked_fill_(mask == 0, float('-inf'))

        # Apply causal mask if required
        if self.causal:
            causal_mask = torch.tril(torch.ones(N, N, device=x.device)).bool()
            attn_scores.masked_fill_(~causal_mask, float('-inf'))

        # Apply softmax to get attention weights
        attn_weights = F.softmax(attn_scores, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Weighted sum of values
        attn_output = torch.einsum('bhnm,bhmd->bhnd', attn_weights, v)

        # Merge heads and project back
        attn_output = attn_output.transpose(1, 2).reshape(B, N, D)
        return self.out_proj(attn_output)

class GPSLayer(nn.Module):
    """Local MPNN + full graph attention x-former layer.
    """

    def __init__(self, in_channels,
                 num_heads,
                 dropout=0.0,
                 attn_dropout=0.0, use_bn=True):
        super(GPSLayer, self).__init__()

        self.dim_h = in_channels
        self.num_heads = num_heads
        self.attn_dropout = attn_dropout
        self.batch_norm = use_bn

        # Local message-passing model.
        self.local_model=GCNConv(in_channels,in_channels)

        # Global attention transformer-style model.
        self.self_attn = SelfAttention(
            dim=in_channels, heads=num_heads,
            dropout=self.attn_dropout, causal=False)

        # Normalization for MPNN and Self-Attention representations.
        if self.batch_norm:
            self.norm1_local = nn.BatchNorm1d(in_channels)
            self.norm1_attn = nn.BatchNorm1d(in_channels)
        self.dropout_local = nn.Dropout(dropout)
        self.dropout_attn = nn.Dropout(dropout)

        # Feed Forward block.
        self.activation = F.relu
        self.ff_linear1 = nn.Linear(in_channels, in_channels * 2)
        self.ff_linear2 = nn.Linear(in_channels * 2, in_channels)
        if self.batch_norm:
            self.norm2 = nn.BatchNorm1d(in_channels)
        self.ff_dropout1 = nn.Dropout(dropout)
        self.ff_dropout2 = nn.Dropout(dropout)

        self.device=None

    def reset_parameters(self):
        for child in self.children():
            # print(child.__class__.__name__)
            classname=child.__class__.__name__
            if classname not in ['SelfAttention','Dropout']:
                child.reset_parameters()
        
        if self.device is None:
            param=next(iter(self.local_model.parameters()))
            self.device=param.device

        self.self_attn=SelfAttention(
            dim=self.dim_h, heads=self.num_heads,
            dropout=self.attn_dropout, causal=False).to(self.device)

    def forward(self, x, edge_index):
        h_in1 = x  # for first residual connection, x has shape (n, in_channels)

        h_out_list = []
        # Local MPNN with edge attributes.
        h_local=self.local_model(x,edge_index)
        h_local=h_in1+h_local # Residual connection.

        if self.batch_norm:
            h_local=self.norm1_local(h_local)
        h_out_list.append(h_local)

        h_attn=self.self_attn(x.unsqueeze(0)) # (1, n, in_channels)
        h_attn=h_attn.squeeze(0) # (n, in_channels)

        h_attn = self.dropout_attn(h_attn)
        h_attn = h_in1 + h_attn  # Residual connection.
        if self.batch_norm:
            h_attn = self.norm1_attn(h_attn)
        h_out_list.append(h_attn)

        # Combine local and global outputs.
        # h = torch.cat(h_out_list, dim=-1)
        h = sum(h_out_list)

        # Feed Forward block.
        h = h + self._ff_block(h)
        if self.batch_norm:
            h = self.norm2(h)

        return h

    def _ff_block(self, x):
        """Feed Forward block.
        """
        x = self.ff_dropout1(self.activation(self.ff_linear1(x)))
        return self.ff_dropout2(self.ff_linear2(x))