import torch
from fomoh.hyperdual import HyperTensor as htorch
import torch.nn as nn
from fomoh.util import P_proj_control, projection_to_Hessian, P_proj_control_diag, projection_to_Hessian_diag
import math
from collections import OrderedDict
from fomoh.sample import mh_acc_step, get_directional_derivative

def update_or_add_key(ordered_dict, key, new_value):
    """
    Updates the value of a key in an OrderedDict if the key exists.
    Otherwise, it adds the key with the new value to the end of the OrderedDict.

    Parameters:
    - ordered_dict: The OrderedDict to update.
    - key: The key to update or add.
    - new_value: The new value for the key.
    """
    if key in ordered_dict:
        # If the key exists, update its value
        ordered_dict[key] = new_value
    else:
        # If the key does not exist, add it to the end
        ordered_dict.update({key: new_value})

class Model:
    def __init__(self):
        super(Model, self).__init__()
        self.params = OrderedDict([])
        self.named_params = []
        self.device = "cpu"
        self.training = True
        self.n_params = self.count_n_params()
    
    def add_param(self, name, params):
        # self.named_params.append(name)
        setattr(self, name, torch.nn.Parameter(params, requires_grad=False))
        update_or_add_key(self.params, name, getattr(self, name))
        self.n_params = self.count_n_params()
        self.named_params = list(self.params.keys())
    
    def eval(self):
        self.training = False

    def train(self):
        self.training = True
    
    def count_n_params(self):
        n = 0
        for p in self.params.values():
            n += torch.tensor(p.shape).prod()
        return n
    
    def vec_to_params(self, vec):
        vec_reshape = []
        i = 0
        for p in self.params.values():
            n = torch.tensor(p.shape).prod()
            vec_reshape.append(vec[i:i+n].view(p.shape))
            i += n
        return vec_reshape

    def params_to_vec(self, v = None):
        vec = []
        if v is None:
            for p in self.params.values():
                vec.append(p.reshape(-1))
        else:
            for p in v:
                vec.append(p.reshape(-1))
        return torch.cat(vec)

    def convert_params_to_htorch(self, v1, v2, requires_grad = False):
        hparams = []
        for p, eps1, eps2 in zip(self.params.values(), v1, v2):
            # if Batch == 1:
            hparams.append(htorch(p.requires_grad_(requires_grad), eps1, eps2))
            # else:
            #     p_batch = p[None].repeat(Batch, *[1 for _ in range(len(p.shape))])
            #     hparams.append(htorch(p.requires_grad_(requires_grad), eps1, eps2))
        return hparams
    
    def nn_module_to_htorch_model(self, model, verbose = True):
        for (module_name, module_params), htorch_name  in zip(model.named_parameters(), self.named_params):
            if len(getattr(self, htorch_name).shape) == 2: # Linear weight
                assert(module_params.data.t().shape==getattr(self, htorch_name).shape)
                self.add_param(htorch_name, module_params.data.t())
            elif len(getattr(self, htorch_name).shape) == 4: # Conv2D layer
                assert(module_params.data.shape==getattr(self, htorch_name).shape)
                self.add_param(htorch_name, module_params.data)
            elif len(getattr(self, htorch_name).shape) == 1: # bias weight
                assert(module_params.data.shape==getattr(self, htorch_name).shape)
                self.add_param(htorch_name, module_params.data)
            else:
                raise NotImplementedError("Shape of weight means not implemented layer: " + htorch_name)
        if verbose:
            print("Weights transferred to htorch model.")
    
    def collect_nn_module_grads(self, model):
        gradients = []
        for (module_name, module_params), htorch_name  in zip(model.named_parameters(), self.named_params):
            if len(getattr(self, htorch_name).shape) == 2: # Linear weight
                assert(module_params.grad.t().shape==getattr(self, htorch_name).shape)
                gradients.append(module_params.grad.t())
            elif len(getattr(self, htorch_name).shape) == 4: # Conv2D layer
                assert(module_params.grad.shape==getattr(self, htorch_name).shape)
                gradients.append(module_params.grad)
            elif len(getattr(self, htorch_name).shape) == 1: # bias weight
                assert(module_params.grad.shape==getattr(self, htorch_name).shape)
                gradients.append(module_params.grad)
            else:
                raise NotImplementedError("Shape of weight means not implemented layer: " + htorch_name)
        return gradients
        
    def sample_to_model(self, sample):
        '''
        Sample is a flat vector that needs converting
        '''
        new_params = self.vec_to_params(sample)
        for name, p in zip(self.named_params, new_params):
            self.add_param(name, p)
            

    def to(self, device):
        self.device = device
        for key in self.params:
            self.params[key] = self.params[key].to(device)
        

class LogisticRegressionModel(Model):
    def __init__(self, input_dim, output_dim, bias = True):
        super(LogisticRegressionModel, self).__init__()
        ### Need an initializer
        self.bias = bias
        self.add_param("W", torch.randn(input_dim, output_dim))
        if self.bias:
            self.add_param("b", torch.randn(output_dim))
        self.n_params = self.count_n_params() # Not needed anymore
        
    def __call__(self, x, v1, v2=None, requires_grad=False):
        if v1 is None:
            v1 = [None for _ in self.params]
            v2 = [None for _ in self.params]
        elif v2 is None:
            v2 = v1 #[v.clone() for v in v1]
        params = self.convert_params_to_htorch(v1,v2, requires_grad=requires_grad)
        
        x = x.matmul(params[0])
        if self.bias:
            x += params[1]
        return x.logsoftmax(-1)

class DenseModel(Model):
    def __init__(self, layers = [1,100,1], bias = True):
        super(DenseModel, self).__init__()
        self.layers = layers
        self.bias = bias
        for n in range(len(self.layers)):
            self.add_param(f'W{n+1}', init_kaiming(self.layers[n], self.layers[n+1]))
            if bias:
                self.add_param(f'b{n+1}', init_bias(self.layers[n+1]))
            if n + 1 == len(self.layers) - 1:
                break
                
        self.n_params = self.count_n_params()
        
    def __call__(self, x, v1, v2=None, requires_grad=False):
        if v1 is None:
            v1 = [None for _ in self.params]
            v2 = [None for _ in self.params]
        elif v2 is None:
            v2 = v1 #[v.clone() for v in v1]
        params = self.convert_params_to_htorch(v1, v2, requires_grad=requires_grad)
        
        i = 0
        for n in range(len(self.layers)-1):
            x = x.matmul(params[i])
            i+=1
            if self.bias:
                x += params[i]
                i+=1
            if n < ( len(self.layers) - 2):
                # x = x.sigmoid()
                x = x.relu()
                
        return x#.logsoftmax()
    
class CNNModel(Model):
    def __init__(self, cnn_layers_channels = [1,20,50], cnn_filter_size = 5, dense_layers = [4*4*50,500,10], maxpool_args = [2,2], bias = True):
        super(CNNModel, self).__init__()
        self.cnn_channels = cnn_layers_channels
        self.filter = cnn_filter_size
        self.dense_layers = dense_layers
        self.bias = bias
        self.maxpool_args = maxpool_args
        
        for n in range(len(self.cnn_channels)):
            self.add_param(f'CW{n+1}', weight_init_conv2d(self.cnn_channels[n+1], self.cnn_channels[n], self.filter, self.filter))
            if bias:
                self.add_param(f'Cb{n+1}', bias_init_conv2d(self.cnn_channels[n+1], self.cnn_channels[n], self.filter, self.filter))
            if n + 1 == len(self.cnn_channels) - 1:
                break
        
        for n in range(len(self.dense_layers)):
            self.add_param(f'W{n+1}', init_kaiming(self.dense_layers[n], self.dense_layers[n+1]))
            if bias:
                self.add_param(f'b{n+1}', init_bias(self.dense_layers[n+1]))
            if n + 1 == len(self.dense_layers) - 1:
                break
                
        self.n_params = self.count_n_params()
        
    def __call__(self, x, v1, v2=None, requires_grad=False):
        if v1 is None:
            v1 = [None for _ in self.params]
            v2 = [None for _ in self.params]
        elif v2 is None:
            v2 = v1 #[v.clone() for v in v1]
        params = self.convert_params_to_htorch(v1,v2, requires_grad=requires_grad)
        
        i = 0
        for n in range(len(self.cnn_channels)-1):
            x = x.conv2d(params[i])
            i+=1
            if self.bias:
                x += params[i].view(1, self.cnn_channels[n+1], 1, 1)
                i+=1
            # if n < ( len(self.layers) - 2):
            #     # x = x.sigmoid()
            x = x.relu()
            x = x.maxpool2d(*self.maxpool_args)
        
        x = x.view(-1, self.dense_layers[0])
        
        for n in range(len(self.dense_layers)-1):
            x = x.matmul(params[i])
            i+=1
            if self.bias:
                x += params[i]
                i+=1
            if n < ( len(self.dense_layers) - 2):
                # x = x.sigmoid()
                x = x.relu()
                
        return x#.logsoftmax()
    
 
class CNN_CIFAR10(Model):
    def __init__(self, dropout = True):
        super(CNN_CIFAR10, self).__init__()
        params = cnn_cifar10_params()
        self.dropout = dropout
        
        for key, value in params.items():
            self.add_param(key, value)
                
        self.n_params = self.count_n_params()
    
    def block(self, x, params, count):
        x = x.conv2d(params[count], padding=1)
        x = x + params[count+1].view(1, -1, 1, 1)
        x = batchnorm2d(x, params[count+2], params[count+3])
        x = x.relu()
        count += 4
        x = x.conv2d(params[count], padding=1)
        x = x + params[count+1].view(1, -1, 1, 1)
        x = batchnorm2d(x, params[count+2], params[count+3])
        x = x.relu()
        count += 4
        x = x.maxpool2d(kernel_size=2, stride=2)
        return x, count
    
    def __call__(self, x, v1, v2=None, requires_grad=False):
        if v1 is None:
            v1 = [None for _ in self.params]
            v2 = [None for _ in self.params]
        elif v2 is None:
            v2 = v1 #[v.clone() for v in v1]
        params = self.convert_params_to_htorch(v1,v2, requires_grad=requires_grad)
        
        count = 0
        
        for i in range(3):
            x, count = self.block(x, params, count)
        
        x = x.view(-1, 128*4*4)

        if self.dropout:
            if self.training:
                x = dropout_train(x, 0.5)
            else:
                x = dropout_test(x, 0.5)
        
        x = x.matmul(params[count])
        x = x + params[count+1]
        x = x.relu()

        if self.dropout:
            if self.training:
                x = dropout_train(x, 0.5)
            else:
                x = dropout_test(x, 0.5)

        x = x.matmul(params[count+2])
        x = x + params[count+3]
        return x



class VGG16_CIFAR10(Model):
    def __init__(self):
        super(VGG16_CIFAR10, self).__init__()
        params = vgg16_params(512)
        
        for key, value in params.items():
            self.add_param(key, value)
                
        self.n_params = self.count_n_params()
    
    def block(self, x, params, count, layers=2, maxpool = True):
        
        x = x.conv2d(params[count], padding=1)
        x = x + params[count+1].view(1, -1, 1, 1)
        x = x.relu()
        x = x.conv2d(params[count+2], padding=1)
        x = x + params[count+3].view(1, -1, 1, 1)
        x = x.relu()
        count = count + 4
        if layers == 3:
            x = x.conv2d(params[count], padding=1)
            x = x + params[count+1].view(1, -1, 1, 1)
            x = x.relu()
            count = count + 2
        if maxpool:
            x = x.maxpool2d(kernel_size=2, stride=2)
        return x, count
    
    def __call__(self, x, v1, v2=None, requires_grad=False):
        if v1 is None:
            v1 = [None for _ in self.params]
            v2 = [None for _ in self.params]
        elif v2 is None:
            v2 = v1 #[v.clone() for v in v1]
        params = self.convert_params_to_htorch(v1,v2, requires_grad=requires_grad)
        
        count = 0
        
        for i in range(2):
            x, count = self.block(x, params, count, layers=2)
        
        for i in range(3):
            if i > 0:
                maxpool = False
            else:
                maxpool = True
            x, count = self.block(x, params, count, layers=3, maxpool = maxpool)

        x = x.maxpool2d(kernel_size=2, stride=2, padding=0)
        
        x = x.view(-1, 512)
        
        x = x.matmul(params[count])
        x = x + params[count+1]
        x = x.relu()
        if self.training:
            x = dropout_train(x, 0.5)
        else:
            x = dropout_test(x, 0.5)

        x = x.matmul(params[count+2])
        x = x + params[count+3]
        x = x.relu()
        if self.training:
            x = dropout_train(x, 0.5)
        else:
            x = dropout_test(x, 0.5)

        x = x.matmul(params[count+4])
        x = x + params[count+5]
        return x
    
class Resnet18(Model):
    def __init__(self):
        super(Resnet18, self).__init__()
        params = resnet18_params()
        
        for key, value in params.items():
            self.add_param(key, value)
               
        self.n_params = self.count_n_params()
        
    def block(self, x, params, i, count):
        # get keys:
        conv_keys = [s for s in self.named_params if s.startswith(f'conv{i}')]
        
        xprev = x
        if any("downsample" in s for s in conv_keys):
            stride = 2
            xprev = xprev.conv2d(params[count], stride=2)
            xprev = batchnorm2d(xprev, params[count+1], params[count+2])
            count += 3
        else:
            stride = 1
        
        x = x.conv2d(params[count], stride=stride, padding=1)
        x = batchnorm2d(x, params[count+1], params[count+2])
        x = x.relu()
        x = x.conv2d(params[count+3], padding=1)
        x = batchnorm2d(x, params[count+4], params[count+5])
        x = x + xprev
        x = x.relu()
        
        count += 6

        xprev = x
        x = x.conv2d(params[count], padding=1)
        x = batchnorm2d(x, params[count+1], params[count+2])
        x = x.relu()
        x = x.conv2d(params[count+3], padding=1)
        x = batchnorm2d(x, params[count+4], params[count+5])
        x = x + xprev
        x = x.relu()
        
        count += 6
        
        return x, count
            
    
    def __call__(self, x, v1, v2=None, requires_grad=False):
        if v1 is None:
            v1 = [None for _ in self.params]
            v2 = [None for _ in self.params]
        elif v2 is None:
            v2 = v1 #[v.clone() for v in v1]
        params = self.convert_params_to_htorch(v1,v2, requires_grad=requires_grad)
        
        count = 0
        x = x.conv2d(params[count], stride=2, padding=3)
        x = batchnorm2d(x, params[count+1], params[count+2])
        x = x.relu()
        x = x.maxpool2d(3, stride=2, padding=1)
        count += 3
        
        for i in range(2,6):
            x, count = self.block(x, params, i, count)
    
        x = x.adaptiveavgpool2d((1, 1))
        x = x.view(-1, 512)
        x = x.matmul(params[count])
        x = x + params[count+1]

        return x


def init_uniform(shape, k):
    return -k + torch.rand(shape) * 2*k

def init_bias(out_features):
    k = 1 / math.sqrt(out_features)
    return init_uniform(out_features, k)

def init_kaiming(in_features, out_features):
    a = math.sqrt(5.)
    w = torch.randn(in_features, out_features)
    s = math.sqrt(2. / ((1. + a*a) * in_features))
    return w * s

def weight_init_conv2d(out_channels, in_channels, kernel_height, kernel_width):
    w = torch.zeros(out_channels, in_channels, kernel_height, kernel_width)
    nn.init.kaiming_uniform_(w, a=math.sqrt(5))
    return w

def bias_init_conv2d(out_channels, in_channels, kernel_height, kernel_width):
    num_input_fmaps = in_channels
    receptive_field_size = kernel_height*kernel_width
    fan_in = num_input_fmaps*receptive_field_size
    bound = 1 / math.sqrt(fan_in)
    b = torch.zeros(out_channels)
    nn.init.uniform_(b, -bound, bound)
    return b

def nll_loss(input, target, reduce = "mean"):
    if input.real.dim() != 2 or target.real.dim() != 1:
        raise ValueError('Expecting 2d input and 1d target')
    n, c = input.real.shape[0], input.real.shape[1]
    l = 0.
    for i in range(n):
        t = int(target.real[i])
        l -= input[i, t]
    if reduce == "mean":
        return l / n
    else: # sum
        return l

def dropout_train(a, prob=0.5):
    mask = torch.bernoulli(torch.full(a.shape, 1-prob)).to(a.device)
    return a * mask / (1-prob)

def dropout_test(a, prob=0.5):
    return a #* (1-prob)
    
def mean(a, dim, keepdim=False):
    sm = a.sum(dim, keepdim=keepdim)
    dv = sm / a.shape[dim]
    return dv

def variance(a, dim, unbiased, keepdim=False):
    # This is the two-pass algorithm, see https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
    if unbiased:
        n = a.shape[dim] - 1
    else:
        n = a.shape[dim]
    a2 = a - mean(a, dim=dim, keepdim=True)
    return (a2 ** 2).sum(dim=dim, keepdim=keepdim) / n

def batchnorm2d(a, weight, bias, eps=1e-05):
    if len(a.shape) != 4:
        raise ValueError('Expecting a 4d tensor with shape BxCxHxW')
    num_features = a.shape[1]
    at = a.transpose(0, 1).reshape(num_features, -1)
    m, v = mean(at, dim=1), variance(at, dim=1, unbiased=False)
    res = (a - m.view(1, num_features, 1, 1)) / ((v.view(1, num_features, 1, 1) + eps) ** 0.5 )
    res = res * weight.view(1, num_features, 1, 1) + bias.view(1, num_features, 1, 1)
    return res

def update_model_gradients_(model, external_derivatives):
    # Manually set the .grad attributes of the model's parameters
    for k, external_grad in zip(model.params.keys(), external_derivatives):
        model.params[k].grad = external_grad

def update_model_parameters(params, update_term):
    # Add new gradient term
    for param, p_up in zip(params, update_term):
        param += p_up

def tangent_dropout_mask(x, p=0.1):
    return (torch.rand_like(x) > p)
    

def optimizer_step(model, loss_module, optimizer, n_sample_directions, input, target, device = "cpu", clip_value = 0.0, hess = False, epsilon = 10e-7, tangent_dropout = 0.0, backprop = False):
    loss = 0.
    directional_derivative = torch.zeros(model.n_params).to(device)
    ### Might be able to parallelize this!
    if backprop:
        assert(n_sample_directions==1)
    
    for n in range(n_sample_directions):
        if backprop:
            # Zero gradients:
            for p in model.params.values():
                p.grad = None
            pred = model(input, None, requires_grad=True)
            out = loss_module(pred, target)
            out.real.backward()
            param_directions = []
            for p in model.params.values():
                param_directions.append(p.grad)
            param_directions = model.params_to_vec(param_directions)
            
        else:
            param_directions = torch.randn(model.n_params).to(device)
            if tangent_dropout != 0.0:
                param_directions *= tangent_dropout_mask(param_directions, p=tangent_dropout)
            # param_directions =param_directions/param_directions.norm()
        param_directions_reshaped = model.vec_to_params(param_directions)
        pred = model(input, param_directions_reshaped)
        out = loss_module(pred, target)
        if hess:
            directional_derivative += out.eps1.item() * param_directions / (abs(out.eps1eps2.item()) + epsilon)
        else:
            directional_derivative += out.eps1.item() * param_directions

    loss = out#.real.item() # Loss should be the same at each iteration as only tangents are changing
    directional_derivative /= n_sample_directions
    
    optimizer.zero_grad()
    
    if clip_value != 0.0:
        directional_derivative = torch.clamp(directional_derivative, -abs(clip_value), abs(clip_value))
    grads = model.vec_to_params(directional_derivative)
    
    update_model_gradients_(model, grads)
    optimizer.step()
    return loss, pred.real

def evaluate_model(input, target, param_directions, model, loss_module):
    param_directions_reshaped = model.vec_to_params(param_directions)
    pred = model(input, param_directions_reshaped)
    out = loss_module(pred, target)
    return out, pred


def optimizer_step_mh(model, loss_module, optimizer, n_sample_directions, input, target, loss_prev=None, tangent=None, pred_prev = None, device = "cpu", clip_value = 0.0, hess = False, epsilon = 10e-7):
        
    def get_grad(n_sample_directions):
        loss = 0.
        directional_derivative = torch.zeros(model.n_params).to(device)
        for n in range(n_sample_directions):
            param_directions = torch.randn(model.n_params).to(device)
            out, pred = evaluate_model(input, target, param_directions, model, loss_module, device)
            directional_derivative += get_directional_derivative(out, param_directions, hess, epsilon)
        
        loss = out # Loss should be the same at each iteration as only tangents are changing
        directional_derivative /= n_sample_directions ### Moving average might be better if we get overflow
        return loss, directional_derivative, pred
    
    if loss_prev is None or tangent is None:
        loss, directional_derivative, pred = get_grad(n_sample_directions)
    elif n_sample_directions == 1:
        loss = loss_prev
        directional_derivative = get_directional_derivative(loss_prev, tangent, hess, epsilon)
        pred = pred_prev
    else:
        directional_derivative_1 = get_directional_derivative(loss_prev, tangent, hess, epsilon)
        loss, directional_derivative, pred = get_grad(n_sample_directions - 1)
        directional_derivative = ((directional_derivative * (n_sample_directions - 1)) + directional_derivative_1)/n_sample_directions
    
    optimizer.zero_grad()
    
    if clip_value != 0.0:
        directional_derivative = torch.clamp(directional_derivative, -abs(clip_value), abs(clip_value))
    grads = model.vec_to_params(directional_derivative)
    
    update_model_gradients_(model, grads)
    optimizer.step()
    
    # updated output
    tangent_next = torch.randn(model.n_params).to(device)
    loss_next, pred_next = evaluate_model(input, target, tangent_next, model, loss_module, device)
    
    if mh_acc_step(loss.real, loss_next.real): 
        return loss.real.item(), pred.real, 1, loss_next, tangent_next, pred_next
    else:
        print("Assert SGD")
        optimizer.zero_grad()
        update_model_gradients_(model, [-g for g in grads])
        optimizer.step()
        return loss.real.item(), pred.real, 0, loss, None, pred


def newton_step(model, loss_module, n_sample_directions, input, target, lr=1.0, control = 0., epsilon = 1e-5, beta = 1.0):
    loss = 0.
    directional_derivative = torch.zeros(model.n_params)
    projection_outer_product = torch.zeros(model.n_params, model.n_params)
    for n in range(n_sample_directions):
        param_directions = torch.randn(model.n_params)
        param_directions_reshaped = model.vec_to_params(param_directions)
        pred = model(input, param_directions_reshaped)
        out = loss_module(pred, target)
        loss += out.real.item()
        directional_derivative += out.eps1.item() * param_directions
        projection_outer_product += P_proj_control(param_directions, out.eps1eps2.item(), c = control)

    loss /= n_sample_directions
    directional_derivative /= n_sample_directions
    projection_outer_product /= n_sample_directions
    # Add some jitter
    projection_outer_product += torch.eye(model.n_params) * epsilon
    # print(projection_outer_product)

    ### Full Newton's step:
    H_tilde = projection_to_Hessian(projection_outer_product)
    t = - lr * (beta * torch.linalg.inv(H_tilde) @ directional_derivative.view(-1,1) + (1.-beta) * directional_derivative.view(-1,1))

    additive_update = model.vec_to_params(t)
    
    update_model_parameters(model.params, additive_update)
    
    return loss

def newton_step_diag(model, loss_module, n_sample_directions, input, target, lr=1.0, control = 0., epsilon = 1e-5, beta = 1.0, device = "cpu"):
    loss = 0.
    directional_derivative = torch.zeros(model.n_params).to(device)
    projection_outer_product = torch.zeros(model.n_params).to(device)
    for n in range(n_sample_directions):
        param_directions = torch.randn(model.n_params).to(device)
        param_directions_reshaped = model.vec_to_params(param_directions)
        pred = model(input, param_directions_reshaped)
        out = loss_module(pred, target)
        loss += out.real.item()
        # print("1")
        directional_derivative += out.eps1.item() * param_directions
        projection_outer_product += P_proj_control_diag(param_directions, out.eps1eps2.item(), c = control)
    # print("2")

    loss /= n_sample_directions
    directional_derivative /= n_sample_directions
    projection_outer_product /= n_sample_directions
    # Add some jitter
    # projection_outer_product += torch.eye(model.n_params) * epsilon
    # print(projection_outer_product)

    ### Diag Newton's step:
    H_tilde_diag = projection_to_Hessian_diag(projection_outer_product)
    # print("4")
    t = - lr * (beta * directional_derivative.view(-1,1) / H_tilde_diag.view(-1,1)  + (1.-beta) * directional_derivative.view(-1,1))

    additive_update = model.vec_to_params(t)
    
    update_model_parameters(model.params, additive_update)
    
    return loss


def resnet18_params(device="cpu"):
    return {'conv1_w': weight_init_conv2d(64, 3, 7, 7).to(device),
            'conv1_bn_w': torch.ones(64).to(device), 'conv1_bn_b': torch.zeros(64).to(device),

            # conv2 x2
            'conv2_1a_w': weight_init_conv2d(64, 64, 3, 3).to(device),
            'conv2_1a_bn_w': torch.ones(64).to(device), 'conv2_1a_bn_b': torch.zeros(64).to(device),
            'conv2_1b_w': weight_init_conv2d(64, 64, 3, 3).to(device),
            'conv2_1b_bn_w': torch.ones(64).to(device), 'conv2_1b_bn_b': torch.zeros(64).to(device),

            'conv2_2a_w': weight_init_conv2d(64, 64, 3, 3).to(device),
            'conv2_2a_bn_w': torch.ones(64).to(device), 'conv2_2a_bn_b': torch.zeros(64).to(device),
            'conv2_2b_w': weight_init_conv2d(64, 64, 3, 3).to(device),
            'conv2_2b_bn_w': torch.ones(64).to(device), 'conv2_2b_bn_b': torch.zeros(64).to(device),

            # conv3 x2
            'conv3_1downsample_w': weight_init_conv2d(128, 64, 1, 1).to(device),
            'conv3_1downsample_bn_w': torch.ones(128).to(device), 'conv3_1downsample_bn_b': torch.zeros(128).to(device),
            'conv3_1a_w': weight_init_conv2d(128, 64, 3, 3).to(device),
            'conv3_1a_bn_w': torch.ones(128).to(device), 'conv3_1a_bn_b': torch.zeros(128).to(device),
            'conv3_1b_w': weight_init_conv2d(128, 128, 3, 3).to(device),
            'conv3_1b_bn_w': torch.ones(128).to(device), 'conv3_1b_bn_b': torch.zeros(128).to(device),

            'conv3_2a_w': weight_init_conv2d(128, 128, 3, 3).to(device),
            'conv3_2a_bn_w': torch.ones(128).to(device), 'conv3_2a_bn_b': torch.zeros(128).to(device),
            'conv3_2b_w': weight_init_conv2d(128, 128, 3, 3).to(device),
            'conv3_2b_bn_w': torch.ones(128).to(device), 'conv3_2b_bn_b': torch.zeros(128).to(device),

            # conv4 x2
            'conv4_1downsample_w': weight_init_conv2d(256, 128, 1, 1).to(device),
            'conv4_1downsample_bn_w': torch.ones(256).to(device), 'conv4_1downsample_bn_b': torch.zeros(256).to(device),
            'conv4_1a_w': weight_init_conv2d(256, 128, 3, 3).to(device),
            'conv4_1a_bn_w': torch.ones(256).to(device), 'conv4_1a_bn_b': torch.zeros(256).to(device),
            'conv4_1b_w': weight_init_conv2d(256, 256, 3, 3).to(device),
            'conv4_1b_bn_w': torch.ones(256).to(device), 'conv4_1b_bn_b': torch.zeros(256).to(device),

            'conv4_2a_w': weight_init_conv2d(256, 256, 3, 3).to(device),
            'conv4_2a_bn_w': torch.ones(256).to(device), 'conv4_2a_bn_b': torch.zeros(256).to(device),
            'conv4_2b_w': weight_init_conv2d(256, 256, 3, 3).to(device),
            'conv4_2b_bn_w': torch.ones(256).to(device), 'conv4_2b_bn_b': torch.zeros(256).to(device),

            # conv5 x2
            'conv5_1downsample_w': weight_init_conv2d(512, 256, 1, 1).to(device),
            'conv5_1downsample_bn_w': torch.ones(512).to(device), 'conv5_1downsample_bn_b': torch.zeros(512).to(device),
            'conv5_1a_w': weight_init_conv2d(512, 256, 3, 3).to(device),
            'conv5_1a_bn_w': torch.ones(512).to(device), 'conv5_1a_bn_b': torch.zeros(512).to(device),
            'conv5_1b_w': weight_init_conv2d(512, 512, 3, 3).to(device),
            'conv5_1b_bn_w': torch.ones(512).to(device), 'conv5_1b_bn_b': torch.zeros(512).to(device),

            'conv5_2a_w': weight_init_conv2d(512, 512, 3, 3).to(device),
            'conv5_2a_bn_w': torch.ones(512).to(device), 'conv5_2a_bn_b': torch.zeros(512).to(device),
            'conv5_2b_w': weight_init_conv2d(512, 512, 3, 3).to(device),
            'conv5_2b_bn_w': torch.ones(512).to(device), 'conv5_2b_bn_b': torch.zeros(512).to(device),

            'fc1_w': init_kaiming(512, 10).to(device),
            'fc1_b': init_bias(10).to(device)}

def cnn_cifar10_params(device="cpu"):
    return {'conv1_w': weight_init_conv2d(32, 3, 3, 3).to(device), # 3 -> 32
            'conv1_b': bias_init_conv2d(32, 3, 3, 3).to(device),
            'conv1_bn_w': torch.ones(32).to(device), 'conv1_bn_b': torch.zeros(32).to(device),
            'conv2_w': weight_init_conv2d(32, 32, 3, 3).to(device), # 32 -> 32
            'conv2_b': bias_init_conv2d(32, 32, 3, 3).to(device),
            'conv2_bn_w': torch.ones(32).to(device), 'conv2_bn_b': torch.zeros(32).to(device),

            'conv3_w': weight_init_conv2d(64, 32, 3, 3).to(device), # 32 -> 64
            'conv3_b': bias_init_conv2d(64, 32, 3, 3).to(device),
            'conv3_bn_w': torch.ones(64).to(device), 'conv3_bn_b': torch.zeros(64).to(device),
            'conv4_w': weight_init_conv2d(64, 64, 3, 3).to(device), # 64 -> 64
            'conv4_b': bias_init_conv2d(64, 64, 3, 3).to(device),
            'conv4_bn_w': torch.ones(64).to(device), 'conv4_bn_b': torch.zeros(64).to(device),

            'conv5_w': weight_init_conv2d(128, 64, 3, 3).to(device), # 64 -> 128
            'conv5_b': bias_init_conv2d(128, 64, 3, 3).to(device),
            'conv5_bn_w': torch.ones(128).to(device), 'conv5_bn_b': torch.zeros(128).to(device),
            'conv6_w': weight_init_conv2d(128, 128, 3, 3).to(device), # 128 -> 128
            'conv6_b': bias_init_conv2d(128, 128, 3, 3).to(device),
            'conv6_bn_w': torch.ones(128).to(device), 'conv6_bn_b': torch.zeros(128).to(device),
            
            'fc14_w': init_kaiming(128*4*4, 1024).to(device),
            'fc14_b': init_bias(1024).to(device),
            'fc15_w': init_kaiming(1024, 10).to(device),
            'fc15_b': init_bias(10).to(device)}

def vgg16_params(reshape=25088, device="cpu"):
    return {'conv1_w': weight_init_conv2d(64, 3, 3, 3).to(device), # 3 -> 64
            'conv1_b': bias_init_conv2d(64, 3, 3, 3).to(device),
            'conv2_w': weight_init_conv2d(64, 64, 3, 3).to(device), # 64 -> 64
            'conv2_b': bias_init_conv2d(64, 64, 3, 3).to(device),

            'conv3_w': weight_init_conv2d(128, 64, 3, 3).to(device), # 64 -> 128
            'conv3_b': bias_init_conv2d(128, 64, 3, 3).to(device),
            'conv4_w': weight_init_conv2d(128, 128, 3, 3).to(device), # 128 -> 128
            'conv4_b': bias_init_conv2d(128, 128, 3, 3).to(device),

            'conv5_w': weight_init_conv2d(256, 128, 3, 3).to(device), # 128 -> 256
            'conv5_b': bias_init_conv2d(256, 128, 3, 3).to(device),
            'conv6_w': weight_init_conv2d(256, 256, 3, 3).to(device), # 256 -> 256
            'conv6_b': bias_init_conv2d(256, 256, 3, 3).to(device),
            'conv7_w': weight_init_conv2d(256, 256, 3, 3).to(device), # 256 -> 256
            'conv7_b': bias_init_conv2d(256, 256, 3, 3).to(device),

            'conv8_w': weight_init_conv2d(512, 256, 3, 3).to(device), # 256 -> 512
            'conv8_b': bias_init_conv2d(512, 256, 3, 3).to(device),
            'conv9_w': weight_init_conv2d(512, 512, 3, 3).to(device), # 512 -> 512
            'conv9_b': bias_init_conv2d(512, 512, 3, 3).to(device),
            'conv10_w': weight_init_conv2d(512, 512, 3, 3).to(device), # 512 -> 512
            'conv10_b': bias_init_conv2d(512, 512, 3, 3).to(device),
            'conv11_w': weight_init_conv2d(512, 512, 3, 3).to(device), # 512 -> 512
            'conv11_b': bias_init_conv2d(512, 512, 3, 3).to(device),
            'conv12_w': weight_init_conv2d(512, 512, 3, 3).to(device), # 512 -> 512
            'conv12_b': bias_init_conv2d(512, 512, 3, 3).to(device),
            'conv13_w': weight_init_conv2d(512, 512, 3, 3).to(device), # 512 -> 512
            'conv13_b': bias_init_conv2d(512, 512, 3, 3).to(device),

            'fc14_w': init_kaiming(reshape, 4096).to(device),
            'fc14_b': init_bias(4096).to(device),
            'fc15_w': init_kaiming(4096, 4096).to(device),
            'fc15_b': init_bias(4096).to(device),
            'fc16_w': init_kaiming(4096, 10).to(device),
            'fc16_b': init_bias(10).to(device)}