import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import math

from torchmeta.modules import (MetaModule, MetaConv2d, MetaBatchNorm2d,
                               MetaSequential, MetaLinear)
from utils.data_utils import get_subdict
from collections import OrderedDict

def normal_conv_block(in_channels, out_channels, **kwargs):
    return MetaSequential(OrderedDict([
      ('conv', nn.Conv2d(in_channels, out_channels, **kwargs)),
      ('norm', nn.BatchNorm2d(out_channels, track_running_stats=False)),
      ('relu', nn.ReLU()),
      ('pool', nn.MaxPool2d(2))
    ]))

class ConvModel(nn.Module):
    def __init__(self,in_channels,hidden_size=64):
        super(ConvModel,self).__init__()
        self.in_channels=in_channels
        self.hidden_size=hidden_size

        self.features = MetaSequential(OrderedDict([
        ('layer1', normal_conv_block(in_channels, hidden_size, kernel_size=3,
                              stride=1, padding=1, bias=True)),
        ('layer2', normal_conv_block(hidden_size, hidden_size, kernel_size=3,
                              stride=1, padding=1, bias=True)),
        ('layer3', normal_conv_block(hidden_size, hidden_size, kernel_size=3,
                              stride=1, padding=1, bias=True)),
        ('layer4', normal_conv_block(hidden_size, hidden_size, kernel_size=3,
                              stride=1, padding=1, bias=True))
        ]))
    def forward(self, inputs):
        features = self.features(inputs,)
        features = features.view((features.size(0), -1))
        return features


def conv_block(in_channels, out_channels, non_lin,**kwargs):
    return MetaSequential(OrderedDict([
      ('conv', MetaConv2d(in_channels, out_channels, **kwargs)),
      ('norm', MetaBatchNorm2d(out_channels, track_running_stats=False)),
      ('relu', non_lin),
      ('pool', nn.MaxPool2d(2))
    ]))

class MetaConvModel(MetaModule):
    def __init__(self,in_channels,out_features,hidden_size=64,feature_size=64,
                                                             non_lin=nn.ReLU()):
        super(MetaConvModel,self).__init__()
        self.in_channels=in_channels
        self.out_features=out_features
        self.hidden_size=hidden_size
        self.feature_size=feature_size

        self.features = MetaSequential(OrderedDict([
        ('layer1', conv_block(in_channels, hidden_size, non_lin,kernel_size=3,
                              stride=1, padding=1, bias=True)),
        ('layer2', conv_block(hidden_size, hidden_size, non_lin,kernel_size=3,
                              stride=1, padding=1, bias=True)),
        ('layer3', conv_block(hidden_size, hidden_size, non_lin,kernel_size=3,
                              stride=1, padding=1, bias=True)),
        ('layer4', conv_block(hidden_size, hidden_size, non_lin,kernel_size=3,
                              stride=1, padding=1, bias=True))
        ]))
        self.classifier = MetaLinear(feature_size, out_features, bias=True)

    def forward(self, inputs, params=None):
        features = self.features(inputs, params=get_subdict(params, 'features'))
        features = features.view((features.size(0), -1))
        logits=self.classifier(features,params=get_subdict(params,'classifier'))
        return logits

class XAttentionMask(nn.Module):
    def __init__(self,layer_names,layer_sizes,input_channels,hidden_size=64,
                feature_size=64,
                meta_relu_sgd = False,
                meta_sgd_linear = False,
                meta_relu_through= False,
                x_debug=False,
                x_debug_noise=0.0,
                out_shift=0.):

        super(XAttentionMask,self).__init__()
        self.layer_names=layer_names
        self.layer_sizes=layer_sizes
        self.hidden_size=hidden_size
        self.feature_size=feature_size
        self.meta_relu_sgd = meta_relu_sgd
        self.meta_relu_through = meta_relu_through
        self.meta_sgd_linear = meta_sgd_linear
        self.x_debug = x_debug
        self.x_debug_noise = x_debug_noise
        self.total_layers=sum(layer_sizes)
        self.n=layer_sizes[0]*layer_sizes[0]*layer_sizes[1]
        self.alpha_size=sum(layer_sizes)*(self.n)
        self.embedding=ConvModel(input_channels, hidden_size)
        self.out_shift=out_shift
        self.alphas=nn.Linear(feature_size, self.alpha_size)

        self.weight_embedding_list = nn.ParameterList([])
        self.weight_embedding_list.append(nn.Parameter(
                                   torch.ones(feature_size)))
        self.weight_embedding_list.append(nn.Parameter(
                                   torch.zeros(feature_size)))

        print("\nX dep modulation of the conv parameters turned on.\n")

    def forward(self, inputs):

        # terms in the linear projection
        if self.x_debug:
            x = torch.randn_like(self.weight_embedding_list[0])*\
                      self.weight_embedding_list[0]*self.x_debug_noise \
                                                        + self.weight_embedding_list[1]
        else:
            x = self.embedding(inputs)
            # agregate statistics over task
            x = torch.mean(x, dim=0, keepdim=True)

        alphas=self.alphas(x).reshape(self.total_layers, self.n)
        #Divide alpha into the appropriate rows
        prev=0
        masks={}
        for i,name in enumerate(self.layer_names):
            alpha=alphas[prev:prev+self.layer_sizes[i]]
            alpha=alpha.reshape(64, alpha.shape[0], 3, 3)
            if self.meta_relu_through:
                masks[name] = ReluStraightThrough.apply(alpha + self.out_shift)
            elif self.meta_relu_sgd:
                masks[name] = torch.relu(alpha + self.out_shift)
            elif self.meta_sgd_linear:
                masks[name] = alpha + self.out_shift
            else:
                sign = BinaryLayer.apply(alpha + self.out_shift)
                masks[name]= 0.5*(1 + sign)
            prev+=self.layer_sizes[i]
        return masks

class BinaryLayer(torch.autograd.Function):
    def __init__(self):
        super(BinaryLayer, self).__init__()

    @staticmethod
    def forward(self, input):
        return torch.sign(input)

    @staticmethod
    def backward(self, grad_output):
        return grad_output

class ReluStraightThrough(torch.autograd.Function):
    def __init__(self):
        super(ReluStraightThrough, self).__init__()

    @staticmethod
    def forward(self, input):
        return torch.relu(input)

    @staticmethod
    def backward(self, grad_output):
        return grad_output

class AttentionMask(nn.Module):
    def __init__(self, weight_names, weight_shapes, kaiming_init =True,
                 init_shift=0, x_shape=None, x_conv_attention=None,
                 meta_relu_sgd = False, meta_sgd_linear = False,
                 meta_relu_through= False, noise_std = 0.,
                 no_bn_masking=False,
                 meta_sgd_init = False,
                 alpha_init=0.1,
                 no_head_masking=False, dynamic_mask=False, dyn_mask_init=0):
        super(AttentionMask, self).__init__()

        self.weight_names = weight_names
        self.weight_shapes = weight_shapes
        self.weight_mask_list = nn.ParameterList([])
        self.x_conv_attention = x_conv_attention
        self.no_head_masking = no_head_masking
        self.no_bn_masking = no_bn_masking
        self.meta_relu_sgd = meta_relu_sgd
        self.meta_sgd_linear = meta_sgd_linear
        self.meta_relu_through = meta_relu_through
        self.noise_std = noise_std
        self.meta_sgd_init = meta_sgd_init
        self.dynamic_mask =  dynamic_mask

        if self.dynamic_mask:
            print("You are using a dynamic mask.")
            self.mask_change_list = nn.ParameterDict({})

        weight_names_new = []

        for i, name in zip(range(len(weight_shapes)), weight_names):

            # no head masking in general
            if self.no_head_masking and "classifier" in name:
                continue

            # BatchNorm masking in general
            if self.no_bn_masking and "norm" in name:
                continue

            # if this is given, the mask will come from the x dep hnet
            if self.x_conv_attention is not None and "conv" in name:
                continue

            weight_names_new.append(name)
            alpha = nn.Parameter(torch.zeros(weight_shapes[i]))

            if self.meta_sgd_init:
                # original init proposed by meta-sgd (for regression)
                nn.init.uniform_(alpha, a=0.005, b=0.1)
            else:
                if len(weight_shapes[i]) > 1 and kaiming_init:
                    nn.init.kaiming_uniform_(alpha)
                else:
                    nn.init.uniform_(alpha, a=-0.5, b=0.5)

                # control the mean / sparsity init explicitly
                alpha.data = alpha.data - torch.mean(alpha.data) + init_shift

            self.weight_mask_list.append(alpha)

            if self.dynamic_mask:
                alpha = nn.Parameter(torch.zeros(weight_shapes[i]))
                nn.init.uniform_(alpha, a=-dyn_mask_init, b=dyn_mask_init)
                self.mask_change_list[name.replace(".", "")] = alpha

        self.weight_names = weight_names_new
        print("Inner loop modulation on: ", self.weight_names)

    def alter_x(self, x, name, t):
        if self.dynamic_mask and t is not None:
            x = x + self.mask_change_list[name.replace(".", "")]*t
        if self.noise_std != 0:
            x = x + torch.randn_like(x)*self.noise_std
        return x

    def forward(self, input_x, t=None):
        masks = {}
        for name, x in zip(self.weight_names, self.weight_mask_list):
            if self.meta_relu_through:
                x = ReluStraightThrough.apply(self.alter_x(x, name, t))
            elif self.meta_sgd_linear:
                x = self.alter_x(x, name, t)
            elif self.meta_relu_sgd:
                x = torch.relu(self.alter_x(x, name, t))
            else:
                x = 0.5*(BinaryLayer.apply(self.alter_x(x, name, t)) + 1)
            masks[name] = x

        if self.x_conv_attention is not None:
            masks = {**masks, **self.x_conv_attention(input_x)}

        return masks

    def analyse_dynamic_mask(self, writer=None, t=None):
        with torch.no_grad():
            count_z = 0
            count_n = 0
            for (name, value) in self.mask_change_list.items():
                mask = 0.5*(BinaryLayer.apply(value) + 1)
                count_z_cur = np.count_nonzero(mask.detach().cpu().numpy())
                count_n_cur = np.prod(mask.shape)
                if writer is not None:
                    writer.add_scalar('Dyn mask sparcity ' + name,
                                        count_z_cur/count_n_cur, t)
                count_z += count_z_cur
                count_n += count_n_cur
        return count_z/count_n
