import torch
import torch.nn as nn
import torch.nn.functional as F

from ..._layers import GCN_layer

global_debug_ckpt = '_GCN'

class GCN(nn.Module):
    def __init__(self, num_layers, ft_in, ft_hid, ft_out, *args, **kwargs):
        super(GCN, self).__init__()
        assert num_layers > 0, 'num of layers should be greater than 0, received {}'.format(num_layers)
        self.num_layers = num_layers
        self.ft_in = ft_in
        self.ft_hid = ft_hid
        self.ft_out = ft_out

        self.dropout = kwargs['dropout'] if 'dropout' in kwargs else .5
        self.act = kwargs['act'] if 'act' in kwargs else F.relu
        self.bias = kwargs['bias'] if 'bias' in kwargs else True

        self.gcns = nn.ModuleList()
        if self.num_layers == 1:
            self.gcns.append(GCN_layer(self.ft_in, self.ft_out, 0., lambda x: x, self.bias))   # no dropout on input, no act on output
        elif self.num_layers > 1:
            self.gcns.append(GCN_layer(self.ft_in, self.ft_hid, 0., self.act, self.bias))   # no dropout on input
            for i in range(self.num_layers - 2):
                self.gcns.append(GCN_layer(self.ft_hid, self.ft_hid, self.dropout, self.act, self.bias))
            self.gcns.append(GCN_layer(self.ft_hid, self.ft_out, self.dropout, lambda x: x, self.bias)) # no act on output
        else:
            raise ValueError("GCN initialization error: num_layers = {}".format(self.num_layers))
        
        self.debug_ckpt = global_debug_ckpt + '.GCN'
    
    def forward(self, adj, fts):
        h = fts
        for layer_idx in range(self.num_layers):
            h = self.gcns[layer_idx](adj, h)
        
        return h
        
