import numpy as np
import random
import torch
from models import Resnet20
from typing import Dict, List
from copy import deepcopy
from collections import OrderedDict
from flwr.server.strategy.aggregate import aggregate
from dataset import cinicDataset
from flwr.common import ndarrays_to_parameters, parameters_to_ndarrays, FitRes
from torch.utils.data import DataLoader

NUM_CHANNELS = 3
CLASSES = 10
Batch = 128
DEVICE = 'cuda' if torch.cuda.is_available() else "cpu"

Param_by_Blocks = [['fc.weight'],

                   ['conv4_0.weight', 'bn4_0.weight', 'bn4_0.bias', 'bn4_0.running_mean', 'bn4_0.running_var',
                    'conv4_1.weight', 'bn4_1.weight', 'bn4_1.bias', 'bn4_1.running_mean', 'bn4_1.running_var', 
                    'conv4_2.weight', 'bn4_2.weight', 'bn4_2.bias', 'bn4_2.running_mean', 'bn4_2.running_var', 
                    'conv4_3.weight', 'bn4_3.weight', 'bn4_3.bias', 'bn4_3.running_mean', 'bn4_3.running_var', 
                    'conv4_4.weight', 'bn4_4.weight', 'bn4_4.bias', 'bn4_4.running_mean', 'bn4_4.running_var',
                    'conv4_5.weight', 'bn4_5.weight', 'bn4_5.bias', 'bn4_5.running_mean', 'bn4_5.running_var',
                    'conv4_6.weight', 'bn4_6.weight', 'bn4_6.bias', 'bn4_6.running_mean', 'bn4_6.running_var'],

                   ['conv3_0.weight', 'bn3_0.weight', 'bn3_0.bias', 'bn3_0.running_mean', 'bn3_0.running_var',
                    'conv3_1.weight', 'bn3_1.weight', 'bn3_1.bias', 'bn3_1.running_mean', 'bn3_1.running_var', 
                    'conv3_2.weight', 'bn3_2.weight', 'bn3_2.bias', 'bn3_2.running_mean', 'bn3_2.running_var', 
                    'conv3_3.weight', 'bn3_3.weight', 'bn3_3.bias', 'bn3_3.running_mean', 'bn3_3.running_var', 
                    'conv3_4.weight', 'bn3_4.weight', 'bn3_4.bias', 'bn3_4.running_mean', 'bn3_4.running_var',
                    'conv3_5.weight', 'bn3_5.weight', 'bn3_5.bias', 'bn3_5.running_mean', 'bn3_5.running_var',
                    'conv3_6.weight', 'bn3_6.weight', 'bn3_6.bias', 'bn3_6.running_mean', 'bn3_6.running_var'],


                   ['conv2_1.weight', 'bn2_1.weight', 'bn2_1.bias', 'bn2_1.running_mean', 'bn2_1.running_var', 
                    'conv2_2.weight', 'bn2_2.weight', 'bn2_2.bias', 'bn2_2.running_mean', 'bn2_2.running_var', 
                    'conv2_3.weight', 'bn2_3.weight', 'bn2_3.bias', 'bn2_3.running_mean', 'bn2_3.running_var', 
                    'conv2_4.weight', 'bn2_4.weight', 'bn2_4.bias', 'bn2_4.running_mean', 'bn2_4.running_var',
                    'conv2_5.weight', 'bn2_5.weight', 'bn2_5.bias', 'bn2_5.running_mean', 'bn2_5.running_var',
                    'conv2_6.weight', 'bn2_6.weight', 'bn2_6.bias', 'bn2_6.running_mean', 'bn2_6.running_var'],

                   ['conv1.weight', 'bn1.weight', 'bn1.bias', 'bn1.running_mean', 'bn1.running_var']]


def get_parameters(net:torch.nn.Module) -> List[np.ndarray]: # Access the parameters of a neural network 
  return [val.cpu().numpy() for _, val in net.state_dict().items()]

def get_filters(net:torch.nn.Module) -> List[np.ndarray]:
    params_list = []
    for k, v in net.state_dict().items():
        if "num_batches" not in k:
            params_list.append(v.cpu().numpy())       
    return params_list

def set_filters(net:torch.nn.Module, parameters: List[np.ndarray]): # modify the parameters of a neural network
    param_set_index = 0
    all_names = []
    all_params = []
    old_param_dict = net.state_dict()
    for k, _ in old_param_dict.items():
        if "num_batches" not in k:
            all_params.append(parameters[param_set_index])
            all_names.append(k)
            param_set_index += 1
    params_dict = zip(all_names, all_params)
    state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
    net.load_state_dict(state_dict, strict=False)

def get_updated_layers_with_approx(model:Resnet20, lf:int) -> Dict:

    # First, get parameters that are to be sent by the number of frozen layers:
    param_names = []
    for i in range(len(Param_by_Blocks)-lf):
        param_names += Param_by_Blocks[i]
    layer_dict = {}
    for k,v in model.state_dict().items():
        if k in param_names: # which means this layer is not frozen in the local training
            if ("conv3_0" in k or "conv3_1" in k) and lf == 2:
                pass
            elif ("conv4_0" in k or "conv4_1" in k) and lf == 3:
                pass
            else:
                layer_dict[k] = v.cpu().numpy()
    return layer_dict

def get_updated_layers(model:Resnet20, lf:int) -> Dict:
    # First, get parameters that are to be sent by the number of frozen layers:
    param_names = []
    for i in range(len(Param_by_Blocks)-lf):
        param_names += Param_by_Blocks[i]
    layer_dict = {}
    for k,v in model.state_dict().items():
        if k in param_names: # which means this layer is not frozen in the local training
            layer_dict[k] = v.cpu().numpy()
    return layer_dict

def get_random_updated_layers(model:Resnet20, frozen_layers:List[str]) -> Dict:
    param_names = []
    if not ('conv1' in frozen_layers):  # Add Conv1 and Conv2_x to the set of active layers 
        for n in Param_by_Blocks[-1]:
            param_names.append(n)
    if not ('conv2' in frozen_layers):
        for n in Param_by_Blocks[-2]:
            param_names.append(n)
    if not ('conv3' in frozen_layers):  # Add Conv3_x to the set of active layers 
        for n in Param_by_Blocks[-3]:
            param_names.append(n)
    if not ('conv4' in frozen_layers):  # Add Conv4_x to the set of active layers
        for n in  Param_by_Blocks[-4]:
            param_names.append(n)
    for n in Param_by_Blocks[0]:
        param_names.append(n)
    layer_dict = {}
    for k,v in model.state_dict().items():
        if k in param_names or k in ['fc.weight']: # Return the active blocks and the fully-connected layer.
            layer_dict[k] = v.cpu().numpy()
    return layer_dict

def get_mean_test_acc(global_model:Resnet20):
        dataset = cinicDataset("clientdata/cinic_test.csv")
        testloader = DataLoader(dataset, Batch, shuffle=False)
        criterion = torch.nn.CrossEntropyLoss()
        correct, total, loss = 0, 0, 0.0
        global_model.eval()
        with torch.no_grad():
            for samples, labels in testloader:
                samples, labels = samples.to(DEVICE), labels.to(DEVICE)
                outputs = global_model(samples)
                loss += criterion(outputs, labels).item()
                total += labels.size(0)
                #correct += (outputs == labels).sum()
                _, predicted = torch.max(outputs, 1)
                correct += (predicted == labels).sum()
        loss /= len(dataset)
        accuracy = correct / total
        return loss, accuracy

def generate_filters_random(global_model:Resnet20, rate, sd=1620):
    drop_information = {}
    if rate >= 0.99:
        return drop_information, get_filters(global_model)
    random.seed(sd)
    param_dict = global_model.state_dict()
    subparams = []
    conv_layer_index_by_block = {}
    for name in param_dict.keys():
        if "num_batches" not in name:
            w = param_dict[name].cpu()
            if name == 'conv1.weight':
                total_filters = w.shape[0]
                num_selected_filters = max(1, int(total_filters * rate))
                non_masked_filter_ids = sorted(random.sample(list(range(total_filters)), num_selected_filters))
                conv_layer_index_by_block["conv1"] = non_masked_filter_ids
                sub_param = torch.index_select(w,0,torch.tensor(non_masked_filter_ids))
            elif name in ['bn1.weight', 'bn1.bias', 'bn1.running_mean', 'bn1.running_var']:
                total_filters = w.shape[0]
                num_selected_filters = max(1, int(total_filters * rate))
                non_masked_filter_ids = drop_information['conv1.weight']
                sub_param = torch.index_select(w,0,torch.tensor(non_masked_filter_ids))
            
            elif "conv2_" in name:
                total_filters = w.shape[0]
                num_selected_filters = max(1, int(total_filters * rate))
                if not ("conv2" in conv_layer_index_by_block.keys()):
                    non_masked_filter_ids = sorted(random.sample(list(range(total_filters)), num_selected_filters))
                    conv_layer_index_by_block["conv2"] = non_masked_filter_ids
                else:
                    non_masked_filter_ids = conv_layer_index_by_block["conv2"]
                lastindices = conv_layer_index_by_block['conv1'] if name == "conv2_1.weight" else conv_layer_index_by_block['conv2']
                sub_param_1 = torch.index_select(w, 0, torch.tensor(non_masked_filter_ids))
                sub_param = torch.index_select(sub_param_1, 1, torch.tensor(lastindices))
            elif "bn2_" in name:
                non_masked_filter_ids = conv_layer_index_by_block['conv2']
                sub_param = torch.index_select(w,0,torch.tensor(non_masked_filter_ids))
            
            elif "conv3_" in name:
                total_filters = w.shape[0]
                num_selected_filters = max(1, int(total_filters * rate))
                if not ("conv3" in conv_layer_index_by_block.keys()):
                    non_masked_filter_ids = sorted(random.sample(list(range(total_filters)), num_selected_filters))
                    conv_layer_index_by_block["conv3"] = non_masked_filter_ids
                else:
                    non_masked_filter_ids = conv_layer_index_by_block["conv3"]
                lastindices = conv_layer_index_by_block['conv2'] if name in ['conv3_0.weight', 'conv3_1.weight'] else conv_layer_index_by_block['conv3']
                sub_param_1 = torch.index_select(w, 0, torch.tensor(non_masked_filter_ids))
                sub_param = torch.index_select(sub_param_1, 1, torch.tensor(lastindices))
            elif "bn3_" in name:
                #non_masked_filter_ids = conv_layer_index_by_block['conv2'] if ("bn3_0" in name) or ("bn3_1" in name) else conv_layer_index_by_block['conv3']
                non_masked_filter_ids = conv_layer_index_by_block['conv3']
                sub_param = torch.index_select(w,0,torch.tensor(non_masked_filter_ids))
            
            elif "conv4_" in name:
                total_filters = w.shape[0]
                num_selected_filters = max(1, int(total_filters * rate))
                if not ("conv4" in conv_layer_index_by_block.keys()):
                    non_masked_filter_ids = sorted(random.sample(list(range(total_filters)), num_selected_filters))
                    conv_layer_index_by_block["conv4"] = non_masked_filter_ids
                else:
                    non_masked_filter_ids = conv_layer_index_by_block["conv4"]
                lastindices = conv_layer_index_by_block['conv3'] if name in ['conv4_0.weight', 'conv4_1.weight'] else conv_layer_index_by_block['conv4']
                sub_param_1 = torch.index_select(w, 0, torch.tensor(non_masked_filter_ids))
                sub_param = torch.index_select(sub_param_1, 1, torch.tensor(lastindices))
            elif "bn4_" in name:
                non_masked_filter_ids = conv_layer_index_by_block['conv4']
                sub_param = torch.index_select(w,0,torch.tensor(non_masked_filter_ids))
            elif name == 'fc.weight':
                non_masked_filter_ids = conv_layer_index_by_block['conv4']
                sub_param = torch.index_select(w,1,torch.tensor(non_masked_filter_ids))
            drop_information[name] = non_masked_filter_ids
            subparams.append(sub_param.numpy())
    return drop_information, subparams

def generate_subnet_ordered(global_model:Resnet20, rate):
    drop_information = {}
    if rate >= 0.99:
        return drop_information, get_filters(global_model)
    param_dict = global_model.state_dict()
    subparams = []
    conv_layer_index_by_block = {}
    for name in param_dict.keys():
        if "num_batches" not in name:
            w = param_dict[name].cpu()
            if name == 'conv1.weight':
                total_filters = w.shape[0]
                num_selected_filters = max(1, int(total_filters * rate))
                non_masked_filter_ids = list(range(num_selected_filters))
                conv_layer_index_by_block["conv1"] = non_masked_filter_ids
                sub_param = torch.index_select(w,0,torch.tensor(non_masked_filter_ids))
            elif name in ['bn1.weight', 'bn1.bias', 'bn1.running_mean', 'bn1.running_var']:
                total_filters = w.shape[0]
                num_selected_filters = max(1, int(total_filters * rate))
                non_masked_filter_ids = drop_information['conv1.weight']
                sub_param = torch.index_select(w,0,torch.tensor(non_masked_filter_ids))
            
            elif "conv2_" in name:
                total_filters = w.shape[0]
                num_selected_filters = max(1, int(total_filters * rate))
                if not ("conv2" in conv_layer_index_by_block.keys()):
                    non_masked_filter_ids = list(range(num_selected_filters))
                    conv_layer_index_by_block["conv2"] = non_masked_filter_ids
                else:
                    non_masked_filter_ids = conv_layer_index_by_block["conv2"]
                lastindices = conv_layer_index_by_block['conv1'] if name == "conv2_1.weight" else conv_layer_index_by_block['conv2']
                sub_param_1 = torch.index_select(w, 0, torch.tensor(non_masked_filter_ids))
                sub_param = torch.index_select(sub_param_1, 1, torch.tensor(lastindices))
            elif "bn2_" in name:
                non_masked_filter_ids = conv_layer_index_by_block['conv2']
                sub_param = torch.index_select(w,0,torch.tensor(non_masked_filter_ids))
            
            elif "conv3_" in name:
                total_filters = w.shape[0]
                num_selected_filters = max(1, int(total_filters * rate))
                if not ("conv3" in conv_layer_index_by_block.keys()):
                    non_masked_filter_ids = list(range(num_selected_filters))
                    conv_layer_index_by_block["conv3"] = non_masked_filter_ids
                else:
                    non_masked_filter_ids = conv_layer_index_by_block["conv3"]
                lastindices = conv_layer_index_by_block['conv2'] if name in ['conv3_0.weight', 'conv3_1.weight'] else conv_layer_index_by_block['conv3']
                sub_param_1 = torch.index_select(w, 0, torch.tensor(non_masked_filter_ids))
                sub_param = torch.index_select(sub_param_1, 1, torch.tensor(lastindices))
            elif "bn3_" in name:
                non_masked_filter_ids = conv_layer_index_by_block['conv3']
                sub_param = torch.index_select(w,0,torch.tensor(non_masked_filter_ids))
            
            elif "conv4_" in name:
                total_filters = w.shape[0]
                num_selected_filters = max(1, int(total_filters * rate))
                if not ("conv4" in conv_layer_index_by_block.keys()):
                    non_masked_filter_ids = list(range(num_selected_filters))
                    conv_layer_index_by_block["conv4"] = non_masked_filter_ids
                else:
                    non_masked_filter_ids = conv_layer_index_by_block["conv4"]
                lastindices = conv_layer_index_by_block['conv3'] if name in ['conv4_0.weight', 'conv4_1.weight'] else conv_layer_index_by_block['conv4']
                sub_param_1 = torch.index_select(w, 0, torch.tensor(non_masked_filter_ids))
                sub_param = torch.index_select(sub_param_1, 1, torch.tensor(lastindices))
            elif "bn4_" in name:
                non_masked_filter_ids = conv_layer_index_by_block['conv4']
                sub_param = torch.index_select(w,0,torch.tensor(non_masked_filter_ids))
            elif name == 'fc.weight':
                non_masked_filter_ids = conv_layer_index_by_block['conv4']
                sub_param = torch.index_select(w,1,torch.tensor(non_masked_filter_ids))

            drop_information[name] = non_masked_filter_ids
            subparams.append(sub_param.numpy())
    return drop_information, subparams

def dropout_aggregation(Fit_res:List[FitRes], global_param:List[np.ndarray]):
    Aggregation_Dict = {}
    Aggregated_params = {}
    full_results = []
    for fit_res in Fit_res:
        param, num, merge_info = parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples, fit_res.metrics["drop_info"]
        if len(merge_info) == 0:
            full_results.append((param, num))
            for l1 in range(len(param)):
                layer = param[l1]
                for l2 in range(len(layer)):
                    filter = layer[l2]
                    if len(layer.shape) == 3:
                        for l3 in range(len(filter)):
                            if (l1,l2,l3) in Aggregation_Dict.keys():
                                Aggregation_Dict[(l1,l2,l3)].append(([filter[l3]], num))
                            else:
                                Aggregation_Dict[(l1,l2,l3)] = [([filter[l3]], num)]
                    else:
                        if (l1,l2) in Aggregation_Dict.keys():
                            Aggregation_Dict[(l1,l2)].append(([filter], num))
                        else:
                            Aggregation_Dict[(l1,l2)] = [([filter], num)]
        else:
            last_layer_indices = list(range(NUM_CHANNELS))
            layer_count = 0
            for k in merge_info.keys():
                selected_filters = merge_info[k]
                layer = param[layer_count]
                i1 = 0
                if "bn" in k:
                    for f in selected_filters:
                        if (layer_count, f) in Aggregation_Dict.keys():
                            Aggregation_Dict[(layer_count, f)].append(([layer[i1]], num))
                        else:
                            Aggregation_Dict[(layer_count, f)] = [([layer[i1]], num)]
                elif k == 'fc.weight':
                    for f in range(CLASSES):
                        j1 = 0
                        for j in last_layer_indices:
                            if (layer_count,f,j) in Aggregation_Dict.keys():
                                Aggregation_Dict[(layer_count,f,j)].append(([layer[i1][j1]], num))
                            else:
                                Aggregation_Dict[(layer_count,f,j)] = [([layer[i1][j1]], num)]
                            j1 += 1
                        i1 += 1
                else:
                    indices_along_dim2 = last_layer_indices
                    if k == 'conv3_1.weight':
                        indices_along_dim2 = merge_info["conv2_6.weight"]
                    if k == 'conv4_1.weight':
                        indices_along_dim2 = merge_info["conv3_6.weight"]
                    for f in selected_filters:
                        j1 = 0
                        for j in indices_along_dim2:
                            if (layer_count,f,j) in Aggregation_Dict.keys():
                                Aggregation_Dict[(layer_count,f,j)].append(([layer[i1][j1]], num))
                            else:
                                Aggregation_Dict[(layer_count,f,j)] = [([layer[i1][j1]], num)]
                            j1 += 1
                        i1 += 1
                layer_count += 1
                last_layer_indices = selected_filters
    for z, p in Aggregation_Dict.items():
        Aggregated_params[z] = aggregate(p)
    full_param = aggregate(full_results) if len(full_results) > 0 else deepcopy(global_param)
    for Key in Aggregated_params.keys():
        if len(Key) == 2:
            layer_idx, filter = Key
            full_param[layer_idx][filter] = Aggregated_params[Key][0]
        else:
            layer_idx, filter, last_filter = Key
            full_param[layer_idx][filter][last_filter] = Aggregated_params[Key][0]
    return full_param

def generate_subnet_SLT(global_model:Resnet20, rate, lf=0):
    drop_information = {}
    if rate >= 0.99:
        return drop_information, get_filters(global_model)
    param_dict = global_model.state_dict()
    subparams = []
    conv_layer_index_by_block = {}
    for name in param_dict.keys():
        if "num_batches" not in name:
            w = param_dict[name].cpu()
            if name == 'conv1.weight':
                r = rate if lf >= 1 else 1.0
                total_filters = w.shape[0]
                num_selected_filters = max(1, int(total_filters * r))
                non_masked_filter_ids = list(range(num_selected_filters))
                conv_layer_index_by_block["conv1"] = non_masked_filter_ids
                sub_param = torch.index_select(w,0,torch.tensor(non_masked_filter_ids))
            elif name in ['bn1.weight', 'bn1.bias', 'bn1.running_mean', 'bn1.running_var']:
                total_filters = w.shape[0]
                non_masked_filter_ids = drop_information['conv1.weight']
                sub_param = torch.index_select(w,0,torch.tensor(non_masked_filter_ids))
            
            elif "conv2_" in name:
                r = rate if lf >= 2 else 1.0
                total_filters = w.shape[0]
                num_selected_filters = max(1, int(total_filters * r))
                if not ("conv2" in conv_layer_index_by_block.keys()):
                    non_masked_filter_ids = list(range(num_selected_filters))
                    conv_layer_index_by_block["conv2"] = non_masked_filter_ids
                else:
                    non_masked_filter_ids = conv_layer_index_by_block["conv2"]
                lastindices = conv_layer_index_by_block['conv1'] if name == "conv2_1.weight" else conv_layer_index_by_block['conv2']
                sub_param_1 = torch.index_select(w, 0, torch.tensor(non_masked_filter_ids))
                sub_param = torch.index_select(sub_param_1, 1, torch.tensor(lastindices))
            elif "bn2_" in name:
                non_masked_filter_ids = conv_layer_index_by_block['conv2']
                sub_param = torch.index_select(w,0,torch.tensor(non_masked_filter_ids))
            
            elif "conv3_" in name:
                total_filters = w.shape[0]
                r = rate if lf >= 3 else 1.0
                num_selected_filters = max(1, int(total_filters * r))
                if not ("conv3" in conv_layer_index_by_block.keys()):
                    non_masked_filter_ids = list(range(num_selected_filters))
                    conv_layer_index_by_block["conv3"] = non_masked_filter_ids
                else:
                    non_masked_filter_ids = conv_layer_index_by_block["conv3"]
                lastindices = conv_layer_index_by_block['conv2'] if name in ['conv3_0.weight', 'conv3_1.weight'] else conv_layer_index_by_block['conv3']
                sub_param_1 = torch.index_select(w, 0, torch.tensor(non_masked_filter_ids))
                sub_param = torch.index_select(sub_param_1, 1, torch.tensor(lastindices))
            elif "bn3_" in name:
                non_masked_filter_ids = conv_layer_index_by_block['conv3']
                sub_param = torch.index_select(w,0,torch.tensor(non_masked_filter_ids))
            
            elif "conv4_" in name:
                total_filters = w.shape[0]
                r = rate if lf >= 4 else 1.0
                num_selected_filters = max(1, int(total_filters * r))
                if not ("conv4" in conv_layer_index_by_block.keys()):
                    non_masked_filter_ids = list(range(num_selected_filters))
                    conv_layer_index_by_block["conv4"] = non_masked_filter_ids
                else:
                    non_masked_filter_ids = conv_layer_index_by_block["conv4"]
                lastindices = conv_layer_index_by_block['conv3'] if name in ['conv4_0.weight', 'conv4_1.weight'] else conv_layer_index_by_block['conv4']
                sub_param_1 = torch.index_select(w, 0, torch.tensor(non_masked_filter_ids))
                sub_param = torch.index_select(sub_param_1, 1, torch.tensor(lastindices))
            elif "bn4_" in name:
                non_masked_filter_ids = conv_layer_index_by_block['conv4']
                sub_param = torch.index_select(w,0,torch.tensor(non_masked_filter_ids))
            elif name == 'fc.weight':
                non_masked_filter_ids = conv_layer_index_by_block['conv4']
                sub_param = torch.index_select(w,1,torch.tensor(non_masked_filter_ids))

            drop_information[name] = non_masked_filter_ids
            subparams.append(sub_param.numpy())
    return drop_information, subparams

def compute_sampling_prob(model:Resnet20, device=DEVICE):
    Sampling_Probs = {}
    param_dict = model.state_dict()
    # calculating conv1 sampling prob:
    for pname in Param_by_Blocks[-1]:
        conv_prob = []
        conv_normalizer = 0.0
        if "conv" in pname:
            for i in range(param_dict[pname].shape[1]):
                index_tensor = torch.tensor(i).to(device)
                output_filter = torch.index_select(param_dict[pname], 1, index_tensor)
                output_filter_norm = torch.norm(output_filter, p='fro').item()
                conv_prob.append(output_filter_norm)
                conv_normalizer += output_filter_norm
            Sampling_Probs[pname] = np.array(conv_prob) / conv_normalizer

    # calculating block1 sampling prob:
    for pname in Param_by_Blocks[-2]:
        conv_prob = []
        conv_normalizer = 0.0
        if "conv" in pname:
            for i in range(param_dict[pname].shape[1]):
                index_tensor = torch.tensor(i).to(device)
                output_filter = torch.index_select(param_dict[pname], 1, index_tensor)
                output_filter_norm = torch.norm(output_filter, p='fro').item()
                conv_prob.append(output_filter_norm)
                conv_normalizer += output_filter_norm
            Sampling_Probs[pname] = np.array(conv_prob) / conv_normalizer
    
    # calculating block2 sampling prob:
    for pname in Param_by_Blocks[-3]:
        conv_prob = []
        conv_normalizer = 0.0
        if "conv" in pname:
            for i in range(param_dict[pname].shape[1]):
                index_tensor = torch.tensor(i).to(device)
                output_filter = torch.index_select(param_dict[pname], 1, index_tensor)
                output_filter_norm = torch.norm(output_filter, p='fro').item()
                conv_prob.append(output_filter_norm)
                conv_normalizer += output_filter_norm
            Sampling_Probs[pname] = np.array(conv_prob) / conv_normalizer

    return Sampling_Probs

def approximate_convolution(globalmodel:Resnet20, sampling_probs:Dict, lf:int, prob=1.0) -> List[np.ndarray]:
    if prob >= 0.95 or lf == 0:
        return get_filters(globalmodel)
    frac = min(1.0, prob)
    approximated_conv_layers = []
    for i in sampling_probs.keys():
        if lf >= 2 and ('conv1' in i or 'conv2' in i):
            approximated_conv_layers.append(i)
        elif lf >= 3 and 'conv3' in i:
            approximated_conv_layers.append(i)
    param_dict = globalmodel.state_dict()
    subparams = []
    drop_info = {}
    selected_filters = {}
    for frozen_layer in approximated_conv_layers:
        num_selected_filters = int(len(sampling_probs[frozen_layer]) * frac)
        selected_filters[frozen_layer] = np.random.choice(list(range(len(sampling_probs[frozen_layer]))), size=num_selected_filters, replace=False, p=sampling_probs[frozen_layer])
    lastname = None
    for name in param_dict.keys():
        if "num_batches" not in name:
            w = param_dict[name].cpu()
            if name == 'conv1.weight':
                total_filters = w.shape[0]
                non_masked_filter_ids = list(range(total_filters)) if not ('conv2_1.weight' in approximated_conv_layers) else selected_filters['conv2_1.weight']
                drop_info[name] = non_masked_filter_ids
                sub_param = torch.index_select(w,0,torch.tensor(non_masked_filter_ids))
                lastname = name
            elif 'conv2_' in name:
                if name == "conv2_1.weight":
                    nextname = "conv2_2.weight"
                if name == "conv2_2.weight":
                    nextname = "conv2_3.weight"
                if name == "conv2_3.weight":
                    nextname = "conv2_4.weight"
                if name == "conv2_4.weight":
                    nextname = "conv2_5.weight"
                if name == "conv2_5.weight":
                    nextname = "conv2_6.weight"
                if name == "conv2_6.weight":
                    nextname = "conv3_0.weight" 
                total_filters = w.shape[0]
                lastindices = drop_info[lastname]
                non_masked_filter_ids = list(range(total_filters)) if not (nextname in selected_filters.keys()) else selected_filters[nextname]
                if name == "conv2_6.weight" and name in selected_filters and not (nextname in selected_filters.keys()):
                    non_masked_filter_ids = selected_filters[name]
                sub_param_1 = torch.index_select(w, 0, torch.tensor(non_masked_filter_ids))
                sub_param = torch.index_select(sub_param_1, 1, torch.tensor(lastindices))
                lastname = name
                drop_info[name] = non_masked_filter_ids
            elif 'conv3_' in name:
                if name == "conv3_0.weight":
                    nextname = "conv3_2.weight"
                if name == "conv3_1.weight":
                    nextname = "conv3_2.weight"
                if name == "conv3_2.weight":
                    nextname = "conv3_3.weight"
                if name == "conv3_3.weight":
                    nextname = "conv3_4.weight"
                if name == "conv3_4.weight":
                    nextname = "conv3_5.weight"
                if name == "conv3_5.weight":
                    nextname = "conv3_6.weight"
                if name == "conv3_6.weight":
                    nextname = "conv4_0.weight" 
                total_filters = w.shape[0]
                lastindices = drop_info[lastname] if name != "conv3_1.weight" else drop_info["conv2_6.weight"]
                non_masked_filter_ids = list(range(total_filters)) if not (nextname in selected_filters.keys()) else selected_filters[nextname]
                if name == "conv3_6.weight" and name in selected_filters and not (nextname in selected_filters.keys()):
                    non_masked_filter_ids = selected_filters[name]
                sub_param_1 = torch.index_select(w, 0, torch.tensor(non_masked_filter_ids))
                sub_param = torch.index_select(sub_param_1, 1, torch.tensor(lastindices))
                lastname = name
                drop_info[name] = non_masked_filter_ids
            elif 'conv4_' in name:
                if name == "conv4_0.weight":
                    nextname = "conv4_2.weight"
                if name == "conv4_1.weight":
                    nextname = "conv4_2.weight"
                if name == "conv4_2.weight":
                    nextname = "conv4_3.weight"
                if name == "conv4_3.weight":
                    nextname = "conv4_4.weight"
                if name == "conv4_4.weight":
                    nextname = "conv4_5.weight"
                if name == "conv4_5.weight":
                    nextname = "conv4_6.weight"
                if name == "conv4_6.weight":
                    nextname = "fc.weight"  
                total_filters = w.shape[0]
                lastindices = drop_info[lastname] if name != "conv4_1.weight" else drop_info["conv3_6.weight"]
                non_masked_filter_ids = list(range(total_filters))
                sub_param_1 = torch.index_select(w, 0, torch.tensor(non_masked_filter_ids))
                sub_param = torch.index_select(sub_param_1, 1, torch.tensor(lastindices))
                lastname = name
                drop_info[name] = non_masked_filter_ids
            elif 'bn' in name:
                non_masked_filter_ids = drop_info[lastname]
                sub_param = torch.index_select(w,0,torch.tensor(non_masked_filter_ids))
            else:
                sub_param = w
            subparams.append(sub_param.numpy())
    return subparams