import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
import dgl
import dgl.function as fn


class EGTLayer(nn.Module):
    def __init__(self, hidden_channels=64, num_heads=4, activation=F.gelu, dropout=0.1):
        super().__init__()
        self.hidden_channels = hidden_channels
        self.num_heads = num_heads
        self.head_dim = hidden_channels // num_heads
        self.scaling = self.head_dim**-0.5

        self.q_proj = nn.Linear(hidden_channels, hidden_channels)
        self.k_proj = nn.Linear(hidden_channels, hidden_channels)
        self.v_proj = nn.Linear(hidden_channels, hidden_channels)
        self.norm = nn.LayerNorm(hidden_channels)
        self.norm2 = nn.LayerNorm(hidden_channels)
        self.FFN1 = nn.Linear(hidden_channels, hidden_channels * 4)
        self.FFN2 = nn.Linear(hidden_channels * 4, hidden_channels)
        self.activation = activation
        self.dropout = nn.Dropout(dropout)
        self.eps = nn.Parameter(torch.zeros(1,))
        self.eps2 = nn.Parameter(torch.zeros(1,))

    def forward(self, graph: dgl.DGLGraph, feat, edge_weight=None):
        with graph.local_scope():
            src_prefix_shape = dst_prefix_shape = feat.shape[:-1]
            feat_Q = self.q_proj(feat).view(*src_prefix_shape, self.num_heads, self.head_dim) * self.scaling
            feat_K = self.k_proj(feat).view(*src_prefix_shape, self.num_heads, self.head_dim)
            feat_V = self.v_proj(feat).view(*src_prefix_shape, self.num_heads, self.head_dim)
            if graph.is_block:
                feat_Q = feat_Q[:graph.number_of_dst_nodes()]
                dst_prefix_shape = (graph.number_of_dst_nodes(),) + dst_prefix_shape[1:]

            graph.srcdata.update({'K': feat_K, 'V': feat_V})
            graph.dstdata.update({'Q': feat_Q})
            graph.apply_edges(fn.u_dot_v('K', 'Q', 'score'))
            score = graph.edata.pop('score')
            graph.edata['attn'] = torch.sigmoid(score)
            if edge_weight is not None:
                graph.edata['attn'] = graph.edata['attn'] * edge_weight.unsqueeze(-1)
            # message passing
            graph.update_all(fn.e_mul_u('attn', 'V', 'm'), fn.sum('m', 'out'))
            out = graph.dstdata['out'].reshape(feat_Q.shape[0], -1)
            out = self.norm((1 + self.eps) * feat[:graph.number_of_dst_nodes()] + out)
            out = self.dropout(out)
            out = self.FFN2(self.dropout(self.activation(self.FFN1(out))))
            out = self.norm2((1 + self.eps2) * feat[:graph.number_of_dst_nodes()] + out)
            return out


class EGT(nn.Module):
    def __init__(self, in_channels, edge_channels, hidden_channels, num_class, num_heads=16, num_layers=2, dropout=0.0):
        super().__init__()
        self.hidden_channels = hidden_channels
        self.convs = nn.ModuleList()
        for i in range(num_layers):
            self.convs.append(EGTLayer(hidden_channels, num_heads, F.gelu, dropout))
        self.node_embed = nn.Linear(in_channels, hidden_channels)
        self.predict = nn.Linear(hidden_channels, num_class)

    def forward(self, blocks, x, return_h=False):
        h = self.node_embed(x)
        for conv, block in zip(self.convs, blocks):
            h = conv(block, h)
        out = self.predict(h)
        if return_h:
            return out, h
        return out
