import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import time
import copy
import numpy as np


#only works for ResNet
def define_struct_to_prune(net,mod_names):
    structs_to_prune_iterative = []
    for mod_name in mod_names:
        structs_to_prune_iterative.append((dict(net.named_modules())['.'.join(mod_name.split('.')[:-1])],mod_name.split('.')[-1]))
    structs_to_prune_iterative = tuple(structs_to_prune_iterative)
    return structs_to_prune_iterative


def calc_first_order_unstructured(model, trainloader, num_parameters, layer_names,device, samples_N =100000):

    total_number_of_train_datapoints = len(trainloader.dataset)
    starttime = time.time()
    gradient_batch_size = 100
    criterion = nn.CrossEntropyLoss()
    dataloader = torch.utils.data.DataLoader(trainloader.dataset, batch_size= gradient_batch_size ,shuffle=False, num_workers=4)
    first_order = None
    for i, t_data in enumerate(dataloader, 0):
        if i% 100 == 0 and i >0:
            print('Calculated up to datapoint', i*gradient_batch_size)
            print('Time needed: ', time.time() - starttime)
        data_in, label_in = t_data[0].to(device), t_data[1].to(device)
        model.zero_grad()
        criterion(model(data_in), label_in).backward()
        net_parameters = dict(model.named_parameters())
        g_k = None
        if (i*gradient_batch_size) % samples_N == 0 and i > 0:
            break

        for l_name in layer_names:
            weights_l = net_parameters[l_name]
            grad_weights_l = weights_l.grad

            weights_l = weights_l.cpu().detach().numpy().reshape(-1)
            grad_weights_l = grad_weights_l.cpu().detach().numpy().reshape(-1)

            grad_k = np.multiply(grad_weights_l, weights_l)
            if g_k is None:
                g_k = grad_k
            else:
                g_k= np.concatenate((g_k,grad_k),axis=0)

        if first_order is None:
            first_order = (np.array(g_k) * gradient_batch_size)
        else:
            first_order += (np.array(g_k) * gradient_batch_size)

    first_order = first_order * float(1 / total_number_of_train_datapoints)

    print('Time needed to calculate the first order: ', time.time() - starttime)
    return first_order



def calc_first_order(model, trainloader, num_strucs, layer_names,device, samples_N =100000):

    total_number_of_train_datapoints = len(trainloader.dataset)
    starttime = time.time()
    gradient_batch_size = 100
    criterion = nn.CrossEntropyLoss()
    dataloader = torch.utils.data.DataLoader(trainloader.dataset, batch_size= gradient_batch_size ,shuffle=False, num_workers=4)
    first_order = np.zeros([num_strucs ,num_strucs])
    for i, t_data in enumerate(dataloader, 0):
        if i% 100 == 0 and i >0:
            print('Calculated up to datapoint', i*gradient_batch_size)
            print('Time needed: ', time.time() - starttime)
        data_in, label_in = t_data[0].to(device), t_data[1].to(device)
        model.zero_grad()
        criterion(model(data_in), label_in).backward()
        net_parameters = dict(model.named_parameters())
        g_k = []
        if (i*gradient_batch_size) % samples_N == 0 and i > 0:
            break

        for l_name in layer_names:
            weights_l = net_parameters[l_name]
            grad_weights_l = weights_l.grad

            weights_l = weights_l.cpu().detach().numpy()
            grad_weights_l = grad_weights_l.cpu().detach().numpy()
            num_of_struct = int(weights_l.shape[0])

            for struct_i in range(num_of_struct):
                grad_k = np.dot(grad_weights_l[struct_i].reshape(-1), weights_l[struct_i].reshape(-1))
                g_k.append(grad_k)

        first_order += np.diag(np.array(g_k)) * gradient_batch_size

    first_order = first_order * float(1 / total_number_of_train_datapoints)

    print('Time needed to calculate the first order: ', time.time() - starttime)
    return first_order


def calc_second_order(model, trainloader, num_strucs, layer_names,num_layers, device,num_outputs, samples_N = 3000):

    total_number_of_train_datapoints = len(trainloader.dataset)
    if samples_N>= total_number_of_train_datapoints:
        raise Exception('Not a valid value for n! n must be smaller than length of train data set')

    net_softmax = copy.deepcopy(model)
    net_plain = copy.deepcopy(model)

    hess_structures = torch.zeros([num_strucs, num_strucs]).to(device)
    hess_structures_2 = np.zeros([num_strucs, num_strucs])
    starttime = time.time()


    # define dataloader with batchsize=1 to enable autograd
    dataloader = torch.utils.data.DataLoader(trainloader.dataset, batch_size=1, shuffle=False, num_workers=4)
    for i, t_data in enumerate(dataloader, 0):
        #print time after 100 datapoints
        if i % 100 == 0 and i > 0:
            print('Calculated up to datapoint', i)
            print('Time: ', time.time() - starttime)
        #end loop after n datapoints
        if i % samples_N == 0 and i > 0:
            break

        #ensure gradients are zero
        net_plain.zero_grad()
        net_softmax.zero_grad()

        t_inputs = t_data[0].to(device)

        sy_i = F.softmax(net_softmax(t_inputs), dim = 1).view(-1)
        y_i = net_plain(t_inputs).view(-1)

        #loop over outputs
        for idx_o in range(len(y_i)):


            net_plain.zero_grad()
            net_softmax.zero_grad()

            y_i[idx_o].backward(retain_graph=True)
            sy_i[idx_o].backward(retain_graph=True)

            v_ki = []
            sv_ki = []
            #v_ki_torch = None
            #sv_ki_torch = None
            # maybe a small speedup ??? by removing this loop ???
            time_1 = time.time()
            for l_idx in range(num_layers):

                layer_name = layer_names[l_idx]

                W_i = dict(net_plain.named_parameters())[layer_name]
                s_W_i = dict(net_softmax.named_parameters())[layer_name]

                s_grad_Wi = s_W_i.grad
                grad_Wi = W_i.grad

                #W_i_re = W_i.view(W_i.shape[0],-1)
                #grad_Wi_re = grad_Wi.view(W_i.shape[0],-1)
                #s_grad_Wi_re = s_grad_Wi.view(W_i.shape[0],-1)

                W_i = W_i.data.cpu().numpy()
                grad_Wi = grad_Wi.data.cpu().numpy()
                s_grad_Wi = s_grad_Wi.data.cpu().numpy()
                #if v_ki_torch is None:
                #    v_ki_torch = torch.diag(torch.matmul(grad_Wi_re,W_i_re.T))
                #    sv_ki_torch = torch.diag(torch.matmul(s_grad_Wi_re, W_i_re.T))
                #else:
                #    v_ki_torch = torch.cat((v_ki_torch,torch.diag(torch.matmul(grad_Wi_re,W_i_re.T))),0)
                #    sv_ki_torch = torch.cat((sv_ki_torch, torch.diag(torch.matmul(s_grad_Wi_re, W_i_re.T))),0)

                #del W_i
                #del s_W_i

                #del s_grad_Wi
                #del grad_Wi

                #del W_i_re
                #del grad_Wi_re
                #del s_grad_Wi_re

                for idx in range(W_i.shape[0]):
                    inter_v = np.dot(grad_Wi[idx].reshape(-1), W_i[idx].reshape(-1))
                    inter_sv = np.dot(s_grad_Wi[idx].reshape(-1), W_i[idx].reshape(-1))
                    v_ki.append(inter_v)
                    sv_ki.append(inter_sv)

            #time_2 = time.time()
            #v_ki = np.array(v_ki).reshape(-1, 1)
            #sv_ki = np.array(sv_ki).reshape(-1, 1)
            #hess_structures_2 += np.matmul(v_ki, sv_ki.T)
            v_ki_torch = torch.from_numpy(np.array(v_ki).reshape(-1, 1)).to(device)
            sv_ki_torch = torch.from_numpy(np.array(sv_ki).reshape(-1, 1)).to(device)
            #v_ki_torch = v_ki_torch.view(-1,1)
            #sv_ki_torch = sv_ki_torch.view(-1,1)

            hess_structures += torch.matmul(v_ki_torch, sv_ki_torch.T)
            del v_ki_torch
            del sv_ki_torch
            #time_3 = time.time()

            #print(time_2 - time_1)
            #print(time_3 - time_2)
            #import pdb
            #pdb.set_trace()
    hess_structures = hess_structures.data.cpu().numpy()
    print('Time needed to calculate the second order: ', time.time() - starttime)
    hess_structures = 0.5*hess_structures * float(1 / total_number_of_train_datapoints)
    return hess_structures



def underestimated_calc_parameters_pruned_resnet(net_in, pruned_parameters):
    struc_array =[]
    pruned_struc_array =[]
    previous_width = 3
    previous_width_pruned = 3
    total_num_params = 0
    total_pruned_num_params = 0
    for layer in pruned_parameters:

        layer_shape = copy.deepcopy(np.array(list(layer[0].weight.shape)))
        layer_shape_2 = copy.deepcopy(layer_shape)
        struc_array.append(layer_shape)
        layer_structs = layer[0].weight
        num_structs = layer[0].weight.shape[0]

        #struc_size instead
        struc_size =  layer[0].weight.shape[-1]*layer[0].weight.shape[-2]*layer[0].weight.shape[-3]
        #filter_size = layer[0].weight.shape[-1]*layer[0].weight.shape[-2]
        active_strucs = 0
        for weights in layer_structs:
            active_strucs += (torch.sum(weights) != 0)
        layer_shape_2[0]= active_strucs
        pruned_struc_array.append(layer_shape_2)
        #total_num_params += num_structs * previous_width * filter_size
        #total_pruned_num_params += active_strucs.cpu().numpy() * previous_width_pruned * filter_size

        total_num_params += num_structs * struc_size
        total_pruned_num_params += active_strucs.cpu().numpy() * struc_size

        total_num_params += 2 * num_structs
        total_pruned_num_params += 2 * active_strucs.cpu().numpy()

        previous_width = num_structs
        previous_width_pruned = active_strucs.cpu().numpy()

    # not sure how to automate this
    shape_last_layer = list(net_in.parameters())[-2].shape

    total_pruned_num_params += shape_last_layer[0] * shape_last_layer[1] + shape_last_layer[0]
    total_num_params += shape_last_layer[0] * shape_last_layer[1] + shape_last_layer[0]
    print('Fraction: ' + str(total_pruned_num_params / float(total_num_params)))
    return total_pruned_num_params / float(total_num_params),np.array(struc_array),np.array(pruned_struc_array)


def vgg19_param_count(net_in, pruned_parameters):
    previous_width = 3
    previous_width_pruned = 3
    total_num_params = 0
    total_pruned_num_params = 0
    cfg_vgg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M',
          512, 512, 512, 512, 'M']
    total_num_macs = 0
    total_pruned_num_macs = 0
    picture_size = 32
    pic_counter = 0
    for layer in pruned_parameters:

        if cfg_vgg[pic_counter]=='M':
            picture_size = picture_size//2
            pic_counter += 1
        pic_counter += 1


        layer_shape = copy.deepcopy(np.array(list(layer[0].weight.shape)))
        layer_shape_2 = copy.deepcopy(layer_shape)
        #struc_array.append(layer_shape)
        layer_structs = layer[0].weight
        num_structs = layer[0].weight.shape[0]

        #struc_size instead
        struc_size =  layer[0].weight.shape[-1]*layer[0].weight.shape[-2]*layer[0].weight.shape[-3]
        filter_size = layer[0].weight.shape[-1]*layer[0].weight.shape[-2]
        active_strucs = 0
        for weights in layer_structs:
            active_strucs += (torch.sum(weights) != 0)
        layer_shape_2[0]= active_strucs
        #pruned_struc_array.append(layer_shape_2)

        total_num_params += num_structs * previous_width * filter_size
        total_pruned_num_params += active_strucs.cpu().numpy() * previous_width_pruned * filter_size

        total_num_macs += num_structs * previous_width * filter_size*picture_size*picture_size
        total_pruned_num_macs += active_strucs.cpu().numpy() * previous_width_pruned * filter_size*picture_size*picture_size

        total_num_params += num_structs * struc_size
        #total_pruned_num_params += active_strucs.cpu().numpy() * struc_size

        total_num_params += 2 * num_structs
        total_pruned_num_params += 2 * active_strucs.cpu().numpy()

        previous_width = num_structs
        previous_width_pruned = active_strucs.cpu().numpy()

    # not sure how to automate this
    shape_last_layer = list(net_in.parameters())[-2].shape

    total_pruned_num_params += shape_last_layer[0] * shape_last_layer[1] + shape_last_layer[0]
    total_num_params += shape_last_layer[0] * shape_last_layer[1] + shape_last_layer[0]

    total_num_macs += shape_last_layer[0] * shape_last_layer[1]
    total_pruned_num_macs += shape_last_layer[0] * shape_last_layer[1]

    #print('Fraction: ' + str(total_pruned_num_params / float(total_num_params)))
    return total_pruned_num_params , total_pruned_num_macs


def xor_oh(channels1, channels2):  # both one-hot vectors must have the same length
    # print('XOR:',len(channels1),',',len(channels2))
    combinedchannels = np.zeros(len(channels1), dtype=np.int64)
    for i in range(len(channels1)):
        if channels1[i] == 1 or channels2[i] == 1:
            combinedchannels[i] = 1
        else:
            combinedchannels[i] = 0
    return combinedchannels


def resnet56_param_count(layer_names, pruned_channels_oh, pruned_strucs, unpruned_strucs, ratio_num=0,
                         num_classes=10):
    pic = np.zeros((56,),
                   dtype=np.object)  # reserve array for all 56 pictures (will contain one-hot encoding of channels present)

    pic_nr = 0  # index of input picture (first picture in list)
    pic[pic_nr] = np.ones(unpruned_strucs[ratio_num, 0, 1],dtype=np.int64)  # filters in input picture, all active (not pruned)
    # print('First picture=',pic[pic_nr])

    pic_wh = 32  # width and height of picture (input size = output size for conv layers)
    mac_count = 0

    total_param = 0  # total parameter count



    for lay in range(len(layer_names)):
        # print('Layer lay=',lay,', ',layer_names[lay])

        if layer_names[lay][9:13] == 'down':  # downsampling layer

            total_param += np.sum(pic[pic_nr - 1]) * np.sum(pruned_channels_oh[ ratio_num, lay]) * \
                           pruned_strucs[ratio_num, lay, -2] * pruned_strucs[ ratio_num, lay, -1]
            total_param += 2 * np.sum(pruned_channels_oh[ ratio_num, lay])  # batchnorm parameters

            # MACs
            mac_count += pic_wh * pic_wh * np.sum(pic[pic_nr - 1]) * np.sum(
                pruned_channels_oh[ratio_num, lay]) * pruned_strucs[ratio_num, lay, -2] * \
                         pruned_strucs[ratio_num, lay, -1]

            pic_nr += 1
            pic[pic_nr] = xor_oh(pruned_channels_oh[ratio_num, lay],
                                 pruned_channels_oh[ratio_num, lay - 1])  # MAKE ZERO-ONE VECT

        else:  # it's a 3x3 convolutional layer(first conv layer, any conv1 layer, or any conv2 layer)
            total_param += np.sum(pic[pic_nr]) * np.sum(pruned_channels_oh[ratio_num, lay]) * pruned_strucs[
                ratio_num, lay, -2] * pruned_strucs[ratio_num, lay, -1]
            total_param += 2 * np.sum(pruned_channels_oh[ratio_num, lay])  # batchnorm parameters

            # MACs
            if layer_names[lay][9:14] == 'conv1' and unpruned_strucs[ratio_num, lay, 0] != unpruned_strucs[ratio_num, lay, 1]:
                pic_wh = pic_wh // 2

            mac_count += pic_wh * pic_wh * np.sum(pic[pic_nr]) * np.sum(pruned_channels_oh[ratio_num, lay]) * \
                         pruned_strucs[ratio_num, lay, -2] * pruned_strucs[ratio_num, lay, -1]

            if layer_names[lay][9:14] == 'conv2':  # conv2 layer

                if lay < (len(layer_names) - 1) and layer_names[lay + 1][9:13] == 'down':
                    pass #do nothing. If next layer is a pruned downsampling layer, then deal with it in next step
                else:

                    if unpruned_strucs[ratio_num, lay, 0] == unpruned_strucs[ratio_num, lay - 1, 1]:  # skip connection is identity
                        pic_nr += 1
                        pic[pic_nr] = xor_oh(pruned_channels_oh[ratio_num, lay],
                                             pic[pic_nr - 2])  # MAKE ZERO-ONE VECT
                    else:  # skip connection is a downsampling layer that has not been subjected to pruning
                        total_param += np.sum(pic[pic_nr - 1]) * unpruned_strucs[ratio_num, lay, 0]  # 1x1 kernel size
                        total_param += 2 * unpruned_strucs[ratio_num, lay, 0]  # batchnorm parameters

                        # MACs
                        mac_count += pic_wh * pic_wh * np.sum(pic[pic_nr - 1]) * unpruned_strucs[ratio_num, lay, 0]  # 1x1 kernel size

                        pic_nr += 1
                        pic[pic_nr] = np.ones(unpruned_strucs[ratio_num, lay, 0], dtype=np.int64)
            else:  # first conv layer, or any conv1 layer
                pic_nr += 1
                pic[pic_nr] = pruned_channels_oh[ratio_num, lay]

    total_param += (1 + np.sum(pic[pic_nr])) * num_classes

    # MAC
    mac_count += np.sum(pic[pic_nr]) * num_classes

    return total_param, mac_count


def resnet32_param_count(layer_names, pruned_channels_oh, pruned_strucs, unpruned_strucs, ratio_num=0,
                         num_classes=10):
    pic = np.zeros((32,),
                   dtype=np.object)  # reserve array for all 56 pictures (will contain one-hot encoding of channels present)
    pic_nr = 0  # index of input picture (first picture in list)
    pic[pic_nr] = np.ones(unpruned_strucs[ ratio_num, 0, 1],dtype=np.int64)  # filters in input picture, all active (not pruned)
    # print('First picture=',pic[pic_nr])

    pic_wh = 32  # width and height of picture (input size = output size for conv layers)
    mac_count = 0

    total_param = 0  # total parameter count



    for lay in range(len(layer_names)):
        # print('Layer lay=',lay,', ',layer_names[lay])

        if layer_names[lay][9:13] == 'down':  # downsampling layer

            total_param += np.sum(pic[pic_nr - 1]) * np.sum(pruned_channels_oh[ ratio_num, lay]) * \
                           pruned_strucs[ ratio_num, lay, -2] * pruned_strucs[ ratio_num, lay, -1]
            total_param += 2 * np.sum(pruned_channels_oh[ ratio_num, lay])  # batchnorm parameters

            # MACs
            mac_count += pic_wh * pic_wh * np.sum(pic[pic_nr - 1]) * np.sum(
                pruned_channels_oh[ ratio_num, lay]) * pruned_strucs[ ratio_num, lay, -2] * \
                         pruned_strucs[ ratio_num, lay, -1]

            pic_nr += 1
            pic[pic_nr] = xor_oh(pruned_channels_oh[ ratio_num, lay],
                                 pruned_channels_oh[ ratio_num, lay - 1])  # MAKE ZERO-ONE VECT

        else:  # it's a 3x3 convolutional layer(first conv layer, any conv1 layer, or any conv2 layer)
            total_param += np.sum(pic[pic_nr]) * np.sum(pruned_channels_oh[ ratio_num, lay]) * pruned_strucs[
                 ratio_num, lay, -2] * pruned_strucs[ ratio_num, lay, -1]
            total_param += 2 * np.sum(pruned_channels_oh[ ratio_num, lay])  # batchnorm parameters

            # MACs
            if layer_names[lay][9:14] == 'conv1' and unpruned_strucs[ ratio_num, lay, 0] != unpruned_strucs[
                 ratio_num, lay, 1]:
                pic_wh = pic_wh // 2

            mac_count += pic_wh * pic_wh * np.sum(pic[pic_nr]) * np.sum(pruned_channels_oh[ ratio_num, lay]) * \
                         pruned_strucs[ ratio_num, lay, -2] * pruned_strucs[ ratio_num, lay, -1]

            if layer_names[lay][9:14] == 'conv2':  # conv2 layer

                if lay < (len(layer_names) - 1) and layer_names[lay + 1][9:13] == 'down':
                    pass #do nothing. If next layer is a pruned downsampling layer, then deal with it in next step
                else:

                    if unpruned_strucs[ ratio_num, lay, 0] == unpruned_strucs[
                         ratio_num, lay - 1, 1]:  # skip connection is identity
                        pic_nr += 1
                        pic[pic_nr] = xor_oh(pruned_channels_oh[ ratio_num, lay],
                                             pic[pic_nr - 2])  # MAKE ZERO-ONE VECT
                    else:  # skip connection is a downsampling layer that has not been subjected to pruning
                        total_param += np.sum(pic[pic_nr - 1]) * unpruned_strucs[
                             ratio_num, lay, 0]  # 1x1 kernel size
                        total_param += 2 * unpruned_strucs[ ratio_num, lay, 0]  # batchnorm parameters

                        # MACs
                        mac_count += pic_wh * pic_wh * np.sum(pic[pic_nr - 1]) * unpruned_strucs[
                             ratio_num, lay, 0]  # 1x1 kernel size

                        pic_nr += 1
                        pic[pic_nr] = np.ones(unpruned_strucs[ ratio_num, lay, 0], dtype=np.int64)
            else:  # first conv layer, or any conv1 layer
                pic_nr += 1
                pic[pic_nr] = pruned_channels_oh[ ratio_num, lay]

    total_param += (1 + np.sum(pic[pic_nr])) * num_classes

    # MAC
    mac_count += np.sum(pic[pic_nr]) * num_classes
    return total_param, mac_count












def densenet40real_param_count(layer_names, pruned_channels_oh, pruned_strucs, unpruned_strucs, algo_num=0, ratio_num=0,
                         num_classes=10):


    #    pruned_strucs, algo_num=0, ratio_num=0, num_classes=10):
    pruned = pruned_strucs[ ratio_num]

    inchannels = pruned[0, 1]
    param_count = 0
    pic_wh = 32  # width and height of picture (input size = output size for conv layers)
    mac_count = 0

    for lay in range(pruned.shape[0]):
        param_count += inchannels * pruned[lay, 0] * pruned[lay, -1] * pruned[lay, -2]  # conv layer
        param_count += 2 * pruned[lay, 0]  # batchnorm layer

        mac_count += pic_wh * pic_wh * inchannels * pruned[lay, 0] * pruned[lay, -1] * pruned[lay, -2]

        if pruned[lay, -1] == 1 or lay == 0:  # transition layer or initial conv layer
            inchannels = pruned[lay, 0]
        else:
            inchannels += pruned[lay, 0]

        if pruned[lay, -1] == 1:  # after TransitionBlock: reduce picture size by F.avg_pool2d(out, 2)
            pic_wh = pic_wh // 2
        if lay == pruned.shape[0] - 1:  # after last conv layer: reduce picture size to 1 by avg_pool2d(out, 8)
            pic_wh = pic_wh // 8

    param_count += (inchannels + 1) * num_classes  # linear layer

    mac_count += inchannels * num_classes


    return param_count, mac_count


def resnet50_param_count(layer_names, pruned_channels_oh, pruned_strucs, unpruned_strucs, algo_num=0, ratio_num=0,
                         num_classes=10):
    pic = np.zeros((50,),
                   dtype=np.object)  # reserve array for all 50 pictures (will contain one-hot encoding of channels present)

    pic_nr = 0  # index of input picture (first picture in list)
    pic[pic_nr] = np.ones(unpruned_strucs[ ratio_num, 0, 1],
                          dtype=np.int64)  # filters in input picture, all active (not pruned)

    pic_wh = 224  # width and height of picture (input size = output size for conv layers)
    pic_wh = pic_wh // 2  # half of original picture size, since first layer has already stride=2 (unlike ResNet56)
    mac_count = 0

    alternative_mac = 0  # don't take picture size into account (not fully implemented: only for ResNet50)
    alternative_mac_last_filters = 3

    total_param = 0  # total parameter count

    for lay in range(len(layer_names)):

        if layer_names[lay][9:13] == 'down':  # downsampling layer
            total_param += np.sum(pic[pic_nr - 1]) * np.sum(pruned_channels_oh[ ratio_num, lay]) * \
                           pruned_strucs[ ratio_num, lay, -2] * pruned_strucs[ ratio_num, lay, -1]
            total_param += 2 * np.sum(pruned_channels_oh[ ratio_num, lay])  # batchnorm parameters

            # MACs
            mac_count += pic_wh * pic_wh * np.sum(pic[pic_nr - 1]) * np.sum(
                pruned_channels_oh[ ratio_num, lay]) * pruned_strucs[ ratio_num, lay, -2] * \
                         pruned_strucs[ ratio_num, lay, -1]

            pic_nr += 1
            pic[pic_nr] = xor_oh(pruned_channels_oh[ ratio_num, lay],
                                 pruned_channels_oh[ ratio_num, lay - 1])  # MAKE ZERO-ONE VECT

        else:  # it's a 3x3 convolutional layer(first conv layer, any conv1/conv2/conv3 layer)
            total_param += np.sum(pic[pic_nr]) * np.sum(pruned_channels_oh[ ratio_num, lay]) * pruned_strucs[
                 ratio_num, lay, -2] * pruned_strucs[ ratio_num, lay, -1]
            total_param += 2 * np.sum(pruned_channels_oh[ ratio_num, lay])  # batchnorm parameters

            # MACs
            if layer_names[lay][9:14] == 'conv1' and (unpruned_strucs[ ratio_num, lay, 0] == (
                    unpruned_strucs[ ratio_num, lay, 1] // 2)):
                pic_wh = pic_wh // 2

            mac_count += pic_wh * pic_wh * np.sum(pic[pic_nr]) * np.sum(pruned_channels_oh[ ratio_num, lay]) * \
                         pruned_strucs[ ratio_num, lay, -2] * pruned_strucs[ ratio_num, lay, -1]

            alternative_mac += pic_wh * pic_wh * pruned_strucs[
                 ratio_num, lay, 0] * alternative_mac_last_filters * pruned_strucs[
                                 ratio_num, lay, -2] * pruned_strucs[ ratio_num, lay, -1]
            alternative_mac_last_filters = pruned_strucs[ ratio_num, lay, 0]

            if layer_names[lay][0:5] == 'conv1':  # MaxPool with stride = 2 after first conv layer (in ResNet18)
                pic_wh = pic_wh // 2

            if layer_names[lay][9:14] == 'conv3':  # conv3 layer (last one in Bottleneck block)

                if lay < (len(layer_names) - 1) and layer_names[lay + 1][9:13] == 'down':
                    pass
                else:

                    if unpruned_strucs[ ratio_num, lay, 0] == unpruned_strucs[
                         ratio_num, lay - 2, 1]:  # skip connection is identity
                        pic_nr += 1
                        pic[pic_nr] = xor_oh(pruned_channels_oh[ ratio_num, lay],
                                             pic[pic_nr - 3])  # MAKE ZERO-ONE VECT
                    else:  # skip connection is a downsampling layer that has not been subjected to pruning
                        total_param += np.sum(pic[pic_nr - 2]) * unpruned_strucs[
                             ratio_num, lay, 0]  # 1x1 kernel size
                        total_param += 2 * unpruned_strucs[ ratio_num, lay, 0]  # batchnorm parameters

                        # MACs
                        mac_count += pic_wh * pic_wh * np.sum(pic[pic_nr - 2]) * unpruned_strucs[
                             ratio_num, lay, 0]  # 1x1 kernel size

                        alternative_mac += pic_wh * pic_wh * pruned_strucs[ ratio_num, lay - 3, 0] * \
                                           unpruned_strucs[ ratio_num, lay, 0]

                        pic_nr += 1
                        pic[pic_nr] = np.ones(unpruned_strucs[ ratio_num, lay, 0], dtype=np.int64)
            else:  # first conv layer, or any conv1 or and conv2 layer
                pic_nr += 1
                pic[pic_nr] = pruned_channels_oh[ ratio_num, lay]

    total_param += (1 + np.sum(pic[pic_nr])) * num_classes

    # MAC
    mac_count += np.sum(pic[pic_nr]) * num_classes

    alternative_mac += (alternative_mac_last_filters + 1) * num_classes
    print('alternative_mac=', alternative_mac)

    print('total_param end =', total_param)
    print('MAC count end =', mac_count)

    return total_param, mac_count


def resnet18_param_count(layer_names, pruned_channels_oh, pruned_strucs, unpruned_strucs, algo_num=0, ratio_num=0,
                         num_classes=10):
    pic = np.zeros((18,),
                   dtype=np.object)  # reserve array for all 56 pictures (will contain one-hot encoding of channels present)

    pic_nr = 0  # index of input picture (first picture in list)
    pic[pic_nr] = np.ones(unpruned_strucs[ ratio_num, 0, 1],
                          dtype=np.int64)  # filters in input picture, all active (not pruned)

    pic_wh = 224  # width and height of picture (input size = output size for conv layers)
    pic_wh = pic_wh // 2  # half of original picture size, since first layer has already stride=2 (unlike ResNet56)
    mac_count = 0

    alternative_mac = 0  # don't take picture size into account (not fully implemented: only for ResNet18)
    alternative_mac_last_filters = 3

    total_param = 0  # total parameter count

    for lay in range(len(layer_names)):

        if layer_names[lay][9:13] == 'down':  # downsampling layer

            total_param += np.sum(pic[pic_nr - 1]) * np.sum(pruned_channels_oh[ ratio_num, lay]) * \
                           pruned_strucs[ ratio_num, lay, -2] * pruned_strucs[ ratio_num, lay, -1]
            total_param += 2 * np.sum(pruned_channels_oh[ ratio_num, lay])  # batchnorm parameters
            # print("total_param=",total_param)

            # MACs
            mac_count += pic_wh * pic_wh * np.sum(pic[pic_nr - 1]) * np.sum(
                pruned_channels_oh[ ratio_num, lay]) * pruned_strucs[ ratio_num, lay, -2] * \
                         pruned_strucs[ ratio_num, lay, -1]

            pic_nr += 1
            pic[pic_nr] = xor_oh(pruned_channels_oh[ ratio_num, lay],
                                 pruned_channels_oh[ ratio_num, lay - 1])  # MAKE ZERO-ONE VECT

        else:  # it's a 3x3 convolutional layer(first conv layer, any conv1 layer, or any conv2 layer)
            total_param += np.sum(pic[pic_nr]) * np.sum(pruned_channels_oh[ ratio_num, lay]) * pruned_strucs[
                 ratio_num, lay, -2] * pruned_strucs[ ratio_num, lay, -1]
            total_param += 2 * np.sum(pruned_channels_oh[ ratio_num, lay])  # batchnorm parameters

            # MACs
            if layer_names[lay][9:14] == 'conv1' and unpruned_strucs[ ratio_num, lay, 0] != unpruned_strucs[
                 ratio_num, lay, 1]:
                pic_wh = pic_wh // 2

            mac_count += pic_wh * pic_wh * np.sum(pic[pic_nr]) * np.sum(pruned_channels_oh[ ratio_num, lay]) * \
                         pruned_strucs[ ratio_num, lay, -2] * pruned_strucs[ ratio_num, lay, -1]

            alternative_mac += pic_wh * pic_wh * pruned_strucs[
                 ratio_num, lay, 0] * alternative_mac_last_filters * pruned_strucs[
                                 ratio_num, lay, -2] * pruned_strucs[ ratio_num, lay, -1]
            alternative_mac_last_filters = pruned_strucs[ ratio_num, lay, 0]

            if layer_names[lay][0:5] == 'conv1':  # MaxPool with stride = 2 after first conv layer (in ResNet18)
                pic_wh = pic_wh // 2

            if layer_names[lay][9:14] == 'conv2':  # conv2 layer

                if lay < (len(layer_names) - 1) and layer_names[lay + 1][9:13] == 'down':
                    pass
                else:
                    if unpruned_strucs[ ratio_num, lay, 0] == unpruned_strucs[
                         ratio_num, lay - 1, 1]:  # skip connection is identity
                        pic_nr += 1
                        pic[pic_nr] = xor_oh(pruned_channels_oh[ ratio_num, lay],
                                             pic[pic_nr - 2])  # MAKE ZERO-ONE VECT
                    else:  # skip connection is a downsampling layer that has not been subjected to pruning
                        total_param += np.sum(pic[pic_nr - 1]) * unpruned_strucs[
                             ratio_num, lay, 0]  # 1x1 kernel size
                        total_param += 2 * unpruned_strucs[ ratio_num, lay, 0]  # batchnorm parameters

                        # MACs
                        mac_count += pic_wh * pic_wh * np.sum(pic[pic_nr - 1]) * unpruned_strucs[
                             ratio_num, lay, 0]  # 1x1 kernel size

                        alternative_mac += pic_wh * pic_wh * pruned_strucs[ ratio_num, lay - 2, 0] * \
                                           unpruned_strucs[ ratio_num, lay, 0]

                        pic_nr += 1
                        pic[pic_nr] = np.ones(unpruned_strucs[ ratio_num, lay, 0], dtype=np.int64)
            else:  # first conv layer, or any conv1 layer
                pic_nr += 1
                pic[pic_nr] = pruned_channels_oh[ ratio_num, lay]


    total_param += (1 + np.sum(pic[pic_nr])) * num_classes

    # MAC
    mac_count += np.sum(pic[pic_nr]) * num_classes

    alternative_mac += (alternative_mac_last_filters + 1) * num_classes
    print('alternative_mac=', alternative_mac)

    return total_param, mac_count