from collections import OrderedDict

import torch.nn as nn

from .encoders import register
from ..modules import *

__all__ = ['resnet10', 'resnet16']


def conv3x3(in_channels, out_channels):
  return Conv2d(in_channels, out_channels, 3, 1, padding=1, bias=False)


def conv1x1(in_channels, out_channels):
  return Conv2d(in_channels, out_channels, 1, 1, padding=0, bias=False)

# the batch normalization
class BatchNorm2d_fw(nn.BatchNorm2d): #used in MAML to forward input with fast weight 
    def __init__(self, num_features):
        super(BatchNorm2d_fw, self).__init__(num_features)
        self.weight.fast = None
        self.bias.fast = None

    def forward(self, x):
        running_mean = torch.zeros(x.data.size()[1]).cuda()
        running_var = torch.ones(x.data.size()[1]).cuda()
        if self.weight.fast is not None and self.bias.fast is not None:
            out = F.batch_norm(x, running_mean, running_var, self.weight.fast, self.bias.fast, training = True, momentum = 1)
            #batch_norm momentum hack: follow hack of Kate Rakelly in pytorch-maml/src/layers.py
        else:
            out = F.batch_norm(x, running_mean, running_var, self.weight, self.bias, training = True, momentum = 1)
        return out

class Block(Module):
  def __init__(self, in_planes, planes, bn_args):
    super(Block, self).__init__()
    self.in_planes = in_planes
    self.planes = planes

    self.conv1 = conv3x3(in_planes, planes)
    self.bn1 = BatchNorm2d_fw(planes)
    self.conv2 = conv3x3(planes, planes)
    self.bn2 = BatchNorm2d_fw(planes)

    self.res_conv = Sequential(OrderedDict([
      ('conv', conv1x1(in_planes, planes)),
      ('bn', BatchNorm2d(planes, **bn_args)),
    ]))

    self.relu = nn.ReLU(inplace=True)
    self.pool = nn.MaxPool2d(2)

  def forward(self, x, params=None, episode=None):
    out = self.conv1(x, get_child_dict(params, 'conv1'))
    out = self.bn1(out, get_child_dict(params, 'bn1'), episode)
    out = self.relu(out)

    out = self.conv2(out, get_child_dict(params, 'conv2'))
    out = self.bn2(out, get_child_dict(params, 'bn2'), episode)
    

    x = self.res_conv(x, get_child_dict(params, 'res_conv'), episode)
    out = self.pool(self.relu(out + x))
    out = self.relu(out)
    return out


@register('resnet10')
def resnet10(bn_args=dict()):
  return ResNet10([64, 128, 256, 512], bn_args)


@register('resnet16')
def resnet16(bn_args=dict()):
  return ResNet16([64, 128, 256, 512], bn_args)