import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import higher
nonlin_map = {'relu': nn.ReLU, 'leaky_relu': nn.LeakyReLU, 'tanh': nn.Tanh, 'sigmoid': nn.Sigmoid,'identity': nn.Identity,}

def construct_fullcon_layers(indim, dims=[], nonlins=[]):
    assert len(dims) == len(nonlins) and len(dims) > 0
    nonlins = [nonlin_map[_]() for _ in nonlins]
    dims = [indim, *dims]
    net = torch.nn.Sequential()
    for i in range(len(dims)-1):
        net.add_module(name = 'layer{0:d}'.format(i),
            module = torch.nn.Sequential(
                torch.nn.Linear(dims[i], dims[i+1]), 
                nonlins[i]))
    return net

def module2functional(torch_net):    
    f_net = higher.patch.make_functional(module=torch_net)
    f_net._fast_params = [[]]
    f_net.track_higher_grads = False
    for m in f_net.modules():
        if isinstance(m, torch.nn.BatchNorm2d):
            m.running_mean = None
            m.running_var = None
            m.num_batches_tracked = None
    return f_net
