import torch
import torch_geometric.graphgym.models.head  # noqa, register module
import torch_geometric.graphgym.register as register
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.models.gnn import FeatureEncoder, GNNPreMP
from torch_geometric.graphgym.register import register_network

from graphgps.layer.gatedgcn_layer import GatedGCNLayer
from graphgps.layer.gine_conv_layer import GINEConvLayer
from graphgps.layer.gcn_conv_layer import GCNConvLayer


class LLayer(torch.nn.Module):
    def __init__(self, cfg):
        super().__init__()
        layers = []
        layers.append(torch.nn.Linear(cfg.gnn.dim_inner * (cfg.gnn.layers_mp), cfg.gnn.dim_inner, bias=True))
        layers.append(register.act_dict[cfg.gnn.act]())
        for _ in range(0):
            layers.append(torch.nn.Dropout(cfg.gnn.dropout))
            layers.append(torch.nn.Linear(cfg.gnn.dim_inner, cfg.gnn.dim_inner, bias=True))
            layers.append(register.act_dict[cfg.gnn.act]())

        self.mlp = torch.nn.Sequential(*layers)

    #def reset_parameters(self):
    #    self.lin1.reset_parameters()
    #    self.lin2.reset_parameters()

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = self.mlp(x)
        data.x = x
        return data

    def __repr__(self):
        return self.__class__.__name__


@register_network('custom_gnn')
class CustomGNN(torch.nn.Module):
    """
    GNN model that customizes the torch_geometric.graphgym.models.gnn.GNN
    to support specific handling of new conv layers.
    """

    def __init__(self, dim_in, dim_out):
        super().__init__()
        self.encoder = FeatureEncoder(dim_in)
        dim_in = self.encoder.dim_in

        if cfg.gnn.layers_pre_mp > 0:
            self.pre_mp = GNNPreMP(
                dim_in, cfg.gnn.dim_inner, cfg.gnn.layers_pre_mp)
            dim_in = cfg.gnn.dim_inner

        assert cfg.gnn.dim_inner == dim_in, \
            "The inner and hidden dims must match."

        conv_model = self.build_conv_model(cfg.gnn.layer_type)
        layers = []
        for _ in range(cfg.gnn.layers_mp):
            layers.append(conv_model(dim_in,
                                     dim_in,
                                     dropout=cfg.gnn.dropout,
                                     dropout_global=cfg.gnn.global_dropout,
                                     pooling_layer=cfg.gnn.pooling_layer,
                                     residual=cfg.gnn.residual,
                                     add_layer_pooling=cfg.gnn.layer_pooling,
                                     add_feedforward=cfg.gnn.feedforward,
                                     add_norm_weighting=cfg.gnn.norm_weighting))
        self.gnn_layers = torch.nn.Sequential(*layers)

        GNNHead = register.head_dict[cfg.gnn.head]

        if cfg.gnn.layer_pooling == 'JK':
            #cfg.gnn.layers_post_mp += 1
            self.post_mp = GNNHead(dim_in=cfg.gnn.dim_inner * (cfg.gnn.layers_mp), dim_out=dim_out)
        else:
            self.post_mp = GNNHead(dim_in=cfg.gnn.dim_inner, dim_out=dim_out)

    def build_conv_model(self, model_type):
        if model_type == 'gatedgcnconv':
            return GatedGCNLayer
        elif model_type == 'gineconv':
            return GINEConvLayer
        elif model_type == 'gcnconv':
            return GCNConvLayer
        else:
            raise ValueError("Model {} unavailable".format(model_type))

    def forward(self, batch):
        xs = []
        #print(batch.y)
        for idx, module in enumerate(self.children()):
            # JK if applied
            if idx == 2:
                if 'layer_values' in batch:
                    for x_val in batch.layer_values:
                        xs += [x_val]
                    xs = torch.cat(xs, dim=-1)
                    #print(xs)
                    #print(batch.x.shape)
                    batch.x = xs
            batch = module(batch)
        return batch
