import torch
from torch import nn
import torch.nn.functional as F

from nonweight_gcn_conv import nonWeightedGCNConv
from inits import glorot, zeros

class Learner(nn.Module):
    def __init__(self, config):
        super(Learner, self).__init__()
        self.config = config
        self.vars = nn.ParameterList()
        self.layers = []

        for i, (name, param) in enumerate(self.config):
            if name is 'gcn_conv':
                w = nn.Parameter(torch.Tensor(*param[:2]))
                b = nn.Parameter(torch.Tensor(param[1]))
                glorot(w)
                zeros(b)
                self.vars.append(w)
                self.vars.append(b)
                self.layers.append(nonWeightedGCNConv())
        



            elif name in ['elu', 'relu', 'dropout', 'softmax', 'log_softmax', 'sigmoid']:
                continue

            else:
                raise NotImplementedError
    

    def forward(self, x, edge_index, edge_weight=None, vars=None, is_training=True):
        if vars is None:
            vars = self.vars

        idx_v = 0
        idx_l = 0

        for name, param in self.config:
            if name is 'gcn_conv':
                w, b = vars[idx_v], vars[idx_v+1]
                x = self.layers[idx_l](x, edge_index, edge_weight, w, b)
                idx_v += 2
                idx_l += 1

            elif name is 'relu':
                x = F.relu(x)
            
            elif name is 'elu':
                x = F.elu(x)
            
            elif name is 'dropout':
                x = F.dropout(x, training=is_training, p = param[0])

            elif name is 'log_softmax':
                x = F.log_softmax(x, dim=1)

            elif name is 'softmax':
                x = F.softmax(x, dim=1)

            elif name is 'sigmoid':
                x = F.sigmoid(x)

        assert idx_v == len(vars) 
        assert idx_l == len(self.layers)

        return x

    def parameters(self):
        """
        override this function since initial parameters will return with a generator.
        :return:
        """
        return self.vars

    def zero_grad(self, vars=None):
        """

        :param vars:
        :return:
        """
        with torch.no_grad():
            if vars is None:
                for p in self.vars:
                    if p.grad is not None:
                        p.grad.zero_()
            else:
                for p in vars:
                    if p.grad is not None:
                        p.grad.zero_()