import torch.nn as nn

from graph_learning.module import Module, ModuleConfig, get_module
from graph_learning.config import multi_getattr, config_dispatch

@ModuleConfig.register('mlp',
                       help='[Classifier] MLP')
class LinearModuleConfig(ModuleConfig):
    def __init__(self, args, context):
        super().__init__(args, context)

        self.in_size = self.fetch_or_set(self.in_size, context, int)

    @classmethod
    def define_parser(cls, parser):
        super().define_parser(parser)
        parser.add_argument('--in-size')
        parser.add_argument('--hidden-size', type=int)
        parser.add_argument('--num-hidden-layers', type=int)
        parser.add_argument('--dropout', type=float)
        parser.add_argument('--out-size', type=int)
        parser.add_argument('--bn', action='store_true')

    @property
    def builder(self):
        return MLP

class MLP(Module):
    def __init__(self, in_size,
                 hidden_size,
                 out_size,
                 num_hidden_layers,
                 dropout,
                 bn):
        super().__init__()

        self.use_output_layer = True
        if out_size is None:
            self.use_output_layer = False
            out_size = hidden_size

        self.hidden_size = hidden_size
        self.num_hidden_layers = num_hidden_layers

        self.use_bn = bn

        self.fc_layers = nn.ModuleList()
        if self.use_bn:
            self.bn_layers = nn.ModuleList()

        feature_sizes = [in_size] + [hidden_size] * num_hidden_layers + [out_size]
        for i in range(num_hidden_layers + 1):
            self.fc_layers.append(nn.Linear(feature_sizes[i], feature_sizes[i+1]))
            if self.use_bn:
                if i != num_hidden_layers or not self.use_output_layer:
                    self.bn_layers.append(nn.BatchNorm1d(hidden_size))
        self.dropout = nn.Dropout(p=dropout)
        self.activate = nn.ReLU()

    def forward(self, xs):
        outs = None
        if isinstance(xs, tuple):
            xs, *outs = xs

        net = xs
        for i, layer in enumerate(self.fc_layers):
            net = layer(net)
            if i != self.num_hidden_layers or not self.use_output_layer:
                if self.use_bn:
                    net = self.bn_layers[i](net)
                net = self.activate(net)
                net = self.dropout(net)

        if outs is None:
            return net
        else:
            return (net, *outs)

@ModuleConfig.register('identity',
                       help='[Classifier] identity layer.')
class IdentityModuleConfig(ModuleConfig):
    def __init__(self, args, context):
        super().__init__(args, context)

    @classmethod
    def define_parser(cls, parser):
        super().define_parser(parser)

    @property
    def builder(self):
        return nn.Identity

@ModuleConfig.register('multi-classifiers',
                       help='[Classifier] Multiple classifiers.')
class MultiClassifiersModuleConfig(ModuleConfig):
    def __init__(self, args, context):
        super().__init__(args, context)
        self.classifiers = [get_module(context, cls) for cls in self.classifiers]

    @classmethod
    def define_parser(cls, parser):
        super().define_parser(parser)
        parser.add_argument('--classifiers', nargs='+',
                            help='sub-classifiers')
        parser.add_argument('--freeze', action='store_true',
                            help='freeze sub-classifiers')

    @property
    def builder(self):
        return MultiClassifiers

class MultiClassifiers(nn.Module):
    def __init__(self, classifiers, freeze):
        super().__init__()
        self.classifiers = nn.ModuleList(classifiers)
        if freeze:
            for p in self.classifiers.parameters():
                p.requires_grad = False

    def forward(self, hiddens):
        outs = None
        if isinstance(hiddens, tuple):
            hiddens, *outs = hiddens
        ret = [self.classifiers[i](hidden)
               for i, hidden in enumerate(hiddens)]

        if outs is None:
            return ret
        else:
            return (ret, *outs)

