import numpy as np
import random
import torch
from models import AlexNet
from typing import Dict, List
from copy import deepcopy
from collections import OrderedDict
from flwr.server.strategy.aggregate import aggregate
from flwr.common import parameters_to_ndarrays, FitRes

NUM_CHANNELS = 3
CLASSES = 10

Param_by_layers = [['fc.weight', 'fc.bias'],
                   ['fc2.weight'],
                   ['fc1.weight'],
                   ['conv5.weight', 'conv5.bias', 'bn5.weight', 'bn5.bias', 'bn5.running_mean', 'bn5.running_var'],
                   ['conv4.weight', 'conv4.bias', 'bn4.weight', 'bn4.bias', 'bn4.running_mean', 'bn4.running_var'],
                   ['conv3.weight', 'conv3.bias', 'bn3.weight', 'bn3.bias', 'bn3.running_mean', 'bn3.running_var'],
                   ['conv2.weight', 'conv2.bias', 'bn2.weight', 'bn2.bias', 'bn2.running_mean', 'bn2.running_var'],
                   ['conv1.weight', 'conv1.bias', 'bn1.weight', 'bn1.bias', 'bn1.running_mean', 'bn1.running_var']]

def get_parameters(net:torch.nn.Module) -> List[np.ndarray]: # Access the parameters of a neural network 
  return [val.cpu().numpy() for _, val in net.state_dict().items()]

def get_filters(net:torch.nn.Module) -> List[np.ndarray]:
    params_list = []
    for k, v in net.state_dict().items():
        if "num_batches" not in k:
            params_list.append(v.cpu().numpy())       
    return params_list

def set_filters(net:torch.nn.Module, parameters: List[np.ndarray]): # modify the parameters of a neural network
    param_set_index = 0
    all_names = []
    all_params = []
    old_param_dict = net.state_dict()
    for k, _ in old_param_dict.items():
        if "num_batches" not in k:
            all_params.append(parameters[param_set_index])
            all_names.append(k)
            param_set_index += 1
    params_dict = zip(all_names, all_params)
    state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
    net.load_state_dict(state_dict, strict=False)

def get_updated_layers(model:AlexNet, lf:int) -> Dict:
    # First, get parameters that are to be sent by the number of frozen layers:
    param_names = []
    for i in range(len(Param_by_layers)-lf):
        param_names += Param_by_layers[i]
    layer_dict = {}
    for k,v in model.state_dict().items():
        if k in param_names: # which means this layer is not frozen in the local training
            layer_dict[k] = v.cpu().numpy()
    return layer_dict

def get_random_updated_layers(model:AlexNet, frozen_layers:List[str]) -> Dict:
    param_names = []
    if not ('conv1' in frozen_layers):
        for n in  ['conv1.weight', 'conv1.bias', 'bn1.weight', 'bn1.bias', 'bn1.running_mean', 'bn1.running_var']:
            param_names.append(n)
    if not ('conv2' in frozen_layers):
        for n in  ['conv2.weight', 'conv2.bias', 'bn2.weight', 'bn2.bias', 'bn2.running_mean', 'bn2.running_var']:
            param_names.append(n)
    if not ('conv3' in frozen_layers):
        for n in  ['conv3.weight', 'conv3.bias', 'bn3.weight', 'bn3.bias', 'bn3.running_mean', 'bn3.running_var']:
            param_names.append(n)
    if not ('conv4' in frozen_layers):
        for n in  ['conv4.weight', 'conv4.bias', 'bn4.weight', 'bn4.bias', 'bn4.running_mean', 'bn4.running_var']:
            param_names.append(n)
    if not ('conv5' in frozen_layers):
        for n in  ['conv5.weight', 'conv5.bias', 'bn5.weight', 'bn5.bias', 'bn5.running_mean', 'bn5.running_var']:
            param_names.append(n)
    if not ('fc1' in frozen_layers):
        n = 'fc1.weight'
        param_names.append(n)
    if not('fc2' in frozen_layers):
        n = 'fc2.weight'
        param_names.append(n)
    layer_dict = {}
    for k,v in model.state_dict().items():
        if k in param_names or k in ['fc.weight', 'fc.bias']: # which means this layer is not frozen in the local training
            layer_dict[k] = v.cpu().numpy()
    return layer_dict

def get_layer_gradients(old_layerdict:Dict, new_layerdict:Dict) -> Dict:
    layer_gradients = {}
    for k in new_layerdict.keys():
        layer_gradients[k] = new_layerdict[k] - (old_layerdict[k]).cpu().numpy()
    return layer_gradients

def generate_filters_random(global_model:AlexNet, rate):
    drop_information = {}
    if rate >= 0.99:
        return drop_information, get_filters(global_model)
    param_dict = global_model.state_dict()
    subparams = []
    for name in param_dict.keys():
        if "num_batches" not in name:
            w = param_dict[name].cpu()
            if name == 'conv1.weight':
                total_filters = w.shape[0]
                num_selected_filters = max(1, int(total_filters * rate))
                non_masked_filter_ids = sorted(random.sample(list(range(total_filters)), num_selected_filters))
                sub_param = torch.index_select(w,0,torch.tensor(non_masked_filter_ids))
            elif name in ['conv1.bias', 'bn1.weight', 'bn1.bias', 'bn1.running_mean', 'bn1.running_var']:
                total_filters = w.shape[0]
                num_selected_filters = max(1, int(total_filters * rate))
                non_masked_filter_ids = drop_information['conv1.weight']
                sub_param = torch.index_select(w,0,torch.tensor(non_masked_filter_ids))

            elif name == 'conv2.weight':
                total_filters = w.shape[0]
                num_selected_filters = max(1, int(total_filters * rate))
                non_masked_filter_ids = sorted(random.sample(list(range(total_filters)), num_selected_filters))
                lastindices = drop_information['conv1.weight']
                sub_param_1 = torch.index_select(w, 0, torch.tensor(non_masked_filter_ids))
                sub_param = torch.index_select(sub_param_1, 1, torch.tensor(lastindices))
            elif name in ['conv2.bias', 'bn2.weight', 'bn2.bias', 'bn2.running_mean', 'bn2.running_var']:
                total_filters = w.shape[0]
                num_selected_filters = max(1, int(total_filters * rate))
                non_masked_filter_ids = drop_information['conv2.weight']
                sub_param = torch.index_select(w,0,torch.tensor(non_masked_filter_ids))

            elif name == 'conv3.weight':
                total_filters = w.shape[0]
                num_selected_filters = max(1, int(total_filters * rate))
                non_masked_filter_ids = sorted(random.sample(list(range(total_filters)), num_selected_filters))
                lastindices = drop_information['conv2.weight']
                sub_param_1 = torch.index_select(w, 0, torch.tensor(non_masked_filter_ids))
                sub_param = torch.index_select(sub_param_1, 1, torch.tensor(lastindices))
            elif name in ['conv3.bias', 'bn3.weight', 'bn3.bias', 'bn3.running_mean', 'bn3.running_var']:
                total_filters = w.shape[0]
                num_selected_filters = max(1, int(total_filters * rate))
                non_masked_filter_ids = drop_information['conv3.weight']
                sub_param = torch.index_select(w,0,torch.tensor(non_masked_filter_ids))

            elif name == 'conv4.weight':
                total_filters = w.shape[0]
                num_selected_filters = max(1, int(total_filters * rate))
                non_masked_filter_ids = sorted(random.sample(list(range(total_filters)), num_selected_filters))
                lastindices = drop_information['conv3.weight']
                sub_param_1 = torch.index_select(w, 0, torch.tensor(non_masked_filter_ids))
                sub_param = torch.index_select(sub_param_1, 1, torch.tensor(lastindices))
            elif name in ['conv4.bias', 'bn4.weight', 'bn4.bias', 'bn4.running_mean', 'bn4.running_var']:
                total_filters = w.shape[0]
                num_selected_filters = max(1, int(total_filters * rate))
                non_masked_filter_ids = drop_information['conv4.weight']
                sub_param = torch.index_select(w,0,torch.tensor(non_masked_filter_ids))
            
            elif name == 'conv5.weight':
                total_filters = w.shape[0]
                num_selected_filters = max(1, int(total_filters * rate))
                non_masked_filter_ids = sorted(random.sample(list(range(total_filters)), num_selected_filters))
                lastindices = drop_information['conv4.weight']
                sub_param_1 = torch.index_select(w, 0, torch.tensor(non_masked_filter_ids))
                sub_param = torch.index_select(sub_param_1, 1, torch.tensor(lastindices))
            elif name in ['conv5.bias', 'bn5.weight', 'bn5.bias', 'bn5.running_mean', 'bn5.running_var']: 
                total_filters = w.shape[0]
                num_selected_filters = max(1, int(total_filters * rate))
                non_masked_filter_ids = drop_information['conv5.weight']
                sub_param = torch.index_select(w,0,torch.tensor(non_masked_filter_ids))

            elif name == 'fc1.weight':
                total_filters = w.shape[0]
                num_selected_filters = max(1, int(total_filters * rate))
                non_masked_filter_ids = sorted(random.sample(list(range(total_filters)), num_selected_filters))
                lastindices = []
                for q in drop_information['conv5.weight']:
                    for q_ in range(q*7*7, (q+1)*7*7):
                        lastindices.append(q_)
                sub_param_1 = torch.index_select(w, 0, torch.tensor(non_masked_filter_ids))
                sub_param = torch.index_select(sub_param_1, 1, torch.tensor(lastindices))
            
            elif name == 'fc2.weight':
                total_filters = w.shape[0]
                num_selected_filters = max(1, int(total_filters * rate))
                non_masked_filter_ids = sorted(random.sample(list(range(total_filters)), num_selected_filters))
                lastindices = drop_information['fc1.weight']
                sub_param_1 = torch.index_select(w, 0, torch.tensor(non_masked_filter_ids))
                sub_param = torch.index_select(sub_param_1, 1, torch.tensor(lastindices))
            
            elif name == 'fc.weight':
                non_masked_filter_ids = drop_information['fc2.weight']
                sub_param = torch.index_select(w,1,torch.tensor(non_masked_filter_ids))

            else: # fc.bias
                sub_param = w
                non_masked_filter_ids = list(range(CLASSES))

            drop_information[name] = non_masked_filter_ids
            subparams.append(sub_param.numpy())
    return drop_information, subparams

def generate_subnet_ordered(global_model:AlexNet, rate):
    drop_information = {}
    if rate >= 0.99:
        return drop_information, get_filters(global_model)
    param_dict = global_model.state_dict()
    subparams = []
    for name in param_dict.keys():
        if "num_batches" not in name:
            w = param_dict[name].cpu()
            if name == 'conv1.weight':
                total_filters = w.shape[0]
                num_selected_filters = max(1, int(total_filters * rate))
                non_masked_filter_ids = list(range(num_selected_filters))
                sub_param = torch.index_select(w,0,torch.tensor(non_masked_filter_ids))
            elif name in ['conv1.bias', 'bn1.weight', 'bn1.bias', 'bn1.running_mean', 'bn1.running_var']: 
                total_filters = w.shape[0]
                num_selected_filters = max(1, int(total_filters * rate))
                non_masked_filter_ids = drop_information['conv1.weight']
                sub_param = torch.index_select(w,0,torch.tensor(non_masked_filter_ids))
            elif name == 'conv2.weight':
                total_filters = w.shape[0]
                num_selected_filters = max(1, int(total_filters * rate))
                non_masked_filter_ids = list(range(num_selected_filters))
                lastindices = drop_information['conv1.weight']
                sub_param_1 = torch.index_select(w, 0, torch.tensor(non_masked_filter_ids))
                sub_param = torch.index_select(sub_param_1, 1, torch.tensor(lastindices))
            elif name in ['conv2.bias', 'bn2.weight', 'bn2.bias', 'bn2.running_mean', 'bn2.running_var']:
                total_filters = w.shape[0]
                num_selected_filters = max(1, int(total_filters * rate))
                non_masked_filter_ids = drop_information['conv2.weight']
                sub_param = torch.index_select(w,0,torch.tensor(non_masked_filter_ids))
            elif name == 'conv3.weight':
                total_filters = w.shape[0]
                num_selected_filters = max(1, int(total_filters * rate))
                non_masked_filter_ids = list(range(num_selected_filters))
                lastindices = drop_information['conv2.weight']
                sub_param_1 = torch.index_select(w, 0, torch.tensor(non_masked_filter_ids))
                sub_param = torch.index_select(sub_param_1, 1, torch.tensor(lastindices))
            elif name in ['conv3.bias', 'bn3.weight', 'bn3.bias', 'bn3.running_mean', 'bn3.running_var']:
                total_filters = w.shape[0]
                num_selected_filters = max(1, int(total_filters * rate))
                non_masked_filter_ids = drop_information['conv3.weight']
                sub_param = torch.index_select(w,0,torch.tensor(non_masked_filter_ids))
            elif name == 'conv4.weight':
                total_filters = w.shape[0]
                num_selected_filters = max(1, int(total_filters * rate))
                non_masked_filter_ids = list(range(num_selected_filters))
                lastindices = drop_information['conv3.weight']
                sub_param_1 = torch.index_select(w, 0, torch.tensor(non_masked_filter_ids))
                sub_param = torch.index_select(sub_param_1, 1, torch.tensor(lastindices))
            elif name in ['conv4.bias', 'bn4.weight', 'bn4.bias', 'bn4.running_mean', 'bn4.running_var']:
                total_filters = w.shape[0]
                num_selected_filters = max(1, int(total_filters * rate))
                non_masked_filter_ids = drop_information['conv4.weight']
                sub_param = torch.index_select(w,0,torch.tensor(non_masked_filter_ids))
            elif name == 'conv5.weight':
                total_filters = w.shape[0]
                num_selected_filters = max(1, int(total_filters * rate))
                non_masked_filter_ids = list(range(num_selected_filters))
                lastindices = drop_information['conv4.weight']
                sub_param_1 = torch.index_select(w, 0, torch.tensor(non_masked_filter_ids))
                sub_param = torch.index_select(sub_param_1, 1, torch.tensor(lastindices))
            elif name in ['conv5.bias', 'bn5.weight', 'bn5.bias', 'bn5.running_mean', 'bn5.running_var']: 
                total_filters = w.shape[0]
                num_selected_filters = max(1, int(total_filters * rate))
                non_masked_filter_ids = drop_information['conv5.weight']
                sub_param = torch.index_select(w,0,torch.tensor(non_masked_filter_ids))
            elif name == 'fc1.weight':
                total_filters = w.shape[0]
                num_selected_filters = max(1, int(total_filters * rate))
                non_masked_filter_ids = list(range(num_selected_filters))
                lastindices = []
                for q in drop_information['conv5.weight']:
                    for q_ in range(q*7*7, (q+1)*7*7):
                        lastindices.append(q_)
                sub_param_1 = torch.index_select(w, 0, torch.tensor(non_masked_filter_ids))
                sub_param = torch.index_select(sub_param_1, 1, torch.tensor(lastindices))
            elif name == 'fc2.weight':
                total_filters = w.shape[0]
                num_selected_filters = max(1, int(total_filters * rate))
                non_masked_filter_ids = list(range(num_selected_filters))
                lastindices = drop_information['fc1.weight']
                sub_param_1 = torch.index_select(w, 0, torch.tensor(non_masked_filter_ids))
                sub_param = torch.index_select(sub_param_1, 1, torch.tensor(lastindices))
            elif name == 'fc2.bias':
                total_filters = w.shape[0]
                num_selected_filters = max(1, int(total_filters * rate))
                non_masked_filter_ids = drop_information['fc2.weight']
                sub_param = torch.index_select(w,0,torch.tensor(non_masked_filter_ids))
            elif name == 'fc.weight':
                non_masked_filter_ids = drop_information['fc2.weight']
                sub_param = torch.index_select(w,1,torch.tensor(non_masked_filter_ids))
            else: # fc.bias
                sub_param = w
                non_masked_filter_ids = list(range(CLASSES))
            drop_information[name] = non_masked_filter_ids
            subparams.append(sub_param.numpy())
    return drop_information, subparams

def generate_subnet_SLT(global_model:AlexNet, rate, lf=0):
    drop_information = {}
    if rate >= 0.99:
        return drop_information, get_filters(global_model)
    param_dict = global_model.state_dict()
    subparams = []
    for name in param_dict.keys():
        if "num_batches" not in name:
            w = param_dict[name].cpu()
            if name == 'conv1.weight':
                r = rate if lf >= 1 else 1.0
                total_filters = w.shape[0]
                num_selected_filters = max(1, int(total_filters * r))
                non_masked_filter_ids = list(range(num_selected_filters))
                sub_param = torch.index_select(w,0,torch.tensor(non_masked_filter_ids))
            elif name in ['conv1.bias', 'bn1.weight', 'bn1.bias', 'bn1.running_mean', 'bn1.running_var']:
                r = rate if lf >= 1 else 1.0 
                total_filters = w.shape[0]
                num_selected_filters = max(1, int(total_filters * r))
                non_masked_filter_ids = drop_information['conv1.weight']
                sub_param = torch.index_select(w,0,torch.tensor(non_masked_filter_ids))
            elif name == 'conv2.weight':
                r = rate if lf >= 2 else 1.0
                total_filters = w.shape[0]
                num_selected_filters = max(1, int(total_filters * r))
                non_masked_filter_ids = list(range(num_selected_filters))
                lastindices = drop_information['conv1.weight']
                sub_param_1 = torch.index_select(w, 0, torch.tensor(non_masked_filter_ids))
                sub_param = torch.index_select(sub_param_1, 1, torch.tensor(lastindices))
            elif name in ['conv2.bias', 'bn2.weight', 'bn2.bias', 'bn2.running_mean', 'bn2.running_var']:
                r = rate if lf >= 2 else 1.0
                total_filters = w.shape[0]
                num_selected_filters = max(1, int(total_filters * r))
                non_masked_filter_ids = drop_information['conv2.weight']
                sub_param = torch.index_select(w,0,torch.tensor(non_masked_filter_ids))
            elif name == 'conv3.weight':
                r = rate if lf >= 3 else 1.0
                total_filters = w.shape[0]
                num_selected_filters = max(1, int(total_filters * r))
                non_masked_filter_ids = list(range(num_selected_filters))
                lastindices = drop_information['conv2.weight']
                sub_param_1 = torch.index_select(w, 0, torch.tensor(non_masked_filter_ids))
                sub_param = torch.index_select(sub_param_1, 1, torch.tensor(lastindices))
            elif name in ['conv3.bias', 'bn3.weight', 'bn3.bias', 'bn3.running_mean', 'bn3.running_var']:
                r = rate if lf >= 3 else 1.0
                total_filters = w.shape[0]
                num_selected_filters = max(1, int(total_filters * r))
                non_masked_filter_ids = drop_information['conv3.weight']
                sub_param = torch.index_select(w,0,torch.tensor(non_masked_filter_ids))
            elif name == 'conv4.weight':
                r = rate if lf >= 4 else 1.0
                total_filters = w.shape[0]
                num_selected_filters = max(1, int(total_filters * r))
                non_masked_filter_ids = list(range(num_selected_filters))
                lastindices = drop_information['conv3.weight']
                sub_param_1 = torch.index_select(w, 0, torch.tensor(non_masked_filter_ids))
                sub_param = torch.index_select(sub_param_1, 1, torch.tensor(lastindices))
            elif name in ['conv4.bias', 'bn4.weight', 'bn4.bias', 'bn4.running_mean', 'bn4.running_var']:
                r = rate if lf >= 4 else 1.0
                total_filters = w.shape[0]
                num_selected_filters = max(1, int(total_filters * r))
                non_masked_filter_ids = drop_information['conv4.weight']
                sub_param = torch.index_select(w,0,torch.tensor(non_masked_filter_ids))
            elif name == 'conv5.weight':
                r = rate if lf >= 5 else 1.0
                total_filters = w.shape[0]
                num_selected_filters = max(1, int(total_filters * r))
                non_masked_filter_ids = list(range(num_selected_filters))
                lastindices = drop_information['conv4.weight']
                sub_param_1 = torch.index_select(w, 0, torch.tensor(non_masked_filter_ids))
                sub_param = torch.index_select(sub_param_1, 1, torch.tensor(lastindices))
            elif name in ['conv5.bias', 'bn5.weight', 'bn5.bias', 'bn5.running_mean', 'bn5.running_var']: 
                r = rate if lf >= 5 else 1.0
                total_filters = w.shape[0]
                num_selected_filters = max(1, int(total_filters * r))
                non_masked_filter_ids = drop_information['conv5.weight']
                sub_param = torch.index_select(w,0,torch.tensor(non_masked_filter_ids))
            elif name == 'fc1.weight':
                r = rate if lf >= 6 else 1.0
                total_filters = w.shape[0]
                num_selected_filters = max(1, int(total_filters * r))
                non_masked_filter_ids = list(range(num_selected_filters))
                lastindices = []
                for q in drop_information['conv5.weight']:
                    for q_ in range(q*7*7, (q+1)*7*7):
                        lastindices.append(q_)
                sub_param_1 = torch.index_select(w, 0, torch.tensor(non_masked_filter_ids))
                sub_param = torch.index_select(sub_param_1, 1, torch.tensor(lastindices))
            elif name == 'fc2.weight':
                r = rate if lf >= 7 else 1.0
                total_filters = w.shape[0]
                num_selected_filters = max(1, int(total_filters * r))
                non_masked_filter_ids = list(range(num_selected_filters))
                lastindices = drop_information['fc1.weight']
                sub_param_1 = torch.index_select(w, 0, torch.tensor(non_masked_filter_ids))
                sub_param = torch.index_select(sub_param_1, 1, torch.tensor(lastindices))
            elif name == 'fc2.bias':
                r = rate if lf >= 7 else 1.0
                total_filters = w.shape[0]
                num_selected_filters = max(1, int(total_filters * r))
                non_masked_filter_ids = drop_information['fc2.weight']
                sub_param = torch.index_select(w,0,torch.tensor(non_masked_filter_ids))
            elif name == 'fc.weight':
                non_masked_filter_ids = drop_information['fc2.weight']
                sub_param = torch.index_select(w,1,torch.tensor(non_masked_filter_ids))
            else: # fc.bias
                sub_param = w
                non_masked_filter_ids = list(range(CLASSES))
            drop_information[name] = non_masked_filter_ids
            subparams.append(sub_param.numpy())
    return drop_information, subparams

def dropout_aggregation(Fit_res:List[FitRes], global_param:List[np.ndarray]):
    Aggregation_Dict = {}
    Aggregated_params = {}
    full_results = []
    for fit_res in Fit_res:
        param, num, merge_info = parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples, fit_res.metrics["drop_info"]
        if len(merge_info) == 0:
            full_results.append((param, num))
            for l1 in range(len(param)):
                layer = param[l1]
                for l2 in range(len(layer)):
                    filter = layer[l2]
                    if len(layer.shape) == 3:
                        for l3 in range(len(filter)):
                            if (l1,l2,l3) in Aggregation_Dict.keys():
                                Aggregation_Dict[(l1,l2,l3)].append(([filter[l3]], num))
                            else:
                                Aggregation_Dict[(l1,l2,l3)] = [([filter[l3]], num)]
                    else:
                        if (l1,l2) in Aggregation_Dict.keys():
                            Aggregation_Dict[(l1,l2)].append(([filter], num))
                        else:
                            Aggregation_Dict[(l1,l2)] = [([filter], num)]
        else:
            last_layer_indices = list(range(NUM_CHANNELS))
            layer_count = 0
            for k in merge_info.keys():
                selected_filters = merge_info[k]
                layer = param[layer_count]
                i1 = 0
                if (not "weight" in k) or not ("conv" in k or "fc" in k):
                    for f in selected_filters:
                        if (layer_count, f) in Aggregation_Dict.keys():
                            Aggregation_Dict[(layer_count, f)].append(([layer[i1]], num))
                        else:
                            Aggregation_Dict[(layer_count, f)] = [([layer[i1]], num)]
                elif k == 'fc1.weight':
                    for f in selected_filters:
                        j1 = 0
                        for j_ in last_layer_indices:
                            for j in range(j_*7*7, (j_+1)*7*7):
                                if (layer_count,f,j) in Aggregation_Dict.keys():
                                    Aggregation_Dict[(layer_count,f,j)].append(([layer[i1][j1]], num))
                                else:
                                    Aggregation_Dict[(layer_count,f,j)] = [([layer[i1][j1]], num)]
                                j1 += 1
                        i1 += 1
                elif k != "fc.weight":
                    for f in selected_filters:
                        j1 = 0
                        for j in last_layer_indices:
                            if (layer_count,f,j) in Aggregation_Dict.keys():
                                Aggregation_Dict[(layer_count,f,j)].append(([layer[i1][j1]], num))
                            else:
                                Aggregation_Dict[(layer_count,f,j)] = [([layer[i1][j1]], num)]
                            j1 += 1
                        i1 += 1
                else:
                    for f in range(CLASSES):
                        j1 = 0
                        for j in last_layer_indices:
                            if (layer_count,f,j) in Aggregation_Dict.keys():
                                Aggregation_Dict[(layer_count,f,j)].append(([layer[f][j1]], num))
                            else:
                                Aggregation_Dict[(layer_count,f,j)] = [([layer[f][j1]], num)]
                            j1 += 1
                layer_count += 1
                last_layer_indices = selected_filters
    for z, p in Aggregation_Dict.items():
        Aggregated_params[z] = aggregate(p)
    full_param = aggregate(full_results) if len(full_results) > 0 else deepcopy(global_param)
    for Key in Aggregated_params.keys():
        if len(Key) == 2:
            layer_idx, filter = Key
            full_param[layer_idx][filter] = Aggregated_params[Key][0]
        else:
            layer_idx, filter, last_filter = Key
            full_param[layer_idx][filter][last_filter] = Aggregated_params[Key][0]
    return full_param


def compute_sampling_prob(model:AlexNet, device="cpu"):

    # calculating conv2 sampling prob:
    conv2_prob = []
    conv2_normalizer = 0.0
    for i in range(model.conv2.weight.shape[1]):
        index_tensor = torch.tensor([i]).to(device)
        output_filter = torch.index_select(model.conv2.weight, 1, index_tensor)
        #input_filter = torch.index_select(model.conv1.weight, 0, index_tensor)
        output_filter_norm = torch.norm(output_filter, p='fro').item()
        #input_filter_norm = torch.norm(input_filter, p='fro')
        conv2_normalizer += output_filter_norm
        conv2_prob.append(output_filter_norm)
    conv2_prob = np.array(conv2_prob) / conv2_normalizer
    #print(f"conv2 probability = {conv2_prob}")

    # calculating conv3 sampling prob:
    conv3_prob = []
    conv3_normalizer = 0.0
    for i in range(model.conv3.weight.shape[1]):
        index_tensor = torch.tensor([i]).to(device)
        output_filter = torch.index_select(model.conv3.weight, 1, index_tensor)
        #input_filter = torch.index_select(model.conv2.weight, 0, index_tensor)
        output_filter_norm = torch.norm(output_filter, p='fro').item()
        #input_filter_norm = torch.norm(input_filter, p='fro')
        conv3_normalizer += output_filter_norm
        conv3_prob.append(output_filter_norm)
    conv3_prob = np.array(conv3_prob) / conv3_normalizer
    #print(f"conv3 probability = {conv3_prob}")

    # calculating conv4 sampling prob:
    conv4_prob = []
    conv4_normalizer = 0.0
    for i in range(model.conv4.weight.shape[1]):
        index_tensor = torch.tensor([i]).to(device)
        output_filter = torch.index_select(model.conv4.weight, 1, index_tensor)
        #input_filter = torch.index_select(model.conv3.weight, 0, index_tensor)
        output_filter_norm = torch.norm(output_filter, p='fro').item()
        #input_filter_norm = torch.norm(input_filter, p='fro')
        conv4_normalizer += output_filter_norm
        conv4_prob.append(output_filter_norm)
    conv4_prob = np.array(conv4_prob) / conv4_normalizer
    #print(f"conv4 probability = {conv4_prob}")
    
    # calculating conv5 sampling prob:
    conv5_prob = []
    conv5_normalizer = 0.0
    for i in range(model.conv5.weight.shape[1]):
        index_tensor = torch.tensor([i]).to(device)
        output_filter = torch.index_select(model.conv5.weight, 1, index_tensor)
        #input_filter = torch.index_select(model.conv4.weight, 0, index_tensor)
        output_filter_norm = torch.norm(output_filter, p='fro').item()
        #input_filter_norm = torch.norm(input_filter, p='fro')
        conv5_normalizer += output_filter_norm
        conv5_prob.append(output_filter_norm)
    conv5_prob = np.array(conv5_prob) / conv5_normalizer

    return {'conv2.weight':conv2_prob, 'conv3.weight':conv3_prob, 'conv4.weight':conv4_prob, 'conv5.weight':conv5_prob}

def approximate_convolution(globalmodel:AlexNet, sampling_probs:Dict, lf:int, prob=1.0) -> List[np.ndarray]:
    if prob >= 0.95 or lf == 0:
        return get_filters(globalmodel)
    #num_approx_conv = min(lf, 4)
    frac = min(1.0, prob)
    approximated_conv_layers = []
    if lf >= 2:
        approximated_conv_layers.append("conv2.weight")
    if lf >= 3:
        approximated_conv_layers.append("conv3.weight")
    if lf >= 4:
        approximated_conv_layers.append("conv4.weight")
    if lf >= 5:
        approximated_conv_layers.append("conv5.weight")
    #approximated_conv_layers = ['conv2.weight', 'conv3.weight', 'conv4.weight', 'conv5.weight'][:num_approx_conv]
    param_dict = globalmodel.state_dict()
    subparams = []
    drop_info = {}
    selected_filters = {}
    for frozen_layer in approximated_conv_layers:
        num_selected_filters = int(len(sampling_probs[frozen_layer]) * frac)
        selected_filters[frozen_layer] = np.random.choice(list(range(len(sampling_probs[frozen_layer]))), size=num_selected_filters, replace=False, p=sampling_probs[frozen_layer])

    for name in param_dict.keys():
        if "num_batches" not in name:
            w = param_dict[name].cpu()

            if name == 'conv1.weight':
                total_filters = w.shape[0]
                non_masked_filter_ids = list(range(total_filters)) if not ('conv2.weight' in approximated_conv_layers) else selected_filters['conv2.weight']
                drop_info[name] = non_masked_filter_ids
                sub_param = torch.index_select(w,0,torch.tensor(non_masked_filter_ids))
            elif name in ['conv1.bias', 'bn1.weight', 'bn1.bias', 'bn1.running_mean', 'bn1.running_var']:
                #w = param_dict[name] 
                total_filters = w.shape[0]
                num_selected_filters = max(1, int(total_filters * frac))
                non_masked_filter_ids = drop_info['conv1.weight']
                sub_param = torch.index_select(w,0,torch.tensor(non_masked_filter_ids))

            elif name == 'conv2.weight':
                total_filters = w.shape[0]
                #num_output_channels = max(1, int(total_filters * frac)) if 'conv3.weight' in approximated_conv_layers else total_filters
                non_masked_filter_ids = list(range(total_filters)) if not ('conv3.weight' in approximated_conv_layers) else selected_filters['conv3.weight']
                lastindices = selected_filters['conv2.weight']
                sub_param_1 = torch.index_select(w, 0, torch.tensor(non_masked_filter_ids))
                sub_param = torch.index_select(sub_param_1, 1, torch.tensor(lastindices))
                drop_info[name] = non_masked_filter_ids
            elif name in ['conv2.bias', 'bn2.weight', 'bn2.bias', 'bn2.running_mean', 'bn2.running_var']:
                #w = param_dict[name].cpu()
                total_filters = w.shape[0]
                num_selected_filters = max(1, int(total_filters * frac))
                non_masked_filter_ids = drop_info['conv2.weight']
                sub_param = torch.index_select(w,0,torch.tensor(non_masked_filter_ids))

            elif name == 'conv3.weight':
                total_filters = w.shape[0]
                #num_output_channels = max(1, int(total_filters * frac)) if 'conv3.weight' in approximated_conv_layers else total_filters
                non_masked_filter_ids = list(range(total_filters)) if not ('conv4.weight' in approximated_conv_layers) else selected_filters['conv4.weight']
                lastindices = drop_info['conv2.weight'] if not ('conv3.weight' in approximated_conv_layers) else selected_filters['conv3.weight']
                sub_param_1 = torch.index_select(w, 0, torch.tensor(non_masked_filter_ids))
                sub_param = torch.index_select(sub_param_1, 1, torch.tensor(lastindices))
                drop_info[name] = non_masked_filter_ids
            elif name in ['conv3.bias', 'bn3.weight', 'bn3.bias', 'bn3.running_mean', 'bn3.running_var']:
                #w = param_dict[name].cpu()
                total_filters = w.shape[0]
                num_selected_filters = max(1, int(total_filters * frac))
                non_masked_filter_ids = drop_info['conv3.weight']
                sub_param = torch.index_select(w,0,torch.tensor(non_masked_filter_ids))
            
            elif name == 'conv4.weight':
                total_filters = w.shape[0]
                #num_output_channels = max(1, int(total_filters * frac)) if 'conv3.weight' in approximated_conv_layers else total_filters
                non_masked_filter_ids = list(range(total_filters)) if not ('conv5.weight' in approximated_conv_layers) else selected_filters['conv5.weight']
                lastindices = drop_info['conv3.weight'] if not ('conv4.weight' in approximated_conv_layers) else selected_filters['conv4.weight']
                sub_param_1 = torch.index_select(w, 0, torch.tensor(non_masked_filter_ids))
                sub_param = torch.index_select(sub_param_1, 1, torch.tensor(lastindices))
                drop_info[name] = non_masked_filter_ids
            elif name in ['conv4.bias', 'bn4.weight', 'bn4.bias', 'bn4.running_mean', 'bn4.running_var']:
                #w = param_dict[name].cpu()
                total_filters = w.shape[0]
                num_selected_filters = max(1, int(total_filters * frac))
                non_masked_filter_ids = drop_info['conv4.weight']
                sub_param = torch.index_select(w,0,torch.tensor(non_masked_filter_ids))

            elif name == 'conv5.weight':
                total_filters = w.shape[0]
                non_masked_filter_ids = list(range(total_filters))
                lastindices = drop_info['conv4.weight'] if not ('conv5.weight' in approximated_conv_layers) else selected_filters['conv5.weight']
                sub_param_1 = torch.index_select(w, 0, torch.tensor(non_masked_filter_ids))
                sub_param = torch.index_select(sub_param_1, 1, torch.tensor(lastindices))
                drop_info[name] = non_masked_filter_ids

            else:
                sub_param = w

            subparams.append(sub_param.numpy())
    return subparams