# Decoder based on PST.

import torch
import torch.nn as nn
from typing import Literal
from graphgps.encoder.utils import *


class Decoder(nn.Module):
    def __init__(self, X_in_dim: int, 
                 E_in_dim: int,
                 U_in_dim: int,
                 hidden_dim: int,
                 out_dim: int,
                 num_layers: int,
                 dropout: float = 0.0,
                 task: Literal['graph', 'node'] = 'graph'):
        super().__init__()
        self.X_in_dim = X_in_dim
        self.E_in_dim = E_in_dim
        self.U_in_dim = U_in_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.dropout = dropout
        self.out_dim = out_dim
        self.task = task
        
        self.attentions = nn.ModuleList(
            [MLPAttention(X_in_dim + U_in_dim, 
                          hidden_dim, use_layer_norm=False)] +
            [MLPAttention(hidden_dim, hidden_dim) for _ 
             in range(self.num_layers-1)]
        )
        self.vlins = nn.ModuleList(
            [NormedLinear(X_in_dim + U_in_dim, hidden_dim)] + 
            [NormedLinear(hidden_dim, hidden_dim) for _
             in range(self.num_layers-1)]
        )
        self.E_embedder = MLP2(E_in_dim, hidden_dim)
        self.final_encoder = nn.Sequential(
            MLP2(hidden_dim, hidden_dim // 2),
            nn.ELU(), 
            nn.Dropout(dropout),
            nn.Linear(hidden_dim // 2, out_dim)
        )

    def forward(self, X_emb, E_emb, U_emb, mask):
        """
        Input:
            - X_emb: (batch_size, N, dim_1)
            - E_emb: (batch_size, N, N, dim_1)
            - U_emb: (batch_size, N, dim_2)
            - mask:  (batch_size, N)
        """
        scalar = torch.cat([X_emb, U_emb], dim=-1)
        mask_off = ~mask  # non-nodes
        scalar[mask_off] = 0

        for i in range(self.num_layers):
            # score: (batch_size, N, N, hidden_dim)
            score = self.attentions[i](scalar, scalar)
            score = score + self.E_embedder(E_emb)

            # mask off non-nodes
            score[mask_off] = 0
            score = score.permute(0, 2, 1, 3)
            score[mask_off] = 0

            update = torch.einsum("bjid,bjd->bid", score, self.vlins[i](scalar))
            
            if i == 0:
                scalar = update
            else:
                # residual connection
                scalar = scalar + update

        if self.task == 'graph':
            scalar[mask_off] = 0
            scalar = scalar.sum(1)
            return self.final_encoder(scalar) # (batch_size, out_dim)
        
        elif self.task == 'node':
            scalar = self.final_encoder(scalar)
            scalar[mask_off] = 0
            return scalar # (batch_size, N, out_dim)


class MLPAttention(nn.Module):
    def __init__(self, in_channels: int, 
                 out_channels: int,
                 use_layer_norm: bool = True):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.use_layer_norm = use_layer_norm

        self.lins1 = nn.Linear(in_channels, out_channels)
        self.lins2 = nn.Linear(in_channels, out_channels)
        self.mlp = MLP2(out_channels, out_channels)

        if self.use_layer_norm:
            self.layer_norm = nn.LayerNorm(in_channels)
        else:
            self.layer_norm = nn.Identity()

        self.reset_parameters()
    
    def reset_parameters(self):
        self.lins1.reset_parameters()
        self.lins2.reset_parameters()
        self.mlp.reset_parameters()

        if hasattr(self.layer_norm, 'reset_parameters'):
            self.layer_norm.reset_parameters()

    def forward(self, s1, s2):
        # (batch_size, N, out_channels)
        s1_ = self.lins1(self.layer_norm(s1))
        s2_ = self.lins2(self.layer_norm(s2))

        # (batch_size, N, N, out_channels)
        return self.mlp(torch.einsum("bid,bjd->bijd", s1_, s2_))


class MessagePassingDecoder(nn.Module):
    def __init__(self, X_in_dim: int, 
                 E_in_dim: int,
                 U_in_dim: int,
                 hidden_dim: int,
                 out_dim: int,
                 num_layers: int,
                 dropout: float = 0.0,
                 use_bn: bool = True,
                 use_ln: bool = False,
                 task: Literal['graph', 'node'] = 'graph'):
        super().__init__()
        self.X_in_dim = X_in_dim
        self.E_in_dim = E_in_dim
        self.U_in_dim = U_in_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.dropout = dropout
        self.out_dim = out_dim
        self.use_bn = use_bn
        self.use_ln = use_ln
        assert not (self.use_bn and self.use_ln)
        self.task = task
        
        self.mlps = nn.ModuleList(
            [MLP2(X_in_dim + U_in_dim, hidden_dim)] + 
            [MLP2(hidden_dim, hidden_dim) for _ in range(self.num_layers-1)]
        )
        if self.use_bn:
            self.norms = nn.ModuleList(
                [nn.BatchNorm1d(hidden_dim) for _ in range(self.num_layers)]
            )
        if self.use_ln:
            self.norms = nn.ModuleList(
                [nn.LayerNorm(hidden_dim) for _ in range(self.num_layers)]
            )
        self.E_embedder = MLP2(E_in_dim, hidden_dim)
        self.final_encoder = nn.Sequential(
            MLP2(hidden_dim, hidden_dim // 2),
            nn.ELU(), 
            nn.Dropout(dropout),
            nn.Linear(hidden_dim // 2, out_dim)
        )

    def forward(self, X_emb, E_emb, U_emb, mask):
        """
        Input:
            - X_emb: (batch_size, N, dim_1) or (batch_size, N, k, dim_1)
            - E_emb: (batch_size, N, N, dim_1)
            - U_emb: (batch_size, N, dim_2) or None
            - mask:  (batch_size, N)
        """
        if U_emb is None:
            scalar = X_emb
            assert self.U_in_dim == 0
        else:
            scalar = torch.cat([X_emb, U_emb], dim=-1)

        mask_off = ~mask  # non-nodes
        if U_emb is not None:
            scalar[mask_off] = 0
        E_emb_ = self.E_embedder(E_emb) # (batch_size, N, N, hidden_dim)
        
        # mask off non-nodes
        E_emb_[mask_off] = 0
        E_emb_ = E_emb_.permute(0, 2, 1, 3)
        E_emb_[mask_off] = 0

        if U_emb is None:
            assert self.use_bn == False
            assert self.task == 'node'

        for i in range(self.num_layers):
            if U_emb is None:
                update = torch.einsum("bjid,bjkd->bikd", E_emb_, self.mlps[i](scalar))
            else:
                update = torch.einsum("bjid,bjd->bid", E_emb_, self.mlps[i](scalar))

            if self.use_bn:
                update = self.norms[i](update.permute(0, 2, 1)).permute(0, 2, 1)
            elif self.use_ln:
                update = self.norms[i](update)

            if i == 0:
                scalar = update
            else:
                # residual connection
                scalar = scalar + update

        if self.task == 'graph':
            scalar[mask_off] = 0
            scalar = scalar.sum(1)
            return self.final_encoder(scalar) # (batch_size, out_dim)
        
        elif self.task == 'node':
            scalar = self.final_encoder(scalar)
            
            if U_emb is not None:
                scalar[mask_off] = 0

            return scalar # (batch_size, N, out_dim) or (batch_size, N, k, out_dim)
