import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

from ..utils.diff_utils import drop_path
from ..ops.normal_layers_withbn import NaOp, ScOp, LaOp

class ReducedBNSANEBackBone(nn.Module):
    '''
        implement this for sane.
        Actually, sane can be seen as the combination of three cells, node aggregator, skip connection, and layer aggregator
        for sane, we dont need cell, since the DAG is the whole search space, and what we need to do is implement the DAG.
    '''
    def __init__(self, genotype, in_dim, out_dim, hidden_size, num_layers=3, in_dropout=0.5, out_dropout=0.5, act='relu', config=None):
        super(ReducedBNSANEBackBone, self).__init__()
        self.arch = genotype
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.in_dropout = in_dropout
        self.out_dropout = out_dropout
        ops = genotype.split('||')
        self.config = config

        # node aggregator op
        self.lin1 = nn.Linear(in_dim, hidden_size)
        self.gnn_layers = nn.ModuleList(
                [NaOp(ops[i], hidden_size, hidden_size, act, with_linear=config.with_linear) for i in range(num_layers)])

        # skip op
        if self.config.fix_last:
            if self.num_layers > 1:
                self.sc_layers = nn.ModuleList([ScOp(ops[i+num_layers]) for i in range(num_layers - 1)])
            else:
                self.sc_layers = nn.ModuleList([ScOp(ops[num_layers])])
        else:
            # no output conditions.
            skip_op = ops[num_layers:2 * num_layers]
            if skip_op == ['none'] * num_layers:
                skip_op[-1] = 'skip'
                print('skip_op:', skip_op)
            self.sc_layers = nn.ModuleList([ScOp(skip_op[i]) for i in range(num_layers)])

        #layer aggregator op
        self.layer6 = LaOp(ops[-1], hidden_size, 'linear', num_layers)
        self.classifier = nn.Linear(hidden_size, out_dim)

    def new(self):
        model_new = ReducedBNSANEBackBone(self._C, self._num_classes, self._layers, self._criterion).cuda()
        for x, y in zip(model_new.arch_parameters(), self.arch_parameters()):
            x.data.copy_(y.data)
        return model_new

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        #generate weights by softmax
        x = self.lin1(x)
        x = F.dropout(x, p=self.in_dropout, training=self.training)
        js = []
        for i in range(self.num_layers):
            x = self.gnn_layers[i](x, edge_index)
            if self.config.with_layernorm:
                layer_norm = nn.LayerNorm(normalized_shape=x.size(), elementwise_affine=False)
                x = layer_norm(x)
            x = F.dropout(x, p=self.in_dropout, training=self.training)
            if i == self.num_layers - 1 and self.config.fix_last:
                js.append(x)
            else:
                js.append(self.sc_layers[i](x))
        x5 = self.layer6(js)
        x5 = F.dropout(x5, p=self.out_dropout, training=self.training)

        logits = self.classifier(x5)
        return logits

    def genotype(self):
        return self.arch

def reduced_bn_sane_back_bone(**kwargs):
    return ReducedBNSANEBackBone(**kwargs)