import torch.nn as nn
from pdb import set_trace

from collections import OrderedDict
from torchmeta.modules import (MetaModule, MetaConv2d, MetaBatchNorm2d,
                               MetaSequential, MetaLinear)
from torchmeta.modules.utils import get_subdict

def conv_block(in_channels, out_channels, **kwargs):
    return MetaSequential(OrderedDict([
        ('conv', MetaConv2d(in_channels, out_channels, **kwargs)),
        ('norm', nn.BatchNorm2d(out_channels, momentum=1.,
            track_running_stats=False)),
        ('relu', nn.ReLU()),
        ('pool', nn.MaxPool2d(2))
    ]))

def mlp_block(feature_size, output_size, **kwargs):
    return MetaSequential(OrderedDict([
        ('linear', MetaLinear(feature_size, output_size, **kwargs)),
        ('relu', nn.ReLU()),
    ]))

class MetaConvModel(MetaModule):
    """4-layer Convolutional Neural Network architecture from [1].

    Parameters
    ----------
    in_channels : int
        Number of channels for the input images.

    out_features : int
        Number of classes (output of the model).

    hidden_size : int (default: 64)
        Number of channels in the intermediate representations.

    feature_size : int (default: 64)
        Number of features returned by the convolutional head.

    References
    ----------
    .. [1] Finn C., Abbeel P., and Levine, S. (2017). Model-Agnostic Meta-Learning
           for Fast Adaptation of Deep Networks. International Conference on
    """
    def __init__(self, in_channels, out_features, hidden_size=64, feature_size=64, deeper=0):
        super(MetaConvModel, self).__init__()
        self.in_channels = in_channels
        self.out_features = out_features
        self.hidden_size = hidden_size
        self.feature_size = feature_size
        self.deeper = deeper

        self.visual_features = MetaSequential(OrderedDict([
            ('layer1', conv_block(in_channels, hidden_size, kernel_size=3,
                                  stride=1, padding=1, bias=True)),
            ('layer2', conv_block(hidden_size, hidden_size, kernel_size=3,
                                  stride=1, padding=1, bias=True)),
            ('layer3', conv_block(hidden_size, hidden_size, kernel_size=3,
                                  stride=1, padding=1, bias=True)),
            ('layer4', conv_block(hidden_size, hidden_size, kernel_size=3,
                                  stride=1, padding=1, bias=True))
        ]))

        if self.deeper:
            self.features = []
            for i in range(self.deeper):
                self.features += [('layer{}'.format(i+1),
                        mlp_block(feature_size, feature_size, bias=True))]

            self.features = MetaSequential(OrderedDict(self.features))


        self.classifier = MetaLinear(feature_size, out_features, bias=True)

    def forward(self, inputs, params=None):
        visual_features = self.visual_features(inputs, params=get_subdict(params, 'visual_features'))
        features = visual_features.view((visual_features.size(0), -1))
        if self.deeper:
            features = self.features(features, params=get_subdict(params, 'features'))
        logits = self.classifier(features, params=get_subdict(params, 'classifier'))
        return logits

class MetaMLPModel(MetaModule):
    """Multi-layer Perceptron architecture from [1].

    Parameters
    ----------
    in_features : int
        Number of input features.

    out_features : int
        Number of classes (output of the model).

    hidden_sizes : list of int
        Size of the intermediate representations. The length of this list
        corresponds to the number of hidden layers.

    References
    ----------
    .. [1] Finn C., Abbeel P., and Levine, S. (2017). Model-Agnostic Meta-Learning
           for Fast Adaptation of Deep Networks. International Conference on
           Machine Learning (ICML) (https://arxiv.org/abs/1703.03400)
    """
    def __init__(self, in_features, out_features, hidden_sizes):
        super(MetaMLPModel, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.hidden_sizes = hidden_sizes

        layer_sizes = [in_features] + hidden_sizes
        self.features = MetaSequential(OrderedDict([('layer{0}'.format(i + 1),
            MetaSequential(OrderedDict([
                ('linear', MetaLinear(hidden_size, layer_sizes[i + 1], bias=True)),
                ('relu', nn.ReLU())
            ]))) for (i, hidden_size) in enumerate(layer_sizes[:-1])]))
        self.classifier = MetaLinear(hidden_sizes[-1], out_features, bias=True)

    def forward(self, inputs, params=None):
        features = self.features(inputs, params=get_subdict(params, 'features'))
        #set_trace()
        logits = self.classifier(features, params=get_subdict(params, 'classifier'))
        return logits

def ModelConvOmniglot(out_features, hidden_size=64, deeper=0):
    return MetaConvModel(1, out_features, hidden_size=hidden_size,
                feature_size=hidden_size, deeper=deeper)

def ModelConvMiniImagenet(out_features, hidden_size=64, deeper=0):
    return MetaConvModel(3, out_features, hidden_size=hidden_size,
            feature_size=4*4*hidden_size, deeper=deeper)

def ModelConvSynbols(out_features, hidden_size=64, deeper=0):
    return MetaConvModel(3, out_features, hidden_size=hidden_size,
            feature_size=2*2*hidden_size, deeper=deeper)

def ModelMLPSinusoid(hidden_sizes=[40, 40]):
    return MetaMLPModel(1, 1, hidden_sizes)

if __name__ == '__main__':
    model = ModelMLPSinusoid()
