import torch
import copy
import numpy as np

def merge_weight(client_list,selected_index_list = None):
    
    if selected_index_list == None:
        selected_index_list = [i for i in range(len(client_list))]  
    client_state_dict_list = []
    for selected_index in selected_index_list:
            client_state_dict = client_list[selected_index].get_model_weight()
            client_state_dict_list.append(client_state_dict)
    
    merged_dict = {}
    merged_dict = copy.deepcopy(client_list[selected_index_list[0]].get_model_weight())
    
    for key in merged_dict:
        merged_dict[key] = torch.mean(torch.stack([client_state_dict[key]*1.0 for client_state_dict in client_state_dict_list]),0).type(merged_dict[key].type())
    for client in client_list:
        client.load_model_weight(merged_dict)
    return merged_dict


def list_unique(example_list):
    
    unique_list = []
    unique_index_list_list = []
    for i in range(len(example_list)):
        No_equals = True
        for j in range(len(unique_list)):
            if example_list[i] == unique_list[j]:
                No_equals = False
                unique_index_list_list[j].append(i)
                
        if No_equals:
            unique_list.append(example_list[i])
            unique_index_list_list.append([i])
    return unique_list, unique_index_list_list

def merge_heterogeneous_weight(client_list,skip_layer_list = []):
    skip_layer_list = skip_layer_list+['weight_mask','bn.num_batches_tracked']
    network_strcuture_list = []
    for client in client_list:
        network_strcuture_list.append(client.model.layer_parameter_list)
    # get client index in cliend_list whose layer_parameter_list are same.
    network_strcuture_of_each_group, client_index_of_each_group = list_unique(network_strcuture_list)
    merged_dict_list = []
    for client_indexs in client_index_of_each_group:
        client_group = [client_list[index] for index in client_indexs]
        merged_dict = merge_weight(client_group)
        merged_dict_list.append(merged_dict)
    
    merged_dict_of_each_group = merge_heterogeneous_weight_dict(merged_dict_list,client_index_of_each_group,skip_layer_list)
    

    for group_index in range(len(merged_dict_of_each_group)):
        for client_index in client_index_of_each_group[group_index]:
            client_list[client_index].load_model_weight(merged_dict_of_each_group[group_index])
    
    return merged_dict_of_each_group, client_index_of_each_group


def merge_a_layer(To_be_merged_list):
    layer_shape_list = []
    
    for item in To_be_merged_list:
        layer_shape_list.append(torch.tensor(item[0].shape))
        
    max_size = torch.max(torch.stack(layer_shape_list),0).values.numpy().tolist()
    padded_layer_weight_list = []
    padded_weight_for_layer_list = []
    for item in To_be_merged_list:
        temp = nan_padd_to_a_shape_with_phase_alignment(item[0],max_size)
        padded_layer_weight_list.append(temp)
        temp = nan_padd_to_a_shape_with_phase_alignment(torch.ones(item[0].shape)*item[1],max_size)
        padded_weight_for_layer_list.append(temp)
        
    weight = torch.stack(padded_layer_weight_list)
    weight_for_weight = torch.stack(padded_weight_for_layer_list)
    temp1 =np.nansum(weight*weight_for_weight,0)
    temp2 = np.nansum(weight_for_weight,0)
    temp = temp1/temp2
    weight_list = []
    for item in To_be_merged_list:
        weight_shape = torch.tensor(item[0].shape).numpy().tolist()
        weight = get_sub_weight_from_big_weight(weight_shape,temp)
        weight_list.append(torch.tensor(weight))
    
    return weight_list
          


def merge_heterogeneous_weight_dict(merged_dict_list,client_index_of_each_group,skip_layer_list = []):
    ###########
    # only merged dict of same group
    # return merged_dict_list
    ###########
    skip_layer_list = ['num_batches_tracked']+ skip_layer_list
    deepest_dict = get_deepest_dict(merged_dict_list)
    for key in deepest_dict:
        ####### skip layer in skip_layer_list #######
        if True in [skip_layer in key for skip_layer in skip_layer_list]:
            continue
        ############################################# 
        To_be_merged_list = []
        for stated_dict, client_index in zip(merged_dict_list,client_index_of_each_group):
            if key in stated_dict:
                To_be_merged_list.append((stated_dict[key],len(client_index)))
                
        if len(To_be_merged_list)!=1:
            weight_list = merge_a_layer(To_be_merged_list)
            index = 0
            for stated_dict, client_index in zip(merged_dict_list,client_index_of_each_group):
                if key in stated_dict:
                    stated_dict[key] = weight_list[index]
                    index = index+1
    

            
    return merged_dict_list

############################################

def put_weight_A_to_B(merged_dict_list,dict_strength_list = [100,0.01],skip_layer_list = []):
    merged_dict_list = copy.deepcopy(merged_dict_list)
    skip_layer_list = ['num_batches_tracked']+ skip_layer_list
    deepest_dict = get_deepest_dict(merged_dict_list)
    for key in deepest_dict:
        ####### skip layer in skip_layer_list #######
        if True in [skip_layer in key for skip_layer in skip_layer_list]:
            continue
        ############################################# 
        
        To_be_merged_list = []
        for stated_dict, dict_strength in zip(merged_dict_list,dict_strength_list):
            if key in stated_dict:
                To_be_merged_list.append((stated_dict[key],dict_strength))
                
        if len(To_be_merged_list)!=1:
            weight_list = merge_a_layer(To_be_merged_list)
            index = 0
            for stated_dict, client_index in zip(merged_dict_list,dict_strength_list):
                if key in stated_dict:
                    stated_dict[key] = weight_list[index]
                    index = index+1
                    
    return merged_dict_list


    ############################################
def phase_alignment_index(large_kernel_size, small_kernel_size):
    large_kernel_left_padding_size = int((large_kernel_size-1)/2)
    small_kernel_left_padding_size = int((small_kernel_size-1)/2)
    start_index = large_kernel_left_padding_size-small_kernel_left_padding_size
    end_index = start_index+small_kernel_size
    return start_index, end_index

def get_deepest_dict(merged_dict_list):
    
    deepest_dict = {}
    number_of_key = 0
    for state_dict in merged_dict_list:
        if number_of_key<len(state_dict):
            number_of_key = len(state_dict)
            deepest_dict = state_dict
    return deepest_dict

def nan_padd_to_a_shape_with_phase_alignment(weight, shape):
    nan_padd = torch.zeros(shape)
    nan_padd = nan_padd*float('nan')
    weight_shape = weight.shape
    indexing_list = []
#     for i in range(len(weight_shape)):
#         if i == 0:
#             indexing_list.append(np.arange(0, weight_shape[i], 1, dtype=int))
#         elif i == 1:
#             indexing_list.append(np.arange(0, weight_shape[i], 1, dtype=int))
#         else:
#             start_index, end_index = phase_alignment_index(shape[i],weight_shape[i])
#             indexing_list.append(np.arange(start_index, end_index, 1, dtype=int))
#     nan_padd[tuple(indexing_list)] = weight

    if len(weight_shape) == 1:
        nan_padd[0:weight_shape[0]] = weight
    elif len(weight_shape) == 2:
        nan_padd[0:weight_shape[0],0:weight_shape[1]] = weight
    elif len(weight_shape) == 3:
        start_index_2, end_index_2 = phase_alignment_index(shape[2],weight_shape[2])
        nan_padd[0:weight_shape[0],0:weight_shape[1],start_index_2:end_index_2] = weight
    elif len(weight_shape) == 4:
        start_index_2, end_index_2 = phase_alignment_index(shape[2],weight_shape[2])
        start_index_3, end_index_3 = phase_alignment_index(shape[3],weight_shape[3])
        nan_padd[0:weight_shape[0],0:weight_shape[1],start_index_2:end_index_2,start_index_3:end_index_3] = weight

    return nan_padd

def get_sub_weight_from_big_weight(sub_weight_size,big_weight):
    shape = torch.tensor(big_weight.shape).numpy().tolist()
    weight_shape = sub_weight_size
    if len(weight_shape) == 1:
        weight = big_weight[0:weight_shape[0]]
    elif len(weight_shape) == 2:
        weight = big_weight[0:weight_shape[0],0:weight_shape[1]]
    elif len(weight_shape) == 3:
        start_index_2, end_index_2 = phase_alignment_index(shape[2],weight_shape[2])
        weight = big_weight[0:weight_shape[0],0:weight_shape[1],start_index_2:end_index_2]
    elif len(weight_shape) == 4:
        start_index_2, end_index_2 = phase_alignment_index(shape[2],weight_shape[2])
        start_index_3, end_index_3 = phase_alignment_index(shape[3],weight_shape[3])
        weight = big_weight[0:weight_shape[0],0:weight_shape[1],start_index_2:end_index_2,start_index_3:end_index_3]
    return weight