import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

import dgl
import dgl.function as fn
import numpy as np

from model.GT_KAN.efficient_kan import make_kans

"""
    Graph Transformer Layer
    
"""

"""
    Util functions
"""
def src_dot_dst(src_field, dst_field, out_field):
    def func(edges):
        return {out_field: (edges.src[src_field] * edges.dst[dst_field]).sum(-1, keepdim=True)}
    return func

def scaled_exp(field, scale_constant):
    def func(edges):
        # clamp for softmax numerical stability
        return {field: torch.exp((edges.data[field] / scale_constant).clamp(-5, 5))}

    return func


"""
    Single Attention Head
"""

class MultiHeadAttentionLayer(nn.Module):
    def __init__(self, in_dim, out_dim, num_heads, use_bias):
        super().__init__()
        
        self.out_dim = out_dim
        self.num_heads = num_heads
        
        if use_bias:
            self.Q = nn.Linear(in_dim, out_dim * num_heads, bias=True)
            self.K = nn.Linear(in_dim, out_dim * num_heads, bias=True)
            self.V = nn.Linear(in_dim, out_dim * num_heads, bias=True)
        else:
            self.Q = nn.Linear(in_dim, out_dim * num_heads, bias=False)
            self.K = nn.Linear(in_dim, out_dim * num_heads, bias=False)
            self.V = nn.Linear(in_dim, out_dim * num_heads, bias=False)
        
        # self._reset_parameters()  #optional

    def _reset_parameters(self):
        # Initialize parameters with xavier initialization
        init.xavier_uniform_(self.Q.weight)
        init.xavier_uniform_(self.K.weight)
        init.xavier_uniform_(self.V.weight)
        if self.Q.bias is not None:
            init.zeros_(self.Q.bias)
            init.zeros_(self.K.bias)
            init.zeros_(self.V.bias)
    
    def propagate_attention(self, g):
        # Compute attention score
        g.apply_edges(src_dot_dst('K_h', 'Q_h', 'score')) #, edges)
        g.apply_edges(scaled_exp('score', np.sqrt(self.out_dim)))

        # Send weighted values to target nodes
        eids = g.edges()
        g.send_and_recv(eids, fn.u_mul_e('V_h', 'score', 'V_h'), fn.sum('V_h', 'wV'))
        g.send_and_recv(eids, fn.copy_e('score', 'score'), fn.sum('score', 'z'))

        g.ndata['z'][g.ndata['z'] == 0] = 1 # avoid dividing by 0
    
    def forward(self, g, h):
        
        Q_h = self.Q(h)
        K_h = self.K(h)
        V_h = self.V(h)
        
        # Reshaping into [num_nodes, num_heads, feat_dim] to 
        # get projections for multi-head attention
        g.ndata['Q_h'] = Q_h.view(-1, self.num_heads, self.out_dim)
        g.ndata['K_h'] = K_h.view(-1, self.num_heads, self.out_dim)
        g.ndata['V_h'] = V_h.view(-1, self.num_heads, self.out_dim)
        
        self.propagate_attention(g)
        
        head_out = g.ndata['wV']/g.ndata['z']
        
        return head_out

class MultiHeadAttentionLayer_KAN(nn.Module):
    def __init__(self, in_dim, out_dim, num_heads, use_bias, spline_order, grid_size, hidden_layers):
        super().__init__()
        
        self.out_dim = out_dim
        self.num_heads = num_heads
        
        if use_bias:
            self.Q = nn.Linear(in_dim, out_dim * num_heads, bias=True)
            self.K = make_kans(in_dim, out_dim, out_dim, num_heads, hidden_layers, grid_size, spline_order)
            self.V = nn.Linear(in_dim, out_dim * num_heads, bias=True)
        else:
            self.Q = nn.Linear(in_dim, out_dim * num_heads, bias=False)
            self.K = make_kans(in_dim, out_dim, out_dim, num_heads, hidden_layers, grid_size, spline_order)
            self.V = nn.Linear(in_dim, out_dim * num_heads, bias=False)

        # self._reset_parameters() #optional

    def _reset_parameters(self):
        # Initialize parameters with xavier initialization
        init.xavier_uniform_(self.Q.weight)
        if isinstance(self.K, nn.Linear):  # If K is a simple linear layer
            init.xavier_uniform_(self.K.weight)
        init.xavier_uniform_(self.V.weight)
        if self.Q.bias is not None:
            init.zeros_(self.Q.bias)
            if isinstance(self.K, nn.Linear):  # If K is a simple linear layer
                init.zeros_(self.K.bias)
            init.zeros_(self.V.bias)
        
    
    def propagate_attention(self, g):
        # Compute attention score
        g.apply_edges(src_dot_dst('K_h', 'Q_h', 'score')) #, edges)
        g.apply_edges(scaled_exp('score', np.sqrt(self.out_dim)))

        # Send weighted values to target nodes
        eids = g.edges()
        g.send_and_recv(eids, fn.u_mul_e('V_h', 'score', 'V_h'), fn.sum('V_h', 'wV'))
        g.send_and_recv(eids, fn.copy_e('score', 'score'), fn.sum('score', 'z'))

        g.ndata['z'][g.ndata['z'] == 0] = 1
    
    def forward(self, g, h):
        
        Q_h = self.Q(h)
        h_kan = h.unsqueeze(1).expand(-1, self.num_heads, -1)
        K_h = self.K(h_kan)
        V_h = self.V(h)
        
        # Reshaping into [num_nodes, num_heads, feat_dim] to 
        # get projections for multi-head attention
        g.ndata['Q_h'] = Q_h.view(-1, self.num_heads, self.out_dim)
        g.ndata['K_h'] = K_h.view(-1, self.num_heads, self.out_dim)
        g.ndata['V_h'] = V_h.view(-1, self.num_heads, self.out_dim)
        
        self.propagate_attention(g)
        
        head_out = g.ndata['wV']/g.ndata['z']

        return head_out
    

class GraphTransformerLayer(nn.Module):
    """
        Param: 
    """
    def __init__(self, kind, in_dim, out_dim, num_heads, spline_order, grid_size, hidden_layers, dropout=0.0, layer_norm=False, batch_norm=True, residual=True, use_bias=False):
        super().__init__()

        self.in_channels = in_dim
        self.out_channels = out_dim
        self.num_heads = num_heads
        self.dropout = dropout
        self.residual = residual
        self.layer_norm = layer_norm        
        self.batch_norm = batch_norm
        
        if kind == 'KAA_GT':
            self.attention = MultiHeadAttentionLayer_KAN(in_dim, out_dim//num_heads, num_heads, use_bias, spline_order, grid_size, hidden_layers)
        else:
            self.attention = MultiHeadAttentionLayer(in_dim, out_dim//num_heads, num_heads, use_bias)

        self.O = nn.Linear(out_dim, out_dim)

        if self.layer_norm:
            self.layer_norm1 = nn.LayerNorm(out_dim)
            
        if self.batch_norm:
            self.batch_norm1 = nn.BatchNorm1d(out_dim)
        
        # FFN
        self.FFN_layer1 = nn.Linear(out_dim, out_dim*2)
        self.FFN_layer2 = nn.Linear(out_dim*2, out_dim)

        if self.layer_norm:
            self.layer_norm2 = nn.LayerNorm(out_dim)
            
        if self.batch_norm:
            self.batch_norm2 = nn.BatchNorm1d(out_dim)

        # self._reset_parameters() #optional
        
    def _reset_parameters(self):
        # Initialize parameters with xavier initialization
        init.xavier_uniform_(self.O.weight)
        init.xavier_uniform_(self.FFN_layer1.weight)
        init.xavier_uniform_(self.FFN_layer2.weight)
        init.zeros_(self.O.bias)
        init.zeros_(self.FFN_layer1.bias)
        init.zeros_(self.FFN_layer2.bias)
        
    def forward(self, g, h):
        h_in1 = h # for first residual connection
        
        # multi-head attention out
        attn_out = self.attention(g, h)
        h = attn_out.view(-1, self.out_channels)
        h = F.dropout(h, self.dropout, training=self.training)

        h = self.O(h)

        if self.residual:
            h = h_in1 + h # residual connection
        
        if self.layer_norm:
            h = self.layer_norm1(h)
            
        if self.batch_norm:
            h = self.batch_norm1(h)
        
        h_in2 = h # for second residual connection

        # FFN
        h = self.FFN_layer1(h)
        h = F.relu(h)
        h = F.dropout(h, self.dropout, training=self.training)
        h = self.FFN_layer2(h)

        if self.residual:
            h = h_in2 + h # residual connection
        
        if self.layer_norm:
            h = self.layer_norm2(h)
            
        if self.batch_norm:
            h = self.batch_norm2(h)
        
        return h
        
    def __repr__(self):
        return '{}(in_channels={}, out_channels={}, heads={}, residual={})'.format(self.__class__.__name__,
                                             self.in_channels,
                                             self.out_channels, self.num_heads, self.residual)