import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from model.layer import TransConvLayer
from torch.nn import Module, Parameter, Linear, Sequential, LogSoftmax

class UGCFormer(Module):
    def __init__(self, n, nclass, nfeat, nhidden, nlayer,
                 dropout, activation,
                 weight_attr=0.1,
                 num_heads=4, 
                 use_weight=True,
                 aggregate='add'):
        super(UGCFormer, self).__init__()
        self.n = n
        self.activation = activation
        self.nlayer = nlayer
        self.nhidden = nhidden
        self.dropout = dropout
        self.weight_attr = weight_attr
        self.use_bn = True
        self.use_act = True
        self.mlp_a = nn.Linear(n, nhidden)
        self.mlp_x = nn.Linear(nfeat, nhidden)
        self.bns_x = nn.ModuleList()
        self.bns_a = nn.ModuleList()
        self.convs_x = nn.ModuleList()
        self.convs_a = nn.ModuleList()

        for _ in range(nlayer):
            self.convs_x.append(
                TransConvLayer(nhidden, nhidden,
                               num_heads=num_heads,
                               dropout=dropout,
                               use_weight=use_weight))
            self.convs_a.append(
                TransConvLayer(nhidden, nhidden,
                               num_heads=num_heads,
                               dropout=dropout,
                               use_weight=use_weight))
            self.bns_x.append(nn.LayerNorm(nhidden))
            self.bns_a.append(nn.LayerNorm(nhidden))

        self.weight_z = Parameter(torch.empty(size=(nlayer,1)))
        self.weight_y = Parameter(torch.empty(size=(nlayer,1)))
        
        self.mlp = Sequential(Linear(nhidden, nclass),
                                    LogSoftmax(dim=-1))
        self.mlp_fin_a = Sequential(Linear(nhidden, nclass),
                              LogSoftmax(dim=-1))
        self.mlp_fin_y = Sequential(Linear(nhidden, nclass),
                              LogSoftmax(dim=-1))
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.zeros_(self.weight_z)
        nn.init.zeros_(self.weight_y)
        self.mlp_a.reset_parameters()
        self.mlp_x.reset_parameters()
        for conv_ in self.convs_x:
            conv_.reset_parameters()
        for conv_ in self.convs_a:
            conv_.reset_parameters()
        for bn_ in self.bns_x:
            bn_.reset_parameters()
        for bn_ in self.bns_a:
            bn_.reset_parameters()


    def get_TopoRepresentation(self, y, z, adj, layer_id):

        former_z, v = self.convs_a[layer_id](y, z)
        if self.use_bn:
            former_z = self.bns_a[layer_id](former_z)
        if self.use_act:
            former_z = self.activation(former_z)
        gnn_z = torch.spmm(adj, v)
        #gnn_z = F.normalize(gnn_z)
        #former_z = F.normalize(former_z)
        weight_ = torch.tanh(self.weight_z[layer_id])
        rep_topology = (1 - weight_) * gnn_z + weight_ * former_z

        return rep_topology

    def get_AttriRepresentation(self, y, z, feat, layer_id):
        former_y, _ = self.convs_x[layer_id](z, y)
        if self.use_bn:
            former_y = self.bns_x[layer_id](former_y)
        if self.use_act:
            former_y = self.activation(former_y)
        #feat = F.normalize(feat)
        #former_y = F.normalize(former_y)
        weight_ = torch.tanh(self.weight_y[layer_id])
        rep_attribute = (1 - weight_)* feat + weight_ * former_y

        return rep_attribute

    def forward(self, feat, topo, adj_normal):

        z: Tensor = self.mlp_a(topo)
        y: Tensor = self.mlp_x(feat)
        feat = y
        z = F.dropout(z, self.dropout, training=self.training)
        y = F.dropout(y, self.dropout, training=self.training)
        topology_ = adj_normal
        
        for layer_id in range(self.nlayer):
            z_n = self.get_TopoRepresentation(y, z, topology_, layer_id)
            y_n = self.get_AttriRepresentation(y, z, feat, layer_id)
            y = y_n
            z = z_n
            y = F.normalize(y)
            z = F.normalize(z)
            
        y = F.dropout(y, self.dropout, training=self.training)
        z = F.dropout(z, self.dropout, training=self.training)
        outputs_z = self.mlp_fin_a(z)
        outputs_y = self.mlp_fin_y(y)
        output = self.weight_attr * outputs_y + (1 - self.weight_attr) * outputs_z
        return outputs_z, outputs_y, output
