import torch
import __init__
from gcn_lib.torch_vertex import GENConv
from gcn_lib.torch_nn import norm_layer
import torch.nn.functional as F
from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool
import logging
import pdb

class DeeperGCN(torch.nn.Module):
    def __init__(self, args):
        super(DeeperGCN, self).__init__()

        self.num_layers = 28
        self.dropout = 0.5
        self.block = "res+"

        hidden_channels = 128
        num_tasks = 1
        conv = args.conv
        aggr = args.gcn_aggr
        
        learn_msg_scale = False
        conv_encode_edge = args.conv_encode_edge
        norm = args.norm
        mlp_layers = args.mlp_layers
        graph_pooling = args.graph_pooling

        print('The number of layers {}'.format(self.num_layers),
              'Aggr aggregation method {}'.format(aggr),
              'block: {}'.format(self.block))
        if self.block == 'res+':
            print('LN/BN->ReLU->GraphConv->Res')
        elif self.block == 'res':
            print('GraphConv->LN/BN->ReLU->Res')
        elif self.block == 'dense':
            raise NotImplementedError('To be implemented')
        elif self.block == "plain":
            print('GraphConv->LN/BN->ReLU')
        else:
            raise Exception('Unknown block Type')

        self.gcns = torch.nn.ModuleList()
        self.norms = torch.nn.ModuleList()

        for layer in range(self.num_layers):

            if conv == 'gen':
                gcn = GENConv(hidden_channels, 
                              hidden_channels,
                              aggr=aggr,
                              msg_norm=False, 
                              learn_msg_scale=learn_msg_scale,
                              encode_edge=conv_encode_edge, 
                              edge_feat_dim=hidden_channels,
                              norm=norm, mlp_layers=mlp_layers)
            else:
                raise Exception('Unknown Conv Type')
            self.gcns.append(gcn)
            self.norms.append(norm_layer(norm, hidden_channels))

        self.node_features_encoder = torch.nn.Linear(1, hidden_channels)
        self.edge_encoder = torch.nn.Linear(7, hidden_channels)

        if graph_pooling == "sum":
            self.pool = global_add_pool
        elif graph_pooling == "mean":
            self.pool = global_mean_pool
        elif graph_pooling == "max":
            self.pool = global_max_pool
        else:
            raise Exception('Unknown Pool Type')

        self.graph_pred_linear = torch.nn.Linear(hidden_channels, num_tasks)

    def forward(self, input_batch, data_mask=None):

        pdb.set_trace()
        x = input_batch.x
        edge_index = input_batch.edge_index
        edge_attr = input_batch.edge_attr
        batch = input_batch.batch

        h = self.node_features_encoder(x)
        edge_emb = self.edge_encoder(edge_attr)

        h = self.gcns[0](h, edge_index, edge_emb, data_mask=data_mask)

        for layer in range(1, self.num_layers):
            h1 = self.norms[layer - 1](h)
            h2 = F.relu(h1)
            h2 = F.dropout(h2, p=self.dropout, training=self.training)
            h = self.gcns[layer](h2, edge_index, edge_emb, data_mask=data_mask) + h

        h = self.norms[self.num_layers - 1](h)
        h = F.dropout(h, p=self.dropout, training=self.training)
        h_graph = self.pool(h, batch)

        return self.graph_pred_linear(h_graph)

    def print_params(self, epoch=None, final=False):
        if self.learn_t:
            ts = []
            for gcn in self.gcns:
                ts.append(gcn.t.item())
            if final:
                print('Final t {}'.format(ts))
            else:
                logging.info('Epoch {}, t {}'.format(epoch, ts))
        if self.learn_p:
            ps = []
            for gcn in self.gcns:
                ps.append(gcn.p.item())
            if final:
                print('Final p {}'.format(ps))
            else:
                logging.info('Epoch {}, p {}'.format(epoch, ps))
        if self.msg_norm:
            ss = []
            for gcn in self.gcns:
                ss.append(gcn.msg_norm.msg_scale.item())
            if final:
                print('Final s {}'.format(ss))
            else:
                logging.info('Epoch {}, s {}'.format(epoch, ss))