from config import *
from layers import *
from metrics import *

import torch
import torch.nn as nn
from torch.nn import init

class GCN2(nn.Module):
    def __init__(self, input_dim, output_dim,**kwargs):
        super(GCN2, self).__init__(**kwargs)
        usebias = args.bias
        self.bn = args.bn

        try:
            hiddens = [int(s) for s in args.hiddens.split('-')]
        except:
            hiddens =[args.hidden1]

        self.layers = []

        # Create first layer and append
        print('input dim', input_dim)
        print('hiddens', hiddens)
        layer0 = GraphConvolution(input_dim=input_dim[1],
                                  output_dim=hiddens[0],
                                  activation=nn.ReLU,
                                  bias=usebias)
        self.layers.append(layer0)

        # Append all hidden layers
        for _ in range(1, len(hiddens)):

            # Append batch norm layer
            if self.bn:
                self.layers.append(nn.BatchNorm1d(input_dim[0])) #hiddens[_]))

            self.layers.append(GraphConvolution(input_dim=hiddens[_-1],
                                                 output_dim=hiddens[_],
                                                 activation=nn.ReLU,
                                                 bias=usebias)
                                )

        self.layers_ = torch.nn.ModuleList(self.layers)

        # Create final linear layer
        self.pred_layer = nn.Linear(hiddens[-1], output_dim)

        self.hiddens = hiddens

    def forward(self,inputs,training=None):
        out = self.getNodeEmb(inputs,training)
        # print('out', out.shape, out)

        out1, _ = torch.max(out, dim=-1)
        out2 = torch.sum(out, dim=-1)
        # print('out1', out1.shape, out1)
        # print('out2', out2.shape, out2)

        out = torch.cat([out1, out2], dim=-1)
        # print('out_concat', out.shape, out)
        out = self.pred_layer(out)
        # print('out_pred', out.shape, out)
        return out

    def getNodeEmb(self, inputs, training=None):
        x, support = inputs
        x_all = []
        for layer in self.layers_:
            if isinstance(layer, nn.BatchNorm1d):
                x = layer.forward(x)
                x_all.append(x)
            else:
                # print('pre_layer', x.shape, x)
                x = layer.forward((x, support), training)
                if not args.bn:
                    x_all.append(x)
        if args.bn:
            x_all.append(x)
        if args.concat:
            x = torch.cat(x_all, dim=-1)
        out = x
        return out        
