import torch
import torch.nn as nn
import math
import re
import matplotlib.pyplot as plt
import torchvision.models as models
import torchvision.transforms.functional as ttf
import os

def ShellParser(ShellNetwork: nn.Module, general=0, force_stride = False, change_linear=0):
    """
    General: 0 returns parameter_count, layers_name_list, layers_detail
    1 returns CNN info: cnn_param_count, cnn_detail
    2 returns first MLP layer info: mlp_detail = (out_features, in_features, previous_cnn_channels)
    """
    parameter_count = []
    layers_name_list = []
    layers_detail = []
    cnn_param_count = []
    cnn_detail = []
    mlp_detail = (0, 0, 0,0)
    previous_layer = ''
    previous_cnn_channels = 0
    pattern_layer = r'^layer(\d+)\.0$'
    pattern_ds = r'^.*\.downsample$'
    pattern_sc = r'^.*\.shortcut'

    # mlp detail = (out_features, in_features, previous_CNN_Channels)
    for name, module in ShellNetwork.named_modules():
        if isinstance(module, nn.Conv2d):
            # If it's a Conv 2d layer, you can access its attributes
            previous_layer = 'Conv2D'
            layers_name_list.append('Conv2D')
            if module.bias is not None:
                bias_info = True
                bias_factor = 1
            else:
                bias_info = False
                bias_factor = 0
            if force_stride:
                stride = 1
            else:
                stride = module.stride[0]
            filter_info = (module.out_channels, module.in_channels,
                           module.kernel_size[0], module.kernel_size[1],
                           bias_info, stride, module.padding[0])
            filter_count = (module.in_channels * module.out_channels *
                            module.kernel_size[0] * module.kernel_size[1])
            bias_count = module.out_channels*bias_factor
            parameter_count.append(filter_count)
            parameter_count.append(bias_count)
            layers_detail.append(filter_info)
            cnn_param_count.append(filter_count)
            cnn_param_count.append(bias_count)
            cnn_detail.append(filter_info)
            previous_cnn_channels = module.out_channels
        elif isinstance(module, nn.Linear):
            if previous_layer == 'Conv2D':
                '''Directly after a Conv2D layer, we need to add a place holder.'''
                layers_name_list.append('Conv-MLP-PlaceHolder')
                layers_detail.append(None)
            if module.bias is not None:
                bias_info = True
                bias_factor = 1
            else:
                bias_info = False
                bias_factor = 0
            if change_linear == 0:
                out_features = module.out_features
            else:
                assert change_linear > 0, "change_linear should be greater than 0"
                assert type(change_linear) == int, "change_linear should be an integer"
                print("For models like ResNet, we want to change the last linear layer to have out_features based on the dataset.")
                out_features = change_linear
            # If it's a Linear layer, you can access its attributes
            layers_name_list.append('Linear')
            filter_info = (out_features,  module.in_features,  bias_info)
            filter_count = out_features * module.out_features
            bias_count = out_features *bias_factor
            parameter_count.append(filter_count)
            parameter_count.append(bias_count)

            if mlp_detail == (0, 0, 0,0):
                # first layer after flatten
                mlp_detail = (out_features, module.in_features,
                              previous_cnn_channels, bias_info)
                layers_detail.append(mlp_detail)
            else:
                layers_detail.append(filter_info)
            previous_layer = 'Linear'
        elif isinstance(module, nn.ReLU):
            layers_name_list.append('ReLU')
            previous_layer = 'ReLU'
            layers_detail.append(None)
        elif isinstance(module, nn.LeakyReLU):
            layers_name_list.append('LeakyReLU')
            previous_layer = 'LeakyReLU'
            layers_detail.append([module.negative_slope])
        elif isinstance(module, nn.MaxPool2d):
            layers_name_list.append('MaxPool2D')
            previous_layer = 'MaxPool2D'
            if force_stride:
                stride = 1
            else:
                stride = module.stride
            layers_detail.append((module.kernel_size, [stride, module.padding]))
        elif isinstance(module, nn.BatchNorm2d):
            layers_name_list.append('BatchNorm2D')
            previous_layer = 'BatchNorm2D'
            layers_detail.append([module.num_features])
        elif isinstance(module, nn.Dropout):
            layers_name_list.append('Dropout')
            previous_layer = 'Dropout'
            layers_detail.append([module.p])
        elif isinstance(module, nn.AdaptiveAvgPool2d):
            layers_name_list.append('AdaptiveAvgPool2D')
            previous_layer = 'AdaptiveAvgPool2D'
            layers_detail.append([module.output_size])
        elif re.match(pattern_layer, name):
            # begining of a block, in case of downsamples, we save.
            layers_name_list.append('Save')
            layers_detail.append(None)
        elif re.match(pattern_ds, name) or re.match(pattern_sc, name):
            # begining of a resblock.
            layers_name_list.append('Downsample')
            layers_detail.append(None)
        else:
            previous_layer = ''
    if general == 0:
        return parameter_count, layers_name_list, layers_detail
    if general == 1:
        '''only CNN info'''
        return cnn_param_count, cnn_detail
    if general == 2:
        return mlp_detail


def ShellParser_Layers(model, remove_grad = False):
    conv_list = []
    linear_list = []
    if remove_grad:
        for param in model.parameters():
            param.requires_grad = False
    for name, module in model.named_modules():
        if isinstance(module, nn.Conv2d):
            conv_list.append(module.weight)
            conv_list.append(module.bias)
        elif isinstance(module, nn.Linear):
            linear_list.append(module.weight)
            linear_list.append(module.bias)
    return conv_list, linear_list

class param_matching_loss():
    def __init__(self, pretrained_model, layers_info, permutations, head = False):
        super(param_matching_loss, self).__init__()
        self.cnn_para, self.linear_para = ShellParser_Layers(pretrained_model)
        self.length = len(self.cnn_para)
        self.head = head
        self.layers_info = layers_info
        self.permutation_list = permutations

    def forward(self, hypernetwork, x):
        cnn_gen, linear_gen = equi_loader_general(hypernetwork, self.layers_info, x, self.permutation_list, hyper_goal_output(self.layers_info), head=self.head )
        loss_function = nn.CrossEntropyLoss()
        loss = []
        for i in range(self.length):
            if self.cnn_para[i] is not None:
                cnn_gen[i] = torch.mean(cnn_gen[i], dim=0, keepdim=False)
                assert self.cnn_para[i].shape == cnn_gen[i].shape, "cnn_para shape not match with cnn_gen shape"
                loss.append(loss_function(self.cnn_para[i], cnn_gen[i]))
        for i in range(len(self.linear_para)):
            if self.linear_para[i] is not None:
                linear_gen[i] = torch.mean(linear_gen[i], dim=0, keepdim=False)
                assert self.linear_para[i].shape == linear_gen[i].shape, "linear_para shape not match with linear_gen shape"
                loss.append(loss_function(self.linear_para[i], linear_gen[i]))
        return sum(loss)

    def __call__(self, hypernetwork, x):
        return self.forward(hypernetwork, x)


def hyper_goal_output(infos):
    """
    Tells you what should be the output size of the hypernetwork.
    :param infos: from shell parser. n: Corresponding to group Z_4*(2^n)
    :return:
    """
    layers_list = []
    for info in infos:
        if info is None:
            continue
        elif len(info) == 7:
            layers_list.append(info[0] * info[1] * (math.ceil(info[2] / 2) ** 2))
            if info[4]:
                layers_list.append(info[0])
            else:
                layers_list.append(0)
        # This is a CNN layer
        elif len(info) == 4:
            # this is a first linear layer, [2] is the previous cnn channels, and the original input feature is [1]/[2]
            previous_output_dim = info[1] / info[2]
            if previous_output_dim % 4 == 0:  # even x even
                layers_list.append(int(info[0] * info[1] / 4))
            else:  # odd x odd
                layers_list.append(int(info[0] * (info[1] - info[2]) / 4) + info[0] * info[2])
            if info[3]:
                layers_list.append(info[0])
            else:
                layers_list.append(0)
        elif len(info) == 3:
            layers_list.append(info[0] * info[1])
            if info[2]:
                layers_list.append(info[0])
            else:
                layers_list.append(0)
        else:
            pass
    return layers_list


def plot(image, title=None, k=0, saving_path = None):
    if hasattr(image, 'device'):
        if image.device.type == 'cuda':
            image = image.cpu()
    try:
        image = image.detach()
    except:
        pass
    if k != 0:
        image = torch.rot90(image, -1*k, dims=[-2, -1])
    else:
        pass
    if len(image.shape) == 2:
        plt.imshow(image)
    elif len(image.shape) == 3:
        plt.imshow(image.permute(1, 2, 0))
    elif len(image.shape) == 4:
        plt.imshow(image[0].permute(1, 2, 0))
    else:
        raise ValueError("Invalid image shape")
    if title is not None:
        plt.title(title)
    plt.axis('off')
    if saving_path is not None:
        plt.savefig(saving_path)
    else:
        plt.show()
    plt.close()
def equi_loader_90(hypernetwork, layers_info, x, permutation_list, goal_list):
    """
    This function takes the hypernetwork output and returns the corresponding parameters for the target network.
    The returned parameters are two lists of parameters, one for the CNN layers and one for the linear layers.
    cnn = [info], info = (batch, out_channels, in_channels, kernel_size)
    mlp = [info], info = (batch, out_features, in_features)
    """
    # goal_list = hyper_goal_output(layers_info)  -- we did it outside to save time.
    batch_size = x.shape[0]
    hypernetwork_output = [torch.split(hypernetwork(torch.rot90(x, i, dims=[2, 3])), goal_list, dim=1) for i in
                           range(4)]
    # each hypernetwork shape: [batch, corresponding length]. x shape = [batch, 1,28,28]
    # print("layer info: ", layers_info)  # [(16, 1, 2, 2), None, (8, 16, 2, 2), None, (100, 5408, 8), None, (10, 100)]
    cnn_para = []
    mlp_para = []
    layer_counter = 0
    for info in layers_info:
        if info is None:
            layer_counter -= 2
            # when generating and splitting we ignored none part.
        elif len(info) == 4:
            # This is a CNN layer, 2 is easy case, then depends on even or odd. don't forget to load.
            if info[3] == 2:
                # 2x2 filter, using 1 value
                view_shape = (-1, info[0], info[1], 1)
                info_0 = hypernetwork_output[0][layer_counter].view(view_shape)
                info_1 = hypernetwork_output[1][layer_counter].view(view_shape)
                info_2 = hypernetwork_output[2][layer_counter].view(view_shape)
                info_3 = hypernetwork_output[3][layer_counter].view(view_shape)
                bias_0 = hypernetwork_output[0][layer_counter + 1]
                bias_1 = hypernetwork_output[1][layer_counter + 1]
                bias_2 = hypernetwork_output[2][layer_counter + 1]
                bias_3 = hypernetwork_output[3][layer_counter + 1]
                filters = torch.cat([info_0, info_1, info_3, info_2], dim=3)
                cnn_para.append(filters.view(-1, info[0], info[1], info[2], info[3]))
                bias = (bias_0 + bias_1 + bias_3 + bias_2) / 4
                cnn_para.append(bias)
            elif info[3] % 2 == 0:
                f_dim = int(info[3] / 2)
                view_shape = (-1, info[0], info[1], f_dim, f_dim)
                info_0 = hypernetwork_output[0][layer_counter].reshape(view_shape)
                info_1 = torch.rot90(hypernetwork_output[1][layer_counter].reshape(view_shape), -1, dims=[3, 4])
                info_2 = torch.rot90(hypernetwork_output[2][layer_counter].reshape(view_shape), -2, dims=[3, 4])
                info_3 = torch.rot90(hypernetwork_output[3][layer_counter].reshape(view_shape), -3, dims=[3, 4])
                bias_0 = hypernetwork_output[0][layer_counter + 1]
                bias_1 = hypernetwork_output[1][layer_counter + 1]
                bias_2 = hypernetwork_output[2][layer_counter + 1]
                bias_3 = hypernetwork_output[3][layer_counter + 1]

                top_half = torch.cat([info_0, info_1], dim=4)
                bot_half = torch.cat([info_3, info_2], dim=4)
                full_filter = torch.cat([top_half, bot_half], dim=3)
                cnn_para.append(full_filter)

                # don't even need view.view(-1, info[0], info[1], info[2], info[3].
                cnn_para.append((bias_0 + bias_1 + bias_3 + bias_2) / 4)
            #     # even x even filter using even**2/4 values
            elif info[3] % 2 == 1:
                dim = math.ceil(info[3] / 2)
                info_0 = hypernetwork_output[0][layer_counter].view(-1, info[0], info[1], dim, dim)
                info_1 = torch.rot90(hypernetwork_output[1][layer_counter].view(-1, info[0], info[1], dim, dim), -1,
                                     dims=[3, 4])
                info_2 = torch.rot90(hypernetwork_output[2][layer_counter].view(-1, info[0], info[1], dim, dim), -2,
                                     dims=[3, 4])
                info_3 = torch.rot90(hypernetwork_output[3][layer_counter].view(-1, info[0], info[1], dim, dim), -3,
                                     dims=[3, 4])
                top_half = torch.cat(
                    [info_0[:, :, :, :, :-1], (info_0[:, :, :, :, [-1]] + info_1[:, :, :, :, [0]]) / 2,
                     info_1[:, :, :, :, 1:]], dim=4)
                bot_half = torch.cat(
                    [info_3[:, :, :, :, :-1], (info_3[:, :, :, :, [-1]] + info_2[:, :, :, :, [0]]) / 2,
                     info_2[:, :, :, :, 1:]], dim=4)
                full_filter = torch.cat(
                    [top_half[:, :, :, :-1, :], (top_half[:, :, :, [-1], :] + bot_half[:, :, :, [0], :]) / 2,
                     bot_half[:, :, :, 1:, :]], dim=3)

                cnn_para.append(full_filter)
                bias_0 = hypernetwork_output[0][layer_counter + 1]
                bias_1 = hypernetwork_output[1][layer_counter + 1]
                bias_2 = hypernetwork_output[2][layer_counter + 1]
                bias_3 = hypernetwork_output[3][layer_counter + 1]
                cnn_para.append((bias_0 + bias_1 + bias_3 + bias_2) / 4)

        elif len(info) == 3:
            # this is a first linear layer, permutation symmetry.
            final_weight_in_chunks = []
            one_weight = [0] * int(info[1] / info[2])
            # print(hypernetwork_output[0][layer_counter].shape)
            hyper_out_weight = [torch.chunk(hypernetwork_output[rotation][layer_counter], info[2], dim=-1) for
                                rotation in range(4)]
            # print(batch_size, info[0], info[1], info[2])
            final_list = [
                [hyper_out_weight[rotation][chunks].view(batch_size, info[0], -1) for chunks in range(info[2])]
                for rotation in range(4)]
            for chunks in range(info[2]):
                for j, permute in enumerate(permutation_list):
                    para = [final_list[k][chunks][:, :, j] for k in range(4)]
                    if len(permute) == 4:
                        for k in range(4):
                            one_weight[permute[k]] = para[k].unsqueeze(-1)
                    elif len(permute) == 1:
                        one_weight[permute[0]] = torch.sum(torch.stack(para, dim=2), dim=2, keepdim=True)/4
                        # one_weight[permute[0]] = (para[0] + para[1] + para[3] + para[2]) / 4
                weight = torch.cat(one_weight, dim=-1)
                final_weight_in_chunks.append(weight)
            final_weight = torch.cat(final_weight_in_chunks, dim=-1)
            mlp_para.append(final_weight)
            bias_0 = hypernetwork_output[0][layer_counter + 1]
            bias_1 = hypernetwork_output[1][layer_counter + 1]
            bias_2 = hypernetwork_output[2][layer_counter + 1]
            bias_3 = hypernetwork_output[3][layer_counter + 1]
            final_bias = (bias_0 + bias_1 + bias_3 + bias_2) / 4
            mlp_para.append(final_bias)

        elif len(info) == 2:
            view_shape = (-1, info[0], info[1])
            info_0 = hypernetwork_output[0][layer_counter].view(view_shape)
            info_1 = hypernetwork_output[1][layer_counter].view(view_shape)
            info_2 = hypernetwork_output[2][layer_counter].view(view_shape)
            info_3 = hypernetwork_output[3][layer_counter].view(view_shape)
            bias_0 = hypernetwork_output[0][layer_counter + 1]
            bias_1 = hypernetwork_output[1][layer_counter + 1]
            bias_2 = hypernetwork_output[2][layer_counter + 1]
            bias_3 = hypernetwork_output[3][layer_counter + 1]
            # all others linear layers, invariant, take average
            weight = (info_0 + info_1 + info_3 + info_2) / 4
            bias = (bias_0 + bias_1 + bias_3 + bias_2) / 4
            mlp_para.append(weight)
            mlp_para.append(bias)
        layer_counter += 2

    return cnn_para, mlp_para


def equi_loader_general(hypernetwork, layers_info, x,
                        permutation_list, goal_list, rotation=0, head = False, reflection=False):
    """
    This function takes the hypernetwork output and returns the corresponding parameters for the target network.
    The returned parameters are two lists of parameters, one for the CNN layers and one for the linear layers.
    cnn = [info], info = (batch, out_channels, in_channels, kernel_size, kernel_size, Bias= T/F, stride, padding )
    mlp = [info], info = (batch, out_features, in_features, bias= T/F)
    for the first mlp layer, it has length 4 with additional previous output channel info.
    We only generate CNN layers and Linear Layers here.
    For group Z_4*(2^n), rotation is the power of 2.
    if reflection is set to false, we consider mirror reflections, and the group is D_4*(2^n)
    """
    if reflection:
        reflected_x = torch.flip(x, [-1])
        cnn_para_1, linear_para_1 =  equi_loader_general(hypernetwork, layers_info, x,
                        permutation_list, goal_list, rotation, head , reflection=False)
        cnn_para_2, linear_para_2 =  equi_loader_general(hypernetwork, layers_info, reflected_x,
                        permutation_list, goal_list, rotation, head, reflection=False)
        cnn_para = []
        for tensor1, tensor2 in zip(cnn_para_1, cnn_para_2):
            if tensor1 is None and tensor2 is None:
                cnn_para.append(None)
            else:
                cnn_para.append((tensor1 + torch.flip(tensor2, [-1]))/2)

        linear_para = []
        for tensor1, tensor2 in zip(linear_para_1, linear_para_2):
            if tensor1 is None and tensor2 is None:
                linear_para.append(None)
            else:
                linear_para.append((tensor1 + torch.flip(tensor2, [-1]))/2)
        return cnn_para, linear_para


    total_amount = 4*(2 ** rotation)
    each_rotation = 90/(2 ** rotation)
    info_needed = 2 ** rotation
    # if reflection:
    #     total_amount *= 2
    #     info_needed *= 2
    # goal_list = hyper_goal_output(layers_info)  -- we did it outside to save time.
    batch_size = x.shape[0]
    if head:
        # when we have multi-head, we already parsed it.
        hypernetwork_output = [hypernetwork(ttf.rotate(x, angle=(each_rotation * i))) for
                               i in range(total_amount)]
    else:
        hypernetwork_output = [torch.split(hypernetwork(ttf.rotate(x, angle=(each_rotation*i))), goal_list, dim=1) for i in
                           range(total_amount)]
    # 不split 或使用list与tuple，而是直接找到对应位置？
    # each hypernetwork shape: [batch, corresponding length]. x shape = [batch, 1,28,28]
    # print("layer info: ", layers_info)  # [(16, 1, 2, 2), None, (8, 16, 2, 2), None, (100, 5408, 8), None, (10, 100)]
    cnn_para = []
    mlp_para = []
    layer_counter = 0
    for info in layers_info:
        if info is None or len(info) not in [7, 4, 3]:
            layer_counter -= 2
            # when generating and splitting we ignored none part.
        elif len(info) == 7:
            bias_needed = info[4]
            # This is a CNN layer, 2 is easy case, then depends on even or odd. don't forget to load.
            if info[3] == 2:
                # 2x2 filter, using 1 value
                view_shape = (-1, info[0], info[1], 1)
                generated_filters = []
                generated_biases = []
                four_sections = []
                average_bias = []
                for i in range(total_amount):
                    generated_filters.append(hypernetwork_output[i][layer_counter].view(view_shape))
                    if bias_needed:
                        generated_biases.append(hypernetwork_output[i][layer_counter + 1])
                for i in range(4):
                    # why 4 here i forgot... base angle
                    total = 0
                    for j in range(info_needed):
                        counter = i*info_needed + j
                        total += generated_filters[counter]
                        if bias_needed:
                            average_bias.append(generated_biases[counter])
                    four_sections.append(total/info_needed)
                filters = torch.cat([four_sections[0], four_sections[1], four_sections[3], four_sections[2]], dim=3)
                cnn_para.append(filters.view(-1, info[0], info[1], info[2], info[3]))
                if bias_needed:
                    bias = sum(average_bias) / total_amount
                    cnn_para.append(bias)
                else:
                    cnn_para.append(None)
            elif info[3] % 2 == 0:
                view_shape = (-1, info[0], info[1], int(info[2] / 2), int(info[3] / 2))
                generated_filters = []
                generated_biases = []
                four_sections = []
                average_bias = []
                for i in range(total_amount):
                    generated_filters.append(hypernetwork_output[i][layer_counter].view(view_shape))
                    if bias_needed:
                        generated_biases.append(hypernetwork_output[i][layer_counter + 1])
                for i in range(4):
                    total = 0
                    for j in range(info_needed):
                        counter = i * info_needed + j
                        total += generated_filters[counter]
                        average_bias.append(generated_biases[counter])
                    four_sections.append(torch.rot90(total/info_needed, -1*i, dims=[3,4]))
                top_half = torch.cat([four_sections[0], four_sections[1]], dim=4)
                bot_half = torch.cat([four_sections[3], four_sections[2]], dim=4)
                full_filter = torch.cat([top_half, bot_half], dim=3)
                cnn_para.append(full_filter)
                # don't even need view.view(-1, info[0], info[1], info[2], info[3].
                if bias_needed:
                    bias = sum(average_bias) / total_amount
                    cnn_para.append(bias)
                else:
                    cnn_para.append(None)
            #     # even x even filter using even**2/4 values
            elif info[3] % 2 == 1:
                dim = math.ceil(info[3] / 2)
                view_shape = (-1, info[0], info[1], dim, dim)
                generated_filters = []
                generated_biases = []
                four_sections = []
                average_bias = []
                for i in range(total_amount):
                    generated_filters.append(hypernetwork_output[i][layer_counter].view(view_shape))
                    if bias_needed:
                        generated_biases.append(hypernetwork_output[i][layer_counter + 1])
                for i in range(4):
                    total = 0
                    for j in range(info_needed):
                        counter = i * info_needed + j
                        total += generated_filters[counter]
                        if bias_needed:
                            average_bias.append(generated_biases[counter])
                    four_sections.append(torch.rot90(total/info_needed, -1*i, dims=[3,4]))
                top_half = torch.cat(
                    [four_sections[0][:, :, :, :, :-1], (four_sections[0][:, :, :, :, [-1]] + four_sections[1][:, :, :, :, [0]]) / 2,
                     four_sections[1][:, :, :, :, 1:]], dim=4)
                bot_half = torch.cat(
                    [four_sections[3][:, :, :, :, :-1], (four_sections[3][:, :, :, :, [-1]] + four_sections[2][:, :, :, :, [0]]) / 2,
                     four_sections[2][:, :, :, :, 1:]], dim=4)
                full_filter = torch.cat(
                    [top_half[:, :, :, :-1, :], (top_half[:, :, :, [-1], :] + bot_half[:, :, :, [0], :]) / 2,
                     bot_half[:, :, :, 1:, :]], dim=3)
                cnn_para.append(full_filter)
                if bias_needed:
                    bias = sum(average_bias) / total_amount
                    cnn_para.append(bias)
                else:
                    cnn_para.append(None)
        elif len(info) == 4:
            # first linear layer
            bias_needed = info[3]
            linear_weight = 0
            for num_permute in range(info_needed):
                final_weight_in_chunks = []
                one_weight = [0] * int(info[1] / info[2])
                hyper_out_weight = [torch.chunk(hypernetwork_output[num_permute+rotation*info_needed][layer_counter], info[2], dim=-1) for
                                    rotation in range(4)]
                final_list = [
                    [hyper_out_weight[rotation][chunks].view(batch_size, info[0], -1) for chunks in range(info[2])]
                    for rotation in range(4)]
                for chunks in range(info[2]):
                    for j, permute in enumerate(permutation_list):
                        para = [final_list[k][chunks][:, :, j] for k in range(4)]
                        if len(permute) == 4:
                            for k in range(4):
                                one_weight[permute[k]] = para[k].unsqueeze(-1)
                        elif len(permute) == 1:
                            one_weight[permute[0]] = torch.sum(torch.stack(para, dim=2), dim=2,keepdim=True)/4
                    weight = torch.cat(one_weight, dim=-1)
                    final_weight_in_chunks.append(weight)
                final_weight = torch.cat(final_weight_in_chunks, dim=-1)
                linear_weight += final_weight
            linear_weight = linear_weight / info_needed
            mlp_para.append(linear_weight)
            if bias_needed:
                bias_collection = [hypernetwork_output[i][layer_counter+1] for i in range(total_amount)]
                final_bias = sum(bias_collection) / total_amount
                mlp_para.append(final_bias)
            else:
                mlp_para.append(None)
        elif len(info) == 3:
            bias_needed = info[2]
            view_shape = (-1, info[0], info[1])
            total_weight = []
            total_bias = []
            for i in range(total_amount):
                total_weight.append(hypernetwork_output[i][layer_counter].view(view_shape))
                if bias_needed:
                    total_bias.append(hypernetwork_output[i][layer_counter + 1])
            # all others linear layers, invariant, take average
            weight = sum(total_weight) / total_amount
            mlp_para.append(weight)
            if bias_needed:
                bias = sum(total_bias) / total_amount
                mlp_para.append(bias)
            else:
                mlp_para.append(None)
        layer_counter += 2
    return cnn_para, mlp_para

if __name__ == '__main__':
    resnet = models.resnet18(num_classes=10)
    linear = ShellParser(resnet, 2, change_linear=10)
    print(linear)
    linear_2 = ShellParser(resnet, 2)
    print(linear_2)
    print()
