import torch.nn as nn
import torch.nn.functional as F
from layers import GraphConvolution,MLP
import torch

class GCN(nn.Module):
    def __init__(self, nfeat, nhid, nclass, dropout, hops,features, model_type):
        super(GCN, self).__init__()
        self.gcns, self.mlps = nn.ModuleList(),nn.ModuleList()
        self.model_type = model_type
        self.bn = nn.BatchNorm1d(nhid)
        if self.model_type =='mlp':
            self.gcns.append(GraphConvolution(nfeat, nhid, model_type = model_type))
            self.gcns.append(GraphConvolution(nhid, nclass, model_type = model_type, output_layer=1))
        elif self.model_type =='gcn' or self.model_type =='mfgcn':
            self.gcns.append(GraphConvolution(nfeat, nhid,  model_type = model_type))
            self.gcns.append(GraphConvolution(nhid, nclass,  model_type = model_type, output_layer=1))
        elif self.model_type =='sgc' or self.model_type =='mfsgc':
            self.gcns.append(GraphConvolution(nfeat, nclass, model_type = model_type))
        self.dropout = dropout
    

    def forward(self, x, adj_low, adj_high):
        if self.model_type =='mfgcn' or self.model_type =='mfsgc':
            x = F.dropout(x, self.dropout, training=self.training)
            
        fea = (self.gcns[0](x, adj_low, adj_high)) #
        
        if self.model_type =='gcn' or  self.model_type =='mlp': 
            fea = F.dropout(F.relu(fea), self.dropout, training=self.training)
            fea = self.gcns[-1](fea, adj_low, adj_high)
        elif self.model_type =='mfgcn':
            fea = F.dropout(F.relu(fea), self.dropout, training=self.training)
            fea = self.gcns[-1](fea, adj_low, adj_high)
        elif self.model_type =='sgc' or self.model_type =='mfsgc':
            pass
        return fea
