import matplotlib.pyplot as plt

def creat_1D_CNN_structure(base_structure,extend_term):
    
    kernel_sizes =  base_structure[0][0][-1]
    new_kernel_sizes = int(kernel_sizes*extend_term[0])
    channel_number = base_structure[0][0][1]
    new_channel_number = int(channel_number*extend_term[1])
    
    layers = len(base_structure)
    new_layer_number = int(layers*extend_term[2])
    
    extended_structure = [[(1,new_channel_number,new_kernel_sizes)]]
    for layers_index in range(new_layer_number-1):
        extended_structure.append([(new_channel_number,new_channel_number,new_kernel_sizes)])
        
    return extended_structure


def trim_merged_dict(merged_dict,adjustment_rate):
    for key in merged_dict:
        if 'bn_happy.running_var' in key:
            merged_dict[key] = merged_dict[key]/adjustment_rate
    return merged_dict


def zero_weight(state_dict):
    for key in state_dict:
        state_dict[key]= state_dict[key]*0  
        
    return state_dict

def samller_weight(state_dict):
    for key in state_dict:
        state_dict[key]= state_dict[key]*0.1  
        
    return state_dict


def model_distance(model_dict1,model_dict2):
    distance_list =[]
    for key in model_dict1:
        if 'conv1d.weight' in key:
            print(key)
            distance_list.append([model_dict1[key],model_dict2[key]])
    return distance_list


def calculate_calculation(structure):
    cal = 0
    for layer in structure:
        for kernel in layer:
            cal = cal+ kernel[0]*kernel[1]*kernel[2]+2*kernel[1]
    cal =cal+ structure[-1][-1][1]*n_class
    return cal

def plot_result(dataset_name, test_log_list, structure_list,log_path= 'log'):
    fig, (ax1, ax2) = plt.subplots(1, 2,figsize=(15,4))
    fig.suptitle('Horizontally stacked subplots')
    index = 1
    for test_log in test_log_list:
        ax1.plot(test_log,label =str(index))
        index = index + 1
    ax1.set_title('small large')
    ax1.legend()
    
    
    computation_work_load_list = []
    ini_calculation = 0
    for i in range(len(structure_list)):
        calculation_for_each_epoch = calculate_calculation(structure_list[i])
        accumulated_calculation_list = [calculation_for_each_epoch * epoch+ini_calculation for epoch in range(len(test_log_list[i]))]
        
        ini_calculation = accumulated_calculation_list[-1]
        computation_work_load_list.append(accumulated_calculation_list)
    
    index = 1
    for i in range(len(structure_list)):
        ax2.plot(computation_work_load_list[i],test_log_list[i],label =str(index))
        index = index + 1
    
    calculation_for_each_epoch = calculate_calculation(structure_list[-1])
    accumulated_calculation_list = [calculation_for_each_epoch * epoch for epoch in range(len(test_log_list[-1]))]
    ax2.plot(accumulated_calculation_list,test_log_list[-1],label =str(index))
    ax2.set_title('small large')
    ax2.legend()
    fig.show()

    
def smooth(scalars, weight=0.9):  # Weight between 0 and 1
    last = scalars[0]  # First value in the plot (first timestep)
    smoothed = list()
    for point in scalars:
        smoothed_val = last * weight + (1 - weight) * point  # Calculate smoothed value
        smoothed.append(smoothed_val)                        # Save it
        last = smoothed_val                                  # Anchor the last smoothed value
        
    return smoothed