import torch
import torch.nn.functional as F

from utils.config import cfg
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import degree
from torch_geometric.nn.conv import GCNConv, GINConv, GINEConv
from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder

def get_layer(input_dim, emb_dim):
    if cfg.gnn.layer_type == 'GIN' or cfg.gnn.layer_type == 'GINE':
        return get_gin_layer(input_dim, emb_dim)
    elif cfg.gnn.layer_type == 'GCN':
        return OGBGCNConv(emb_dim) if cfg.gnn.use_edge_attr else GCNConv(input_dim, emb_dim)
    else:
        raise ValueError('Invalid layer type')

def get_gin_layer(input_dim, emb_dim):
    if cfg.dataset.format == 'OGB':
        mlp = torch.nn.Sequential(
            torch.nn.Linear(input_dim, 2*emb_dim), 
            torch.nn.BatchNorm1d(2*emb_dim), 
            torch.nn.ReLU(), 
            torch.nn.Linear(2*emb_dim, emb_dim)
        )
    else:
        mlp = torch.nn.Sequential(
            torch.nn.Linear(input_dim, emb_dim), 
            torch.nn.BatchNorm1d(emb_dim), 
            torch.nn.ReLU(), 
            torch.nn.Linear(emb_dim, emb_dim)
        )

    if cfg.gnn.layer_type == 'GINE' or cfg.gnn.use_edge_attr:
        cfg.gnn.use_edge_attr = True
        return GINEConv(mlp)
    else:
        cfg.gnn.use_edge_attr = False
        return GINConv(mlp)

class GNN_node(torch.nn.Module):
    def __init__(self):
        super(GNN_node, self).__init__()

        self.num_layers = cfg.gnn.num_layers
        self.dropout = cfg.gnn.dropout
        self.emb_dim = cfg.gnn.emb_dim

        self.atom_encoder = AtomEncoder(self.emb_dim)
        self.bond_encoder = BondEncoder(self.emb_dim)

        ###List of GNNs
        self.convs = torch.nn.ModuleList()
        self.batch_norms = torch.nn.ModuleList()

        for layer in range(self.num_layers):
            input_dim = self.emb_dim if cfg.gnn.use_atom_encoder or layer > 0 else cfg.gnn.input_dim

            self.convs.append(get_layer(input_dim, self.emb_dim))
            
            self.batch_norms.append(torch.nn.BatchNorm1d(self.emb_dim))

    def forward(self, batched_data):
        x, edge_index, batch = batched_data.x, batched_data.edge_index, batched_data.batch

        is_cgp = cfg.expander.type == 'CGP'
        no_encode_edge_attr = cfg.gnn.layer_type == 'GCN' and cfg.gnn.use_edge_attr

        if cfg.gnn.use_edge_attr:
            edge_attr = batched_data.edge_attr if no_encode_edge_attr else self.bond_encoder(batched_data.edge_attr)

            if is_cgp and cfg.expander.zero_edge_embeddings:
                virtual_edge_attr = torch.zeros((edge_index.shape[1] - edge_attr.shape[0], self.emb_dim), dtype=edge_attr.dtype).to(x.device)
                edge_attr = torch.cat((edge_attr, virtual_edge_attr), axis=0)

        if is_cgp:
            virtual_node_mask = batched_data.virtual_node_mask

        if cfg.gnn.use_edge_attr and cfg.expander.type is not None:
            if cfg.expander.zero_edge_embeddings:
                expander_edge_attr =  torch.zeros((batched_data.expander_edge_attr.shape[0], self.emb_dim), dtype=edge_attr.dtype).to(x.device)
            else:
                expander_edge_attr = batched_data.expander_edge_attr if no_encode_edge_attr else self.bond_encoder(batched_data.expander_edge_attr)

        if is_cgp and cfg.expander.zero_node_embeddings:
            asd = self.emb_dim if cfg.gnn.use_atom_encoder else x.shape[1]
            x_embeddings = torch.zeros((x.shape[0], asd), device=x.device)
            x_embeddings[~virtual_node_mask] = self.atom_encoder(x[~virtual_node_mask]) if cfg.gnn.use_atom_encoder else x[~virtual_node_mask]
        else:
            x_embeddings = self.atom_encoder(x) if cfg.gnn.use_atom_encoder else x

        h_list = [x_embeddings]

        for layer in range(self.num_layers):
            # Alternate between Cayley graph and input graph
            
            is_expander_layer = layer % 2 == 1 and cfg.expander.type is not None

            if cfg.gnn.use_edge_attr:
                if is_expander_layer:
                    h = self.convs[layer](h_list[layer], batched_data.expander_edge_index, expander_edge_attr)
                else:
                    h = self.convs[layer](h_list[layer], edge_index, edge_attr)
            else:
                if is_expander_layer:
                    h = self.convs[layer](h_list[layer], batched_data.expander_edge_index)
                else:
                    h = self.convs[layer](h_list[layer], edge_index)

            h = self.batch_norms[layer](h)

            if layer == self.num_layers - 1:
                #remove relu for the last layer
                h = F.dropout(h, self.dropout, training = self.training)
            else:
                h = F.dropout(F.relu(h), self.dropout, training = self.training)

            h_list.append(h)
        
        node_representation = h_list[-1]
        
        # remove virtual nodes for downstream tasks
        if is_cgp and cfg.expander.truncate_batch:
            node_representation = node_representation[~virtual_node_mask]
            batch = batch[~virtual_node_mask]

        return node_representation, batch

class OGBGCNConv(MessagePassing):
    def __init__(self, emb_dim):
        super(OGBGCNConv, self).__init__(aggr='add')

        self.linear = torch.nn.Linear(emb_dim, emb_dim)
        self.root_emb = torch.nn.Embedding(1, emb_dim)
        self.bond_encoder = BondEncoder(emb_dim = emb_dim)

    def forward(self, x, edge_index, edge_attr):
        x = self.linear(x)
        edge_embedding = self.bond_encoder(edge_attr)

        row, col = edge_index

        #edge_weight = torch.ones((edge_index.size(1), ), device=edge_index.device)
        deg = degree(row, x.size(0), dtype = x.dtype) + 1
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0

        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        return self.propagate(edge_index, x=x, edge_attr = edge_embedding, norm=norm) + F.relu(x + self.root_emb.weight) * 1./deg.view(-1,1)

    def message(self, x_j, edge_attr, norm):
        return norm.view(-1, 1) * F.relu(x_j + edge_attr)

    def update(self, aggr_out):
        return aggr_out

if __name__ == "__main__":
    pass
