import torch
import torch.nn as nn

from graph_learning.module import ModuleConfig

class MultiLayerMPConfig(ModuleConfig):
    def __init__(self, args, context):
        super().__init__(args, context)
        self.activation_layer = nn.ReLU()
        self.dropout_layer = nn.Dropout(p=self.dropout)
        self.bn_layer_builder = lambda: nn.BatchNorm1d(self.hidden_size, track_running_stats=False)

    @classmethod
    def define_parser(cls, parser):
        super().define_parser(parser)
        parser.add_argument('--in-size', type=int)
        parser.add_argument('--hidden-size', type=int)
        parser.add_argument('--out-size', type=int)
        parser.add_argument('--num-layers', type=int)
        parser.add_argument('--dropout', type=float)
        parser.add_argument('--bn', action='store_true',
                            help='use batchnorm')
        parser.add_argument('--raw', action='store_true',
                            help='return node representations only')
        parser.add_argument('--return-hiddens', action='store_true',
                            help='return outputs of all layers')

class CommonMultiLayerMPConfig(MultiLayerMPConfig):
    def __init__(self, args, context):
        super().__init__(args, context)
        no_hidden = (self.hidden_size is None
                     or (self.out_size is not None
                         and self.num_layers == 1))
        self.input_layer_builder = (lambda: self._input_layer_builder(
            in_size=self.in_size,
            out_size=self.hidden_size,)) if not no_hidden else None
        self.hidden_layer_builder = (lambda: self._hidden_layer_builder(
            in_size=self.hidden_size,
            out_size=self.hidden_size,)) if not no_hidden else None
        self.output_layer_builder = (lambda: self._output_layer_builder(
            in_size=self.hidden_size if not no_hidden else self.in_size,
            out_size=self.out_size,)) if self.out_size is not None else None

    def _layer_builder(self, in_size, out_size):
        raise NotImplementedError

    def _input_layer_builder(self, in_size, out_size):
        return self._layer_builder(
            in_size=in_size,
            out_size=out_size,)

    def _hidden_layer_builder(self, in_size, out_size):
        return self._layer_builder(
            in_size=in_size,
            out_size=out_size,)

    def _output_layer_builder(self, in_size, out_size):
        return self._layer_builder(
            in_size=in_size,
            out_size=out_size,)

    @property
    def builder(self):
        return MultiLayerMP

class MultiLayerMP(nn.Module):
    def __init__(self, num_layers,
                 input_layer_builder,
                 hidden_layer_builder,
                 output_layer_builder,
                 dropout_layer,
                 activation_layer,
                 bn_layer_builder,
                 bn,
                 raw,
                 return_hiddens,
                 name):
        super().__init__()
        # flags
        self._return_hiddens = False
        if return_hiddens:
            self.return_hiddens()
        self.use_bn = bn
        self.is_raw = raw
        self.name = name

        # layers
        self.mp_layers = nn.ModuleList()
        self.bn_layers = nn.ModuleList()
        if input_layer_builder is None:
            layer_builders = [output_layer_builder]
            self.use_output_layer = True
        else:
            if output_layer_builder is None:
                layer_builders = [input_layer_builder] + [hidden_layer_builder] * (num_layers-1)
                self.use_output_layer = False
            else:
                layer_builders = [input_layer_builder] + [hidden_layer_builder] * (num_layers-2) + [output_layer_builder]
                self.use_output_layer = True

        for i, layer_builder in enumerate(layer_builders):
            self.mp_layers.append(layer_builder())
            if self.use_bn:
                if i != len(layer_builders) - 1 or not self.use_output_layer:
                    self.bn_layers.append(bn_layer_builder())
        self.dropout = dropout_layer
        self.activation = activation_layer

    def return_hiddens(self):
        self._return_hiddens = True

    def forward(self, data, x):
        g = data.graph().local_var()
        h = x
        hs = []
        for i, layer in enumerate(self.mp_layers):
            h = layer(g, h)
            if i != len(self.mp_layers) - 1 or not self.use_output_layer:
                if self.use_bn:
                    h = self.bn_layers[i](h)
                h = self.activation(h)
                h = self.dropout(h)
            hs += [h]

        ret = (hs, torch.stack([torch.ones(x.size(0), dtype=torch.bool, device=x.device)
                                                   for x in hs], 1)) if self._return_hiddens else h

        if self.is_raw:
            return ret
        else:
            return {'hidden': ret,
                    'outputs': {}}
