import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

from .ops import NaOp, ScOp, LaOp

__all__ = ['BaselineModel', 'mlp_model', 'gcn_model', 'gin_model', 'sage_model', 'gat_model', 'geniepath_model', 'chebconv_model']
# 'graph_sage'

class BaselineModel(nn.Module):
    '''
        implement this for sane.
        Actually, sane can be seen as the combination of three cells, node aggregator, skip connection, and layer aggregator
        for sane, we dont need cell, since the DAG is the whole search space, and what we need to do is implement the DAG.
    '''
    def __init__(self, genotype, jk, in_dim, out_dim, hidden_size, num_layers=3, in_dropout=0.5, out_dropout=0.5, act='relu', config=None):
        super(BaselineModel, self).__init__()
        self.arch = genotype
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.hidden_size = hidden_size
        self.jk = jk
        self.num_layers = num_layers
        self.in_dropout = in_dropout
        self.out_dropout = out_dropout
        ops = genotype.split('||')
        self.config = config

        # node aggregator op
        self.lin1 = nn.Linear(in_dim, hidden_size)
        self.gnn_layers = nn.ModuleList(
                [NaOp(ops[i], hidden_size, hidden_size, act, with_linear=config.with_linear) for i in range(num_layers)])

        if self.jk:
            # skip op
            if self.config.fix_last:
                if self.num_layers > 1:
                    self.sc_layers = nn.ModuleList([ScOp(ops[i+num_layers]) for i in range(num_layers - 1)])
                else:
                    self.sc_layers = nn.ModuleList([ScOp(ops[num_layers])])
            else:
                # no output conditions.
                skip_op = ops[num_layers:2 * num_layers]
                if skip_op == ['none'] * num_layers:
                    skip_op[-1] = 'skip'
                    print('skip_op:', skip_op)
                self.sc_layers = nn.ModuleList([ScOp(skip_op[i]) for i in range(num_layers)])

            #layer aggregator op
            self.layer6 = LaOp(ops[-1], hidden_size, 'linear', num_layers)
        self.classifier = nn.Linear(hidden_size, out_dim)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        x = self.lin1(x)
        x = F.dropout(x, p=self.in_dropout, training=self.training)
        js = []

        for i in range(self.num_layers):
            x = self.gnn_layers[i](x, edge_index)
            if self.config.with_layernorm:
                layer_norm = nn.LayerNorm(normalized_shape=x.size(), elementwise_affine=False)
                x = layer_norm(x)
            x = F.dropout(x, p=self.in_dropout, training=self.training)
            if self.jk:
                if i == self.num_layers - 1 and self.config.fix_last:
                    js.append(x)
                else:
                    js.append(self.sc_layers[i](x))

        if self.jk:
            x5 = self.layer6(js)
            x5 = F.dropout(x5, p=self.out_dropout, training=self.training)
            logits = self.classifier(x5)
        else:
            logits = self.classifier(x)

        return logits

    def genotype(self):
        return self.arch

def mlp_model(**kwargs):
    """
    Constructs a mlp model.
    """
    assert('num_layers' in kwargs)
    num_layers = kwargs['num_layers']
    model = BaselineModel('||'.join(['mlp'] * num_layers), False, **kwargs)
    return model


def gcn_model(**kwargs):
    """
    Constructs a gcn model.
    """
    assert('num_layers' in kwargs)
    num_layers = kwargs['num_layers']
    model = BaselineModel('||'.join(['gcn'] * num_layers), False, **kwargs)
    return model


def gin_model(**kwargs):
    """
    Constructs a gin model.
    """
    assert('num_layers' in kwargs)
    num_layers = kwargs['num_layers']
    model = BaselineModel('||'.join(['gin'] * num_layers), False, **kwargs)
    return model


def sage_model(**kwargs):
    """
    Constructs a sage model.
    """
    assert('num_layers' in kwargs)
    num_layers = kwargs['num_layers']
    model = BaselineModel('||'.join(['sage'] * num_layers), False, **kwargs)
    return model


def gat_model(**kwargs):
    """
    Constructs a gat model.
    """
    assert('num_layers' in kwargs)
    num_layers = kwargs['num_layers']
    model = BaselineModel('||'.join(['gat'] * num_layers), False, **kwargs)
    return model


def geniepath_model(**kwargs):
    """
    Constructs a geniepath model.
    """
    assert('num_layers' in kwargs)
    num_layers = kwargs['num_layers']
    model = BaselineModel('||'.join(['geniepath'] * num_layers), False, **kwargs)
    return model


def chebconv_model(**kwargs):
    """
    Constructs a chebconv model.
    """
    assert('num_layers' in kwargs)
    num_layers = kwargs['num_layers']
    model = BaselineModel('||'.join(['chebconv'] * num_layers), False, **kwargs)
    return model

