import copy
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
import re


N_FILTERS = 64  # number of filters used in conv_block
K_SIZE = 3  # size of kernel
MP_SIZE = 2  # size of max pooling
EPS = 1e-8  # epsilon for numerical stability


class MetaSGD(nn.Module):
    """
    The class defines meta-learner for Meta-SGD algorithm.
    Training details will be written in train.py.
    TODO base-model invariant MetaLearner class
    """

    def __init__(self, params):
        super(MetaSGD, self).__init__()
        self.params = params
        if params['encoder'] == 'wide-convnet4':

            self.meta_learner = Net(
                params['encoder_args']['in_channels'], params['num_way'], dataset=None)
        elif params['encoder'] == 'resnet12':
            self.meta_learner = ResNet12([64,128,256,512], params['num_way'])
        # Defined for Meta-SGD
        # TODO do we need strictly positive task_lr?
        self.task_lr = OrderedDict()

    def forward(self, X, adapted_params=None):
        if adapted_params == None:
            out = self.meta_learner(X)
        else:
            out = self.meta_learner(X, adapted_params)
        return out

    def cloned_state_dict(self):
        """
        Only returns state_dict of meta_learner (not task_lr)
        """
        cloned_state_dict = {
            key: val.clone()
            for key, val in self.state_dict().items()
        }
        return cloned_state_dict

    def define_task_lr_params(self):
        for key, val in self.named_parameters():
            # self.task_lr[key] = 1e-3 * torch.ones_like(val, requires_grad=True)
            self.task_lr[key] = nn.Parameter(
                1e-3 * torch.ones_like(val, requires_grad=True))



class Net(nn.Module):
    """
    The base CNN model for MAML (Meta-SGD) for few-shot learning.
    The architecture is same as of the embedding in MatchingNet.
    """

    def __init__(self, in_channels, num_classes, dataset='Omniglot'):
        """
        self.net returns:
            [N, 64, 1, 1] for Omniglot (28x28)
            [N, 64, 5, 5] for miniImageNet (84x84)
        self.fc returns:
            [N, num_classes]
        
        Args:
            in_channels: number of input channels feeding into first conv_block
            num_classes: number of classes for the task
            dataset: for the measure of input units for self.fc, caused by 
                     difference of input size of 'Omniglot' and 'ImageNet'
        """
        super(Net, self).__init__()
        self.features = nn.Sequential(
            conv_block(0, in_channels, padding=1, pooling=True),
            conv_block(1, N_FILTERS, padding=1, pooling=True),
            conv_block(2, N_FILTERS, padding=1, pooling=True),
            conv_block(3, N_FILTERS, padding=1, pooling=True))
        if dataset == 'Omniglot':
            self.add_module('fc', nn.Linear(N_FILTERS, num_classes))
        elif dataset == 'ImageNet':
            self.add_module('fc', nn.Linear(N_FILTERS * 5 * 5, num_classes))
        else:
            self.add_module('fc', nn.Linear(N_FILTERS*25, num_classes))

    def forward(self, X, params=None):
        """
        Args:
            X: [N, in_channels, W, H]
            params: a state_dict()
        Returns:
            out: [N, num_classes] unnormalized score for each class
        """
        if params == None:
            out = self.features(X)
            out = out.view(out.size(0), -1)
            out = self.fc(out)
        else:
            """
            The architecure of functionals is the same as `self`.
            """
            out = F.conv2d(
                X,
                params['meta_learner.features.0.conv0.weight'],
                params['meta_learner.features.0.conv0.bias'],
                padding=1)
            # NOTE we do not need to care about running_mean anv var since
            # momentum=1.
            out = F.batch_norm(
                out,
                params['meta_learner.features.0.bn0.running_mean'],
                params['meta_learner.features.0.bn0.running_var'],
                params['meta_learner.features.0.bn0.weight'],
                params['meta_learner.features.0.bn0.bias'],
                momentum=1,
                training=True)
            out = F.relu(out, inplace=True)
            out = F.max_pool2d(out, MP_SIZE)

            out = F.conv2d(
                out,
                params['meta_learner.features.1.conv1.weight'],
                params['meta_learner.features.1.conv1.bias'],
                padding=1)
            out = F.batch_norm(
                out,
                params['meta_learner.features.1.bn1.running_mean'],
                params['meta_learner.features.1.bn1.running_var'],
                params['meta_learner.features.1.bn1.weight'],
                params['meta_learner.features.1.bn1.bias'],
                momentum=1,
                training=True)
            out = F.relu(out, inplace=True)
            out = F.max_pool2d(out, MP_SIZE)

            out = F.conv2d(
                out,
                params['meta_learner.features.2.conv2.weight'],
                params['meta_learner.features.2.conv2.bias'],
                padding=1)
            out = F.batch_norm(
                out,
                params['meta_learner.features.2.bn2.running_mean'],
                params['meta_learner.features.2.bn2.running_var'],
                params['meta_learner.features.2.bn2.weight'],
                params['meta_learner.features.2.bn2.bias'],
                momentum=1,
                training=True)
            out = F.relu(out, inplace=True)
            out = F.max_pool2d(out, MP_SIZE)

            out = F.conv2d(
                out,
                params['meta_learner.features.3.conv3.weight'],
                params['meta_learner.features.3.conv3.bias'],
                padding=1)
            out = F.batch_norm(
                out,
                params['meta_learner.features.3.bn3.running_mean'],
                params['meta_learner.features.3.bn3.running_var'],
                params['meta_learner.features.3.bn3.weight'],
                params['meta_learner.features.3.bn3.bias'],
                momentum=1,
                training=True)
            out = F.relu(out, inplace=True)
            out = F.max_pool2d(out, MP_SIZE)

            out = out.view(out.size(0), -1)
            out = F.linear(out, params['meta_learner.fc.weight'],
                           params['meta_learner.fc.bias'])

        out = F.log_softmax(out, dim=1)
        return out


def conv_block(index,
               in_channels,
               out_channels=N_FILTERS,
               padding=0,
               pooling=True):
    """
    The unit architecture (Convolutional Block; CB) used in the modules.
    The CB consists of following modules in the order:
        3x3 conv, 64 filters
        batch normalization
        ReLU
        MaxPool
    """
    if pooling:
        conv = nn.Sequential(
            OrderedDict([
                ('conv'+str(index), nn.Conv2d(in_channels, out_channels, \
                    K_SIZE, padding=padding)),
                ('bn'+str(index), nn.BatchNorm2d(out_channels, momentum=1, \
                    affine=True)),
                ('relu'+str(index), nn.ReLU(inplace=True)),
                ('pool'+str(index), nn.MaxPool2d(MP_SIZE))
            ]))
    else:
        conv = nn.Sequential(
            OrderedDict([
                ('conv'+str(index), nn.Conv2d(in_channels, out_channels, \
                    K_SIZE, padding=padding)),
                ('bn'+str(index), nn.BatchNorm2d(out_channels, momentum=1, \
                    affine=True)),
                ('relu'+str(index), nn.ReLU(inplace=True))
            ]))
    return conv

class ResNetBlock(nn.Module):
    def __init__(self, indim, outdim):
        super(ResNetBlock, self).__init__()
        self.indim = indim
        self.outdim = outdim
        
        self.C1 = nn.Conv2d(indim, outdim, kernel_size=3, stride=1, padding=1, bias=False)
        self.BN1 = nn.BatchNorm2d(outdim)
        self.C2 = nn.Conv2d(outdim, outdim, kernel_size=3, stride=1, padding=1,bias=False)
        self.BN2 = nn.BatchNorm2d(outdim)
        self.C3 = nn.Conv2d(outdim, outdim, kernel_size=3, stride=1, padding=1,bias=False)
        self.BN3 = nn.BatchNorm2d(outdim)

        self.relu1 = nn.LeakyReLU()
        self.relu2 = nn.LeakyReLU()
        self.relu3 = nn.LeakyReLU()
        self.pool = nn.MaxPool2d(2)

        self.parametrized_layers = [self.C1, self.C2, self.C3, self.BN1, self.BN2, self.BN3]

        if indim!=outdim:
            self.shortcut = nn.Conv2d(indim, outdim, 1, 1, bias=False)
            self.BNshortcut = nn.BatchNorm2d(outdim)
            self.shortcut_type = '1x1'
        else:
            self.shortcut_type = 'identity'

        for layer in self.parametrized_layers:
            init_layer(layer)
    
    def forward(self, x, params=None, prefix = None):
        if params == None:
            out = self.C1(x)
            out = self.BN1(out)
            out = self.relu1(out)
            out = self.C2(out)
            out = self.BN2(out)
            out = self.relu2(out)
            out = self.C3(out)
            out = self.BN3(out)
            short_out = x if self.shortcut_type == 'identity' else self.BNshortcut(self.shortcut(x))
            out = out + short_out
            out = self.pool(self.relu2(out))
        else:
            out = F.conv2d(x, 
                params[prefix+'.C1.weight'], 
                #params[prefix+'.C1.bias'], 
                padding=1)
            out = F.batch_norm(
                out, 
                params[prefix+'.BN1.running_mean'], params[prefix+'.BN1.running_var'],
                params[prefix+'.BN1.weight'], params[prefix+'.BN1.bias'],
                momentum=1, training=True
            )
            out = F.relu(out, inplace=True)
            out = F.conv2d(out, 
                params[prefix+'.C2.weight'], 
                #params[prefix+'.C2.bias'], 
                padding=1)
            out = F.batch_norm(
                out,
                params[prefix+'.BN2.running_mean'], params[prefix+'.BN2.running_var'],
                params[prefix+'.BN2.weight'], params[prefix+'.BN2.bias'],
                momentum=1, training=True
            )
            out = F.relu(out, inplace=True)
            out = F.conv2d(out, 
                params[prefix+'.C3.weight'], 
                #params[prefix+'.C3.bias'], 
                padding=1)
            out = F.batch_norm(
                out,
                params[prefix+'.BN3.running_mean'], params[prefix+'.BN3.running_var'],
                params[prefix+'.BN3.weight'], params[prefix+'.BN3.bias'],
                momentum=1, training=True
            )
            if self.shortcut_type == 'identity':
                short_out = x  
            else:
                short_out = F.conv2d(x, 
                    params[prefix+'.shortcut.weight'], 
                    #params[prefix+'.shortcut.bias']
                )
                short_out = F.batch_norm(
                    short_out, 
                    params[prefix+'.BNshortcut.running_mean'], params[prefix+'.BNshortcut.running_var'],
                    params[prefix+'.BNshortcut.weight'], params[prefix+'.BNshortcut.bias'],
                    momentum=1, training=True
                )
            out = out + short_out
            out = F.relu(out, inplace=True)
            out = F.max_pool2d(out, 2)
        return out  

def init_layer(L):
    # Initialization using fan-in
    if isinstance(L, nn.Conv2d):
        n = L.kernel_size[0]*L.kernel_size[1]*L.out_channels
        L.weight.data.normal_(0,math.sqrt(2.0/float(n)))
    elif isinstance(L, nn.BatchNorm2d):
        L.weight.data.fill_(1)
        L.bias.data.fill_(0)

class ResNet12(nn.Module):
    maml = False #Default
    def __init__(self, list_of_out_dims, num_classes, flatten = True):
        # list_of_num_layers specifies number of layers in each stage
        # list_of_out_dims specifies number of output channel for each stage
        super(ResNet12,self).__init__()
    
        trunk = []
        indim = 3

        half_res = True
        self.B1 = ResNetBlock(indim, list_of_out_dims[0])
        self.B2 = ResNetBlock(list_of_out_dims[0], list_of_out_dims[1])
        self.B3 = ResNetBlock(list_of_out_dims[1], list_of_out_dims[2])
        self.B4 = ResNetBlock(list_of_out_dims[2], list_of_out_dims[3])

        self.pool = nn.AdaptiveAvgPool2d(1)
        self.final_feat_dim = list_of_out_dims[-1]
        self.fc = nn.Linear(self.final_feat_dim, num_classes)

    def forward(self,x, params=None):
        
        if params is None:
            out = self.B1(x)
            out = self.B2(out)
            out = self.B3(out)
            out = self.B4(out)
            out = self.pool(out).flatten(1)
        else:
            out = self.B1(x, get_child_dict(params, 'meta_learner.B1'), prefix='meta_learner.B1')
            out = self.B2(out, get_child_dict(params, 'meta_learner.B2'), prefix='meta_learner.B2')
            out = self.B3(out, get_child_dict(params, 'meta_learner.B3'), prefix='meta_learner.B3')
            out = self.B4(out, get_child_dict(params, 'meta_learner.B4'), prefix='meta_learner.B4')
            out = self.pool(out).flatten(1)
        if params == None:
            out = self.fc(out)
        else:
            out = F.linear(out, params['meta_learner.fc.weight'], params['meta_learner.fc.bias'])
        
        out = F.log_softmax(out, dim=1)
        return out
    
    def get_out_dim(self):
        return self.final_feat_dim

def accuracy(outputs, labels):
    """
    Compute the accuracy, given the outputs and labels for all images.
    Args:
        outputs: (np.ndarray) dimension batch_size x 6 - log softmax output of the model
        labels: (np.ndarray) dimension batch_size, where each element is a value in [0, 1, 2, 3, 4, 5]
    Returns: (float) accuracy in [0,1]
    """
    outputs = np.argmax(outputs, axis=1)
    return np.sum(outputs == labels) / float(labels.size)


# Maintain all metrics required in this dictionary.
# These are used in the training and evaluation loops.
metrics = {
    'accuracy': accuracy,
    # could add more metrics such as accuracy for each token type
}

def get_child_dict(params, key=None):
    """
    Constructs parameter dictionary for a network module.

    Args:
    params (dict): a parent dictionary of named parameters.
    key (str, optional): a key that specifies the root of the child dictionary.

    Returns:
    child_dict (dict): a child dictionary of model parameters.
    """
    if params is None:
        return None
    if key is None or (isinstance(key, str) and key == ''):
        return params

    #   key_re = re.compile(r'^{0}\.(.+)'.format(re.escape(key)))
    #   if not any(filter(key_re.match, params.keys())):  # handles nn.DataParallel
    #     key_re = re.compile(r'^module\.{0}\.(.+)'.format(re.escape(key)))
    #   child_dict = OrderedDict(
    #     (key_re.sub(r'\1', k), value) for (k, value)
    #       in params.items() if key_re.match(k) is not None)

    child_dict = dict()
    for one_key in params.keys():
        if key in one_key:
            child_dict[one_key] = params[one_key]
    return child_dict