import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn.models import GCN, GraphSAGE, GIN, GAT

from .mlp import MLP

# __all__ = ['mlp_model', 'gcn_model', 'gin_model', 'sage_model', 'gat_model']
__all__ = ['WrapperModel']

class WrapperModel(nn.Module):
    '''
        implement this for sane.
        Actually, sane can be seen as the combination of three cells, node aggregator, skip connection, and layer aggregator
        for sane, we dont need cell, since the DAG is the whole search space, and what we need to do is implement the DAG.
    '''
    def __init__(self, model, is_mlp=False):
        super(WrapperModel, self).__init__()
        # node aggregator op
        self.model = model
        self.is_mlp = is_mlp

    def forward(self, data):
        if self.is_mlp:
            x = data.x
            return self.model(x)
        else:
            x, edge_index = data.x, data.edge_index
            return self.model(x, edge_index)

def gcn_model(**kwargs):
    """
    Constructs a gcn model.
    """
    model = GCN(**kwargs)
    return WrapperModel(model)


def gin_model(**kwargs):
    """
    Constructs a gin model.
    """
    model = GIN(**kwargs)
    return WrapperModel(model)


def sage_model(**kwargs):
    """
    Constructs a sage model.
    """
    model = GraphSAGE(**kwargs)
    return WrapperModel(model)


def gat_model(**kwargs):
    """
    Constructs a gat model.
    """
    model = GAT(**kwargs)
    return WrapperModel(model)

def mlp_model(**kwargs):
    """
    Constructs a mlp model.
    """
    model = MLP(**kwargs)
    return WrapperModel(model, is_mlp=True)

# def geniepath_model(**kwargs):
#     """
#     Constructs a geniepath model.
#     """
# 
# 
#     model = BaselineModel('||'.join(['geniepath'] * num_layers), False, **kwargs)
#     return model


# def chebconv_model(**kwargs):
#     """
#     Constructs a chebconv model.
#     """
# 
# 
#     model = BaselineModel('||'.join(['chebconv'] * num_layers), False, **kwargs)
#     return model

