import torch
import numpy as np
import torch.nn as nn
from torch.nn import Parameter
import torch.nn.functional as F

from script.models.SGFormer.layers import TransConv, GraphConv


class SGFormer(nn.Module):
    def __init__(self, args, gnn_num_layers=4, gnn_dropout=0.0, gnn_use_weight=False, 
                 gnn_use_init=False, gnn_use_bn=False, gnn_use_residual=True, gnn_use_act=True,
                 use_graph=True, graph_weight=0.7, aggregate='add'):
        super(SGFormer, self).__init__()

        self.in_channels = args.nfeat
        self.hidden_channels = args.nhid
        self.out_channels = args.nout
        self.num_layers = args.trans_num_layers
        self.num_heads = args.trans_num_heads
        self.trans_use_bn = args.trans_use_bn
        self.trans_use_residual = args.trans_use_residual
        self.trans_use_weight = args.trans_use_weight
        self.trans_use_act = args.trans_use_act
        self.dropout = args.dropout

        self.trans_conv = TransConv(self.in_channels, self.hidden_channels, self.num_layers, self.num_heads, self.dropout, self.trans_use_bn, self.trans_use_residual, self.trans_use_weight, self.trans_use_act)
        self.graph_conv = GraphConv(self.in_channels, self.hidden_channels, gnn_num_layers, gnn_dropout, gnn_use_bn, gnn_use_residual, gnn_use_weight, gnn_use_init, gnn_use_act)
        self.use_graph = use_graph
        self.graph_weight = graph_weight

        self.aggregate = aggregate

        if aggregate == 'add':
            self.fc = nn.Linear(self.hidden_channels, self.out_channels)
        elif aggregate == 'cat':
            self.fc = nn.Linear(2 * self.hidden_channels, self.out_channels)
        else:
            raise ValueError(f'Invalid aggregate type:{aggregate}')

        self.params1 = list(self.trans_conv.parameters())
        self.params2 = list(self.graph_conv.parameters()) if self.graph_conv is not None else []
        self.params2.extend(list(self.fc.parameters()))
        self.feat = Parameter((torch.ones(args.num_nodes, self.in_channels)), requires_grad=True)

    def forward(self, edge_index, x=None):
        if x is None:
            x = self.feat
        x1 = self.trans_conv(x)
        if self.use_graph:
            x2 = self.graph_conv(x, edge_index)
            if self.aggregate == 'add':
                x = self.graph_weight * x2 + (1 - self.graph_weight) * x1
            else:
                x = torch.cat((x1, x2), dim=1)
        else:
            x = x1
        x = self.fc(x)
        return x
    
    def get_attentions(self, x):
        attns = self.trans_conv.get_attentions(x) # [layer num, N, N]

        return attns

    def reset_parameters(self):
        self.trans_conv.reset_parameters()
        if self.use_graph:
            self.graph_conv.reset_parameters()
    
    def decoding_lp(self, z, edge_index):
        edge_i = edge_index[0]
        edge_j = edge_index[1]
        z_i = torch.nn.functional.embedding(edge_i, z)
        z_j = torch.nn.functional.embedding(edge_j, z)
        dist = (z_i * z_j).sum(dim=1)
        return dist
