import torch
import numpy as np
import standardize
import order
import bisect
import copy
import max_ae 
import math
import meanse_meanae
import torch.nn as nn


def get_all_permutations_for_kernel_indices(): #assumes only 3 kernels, ideally recursion for this to generalize. 
   all_permuations_kernel_indices = [[0,1,2], [1,0,2], [2,0,1], [1,2,0], [0,2,1],[2,1,0]]

   return all_permuations_kernel_indices

def get_mean_abs_between_kernel(kernel_O, kernel_R):
    pass

def heuristic_ordering_kernels_cnn(original_cnn_layer, model_to_align_cnn_layer): 
    weights_original = original_cnn_layer.weight
    weights_align_net = model_to_align_cnn_layer.weight

    align_index = 0

    min_err_ordering = [-1]*weights_original.shape[0] # for each kernel ->  min mean abs error ordering. The elements represent the new ordering. 
    for weight_align in weights_align_net:
        min_mean_abs_err = 1000
        min_mapping_op_index = -1 
        og_index = 0
        for weight_og in weights_original:
            num_els = weight_og.numel()
            weight_align_flat = weight_align.flatten()
            weight_og_flat = weight_og.flatten()
            sum_abs_error = torch.nn.functional.l1_loss(weight_align_flat, weight_og_flat, reduction="sum")
            mean_abs_error = sum_abs_error/num_els
            if min_mean_abs_err > mean_abs_error:
                min_mean_abs_err = mean_abs_error 
                min_mapping_op_index = og_index

            og_index+=1
        
        min_err_ordering[align_index] = min_mapping_op_index
        align_index +=1
    
    return min_err_ordering
# order kernels according to this permutation. 

def order_kernels_cnn(permutation, network2_layer): 
    if type(network2_layer).__name__ != 'Conv2d': 
        #print(type(network2_layer))
        return network2_layer

    weights_0 = network2_layer.weight[0].clone()
    weights_1 = network2_layer.weight[1].clone()
    weights_2 = network2_layer.weight[2].clone()

    bias_0 =  network2_layer.bias[0].clone()
    bias_1 = network2_layer.bias[1].clone()
    bias_2 = network2_layer.bias[2].clone()
    with torch.no_grad():
        network2_weight_copies = [weights_0, weights_1, weights_2]
        network2_layer.weight[0] = network2_weight_copies[permutation[0]]
        network2_layer.weight[1] = network2_weight_copies[permutation[1]]
        network2_layer.weight[2] = network2_weight_copies[permutation[2]]

        # bias 
        network2_bias_copies = [bias_0, bias_1, bias_2]
        network2_layer.bias[0] = network2_bias_copies[permutation[0]]
        network2_layer.bias[1] = network2_bias_copies[permutation[1]]
        network2_layer.bias[2] = network2_bias_copies[permutation[2]]

def order_fnn_weights(permutation, network2_layer):
    if type(network2_layer).__name__ != 'Linear': 
        #print(type(network2_layer))
        return network2_layer

    number_input_neurons = int(network2_layer.weight.shape[1])
   # print(number_input_neurons)
    index_1 = int(number_input_neurons/3) # 3 because that's the number of kernels
    index_2 = int(2*number_input_neurons/3)
    weights_0 = network2_layer.weight[:, 0:index_1]
    weights_1 = network2_layer.weight[:, index_1:index_2]
    weights_2 = network2_layer.weight[:, index_2:number_input_neurons]

    # match with inter-kernel alignment 
    with torch.no_grad():
        if permutation != None:
            print("inter yes")
            network2_weight_copies = [weights_0, weights_1, weights_2]
        # print("weights 0 shape")
        # print(weights_0.shape)
            network2_layer.weight[:, 0:index_1]= network2_weight_copies[permutation[0]]
            network2_layer.weight[:, index_1:index_2] =  network2_weight_copies[permutation[1]]
            network2_layer.weight[:, index_2:number_input_neurons] = network2_weight_copies[permutation[2]]
    return network2_layer

# we don't need intra-kernel 
def intra_kernel_alignment(cnn_layer, cnn_layer_model_to_align):
    # Extract weights from both layers
    if type(cnn_layer).__name__ != 'Conv2d': 
       # print(type(network1_layer))
        return cnn_layer_model_to_align
    
    weights_network1 = cnn_layer.weight.data.cpu().numpy()
    weights_network2 = cnn_layer_model_to_align.weight.data.cpu().numpy()

   # print("weights_network1 before flattening", weights_network1.shape)
  #  print("weights_network2 before flattening", weights_network2.shape)

    reorder_indices_for_each_kernel = []
    sorting_indices_for_each_kernel = []
    # Reshape weights for comparison
    weight_index = 0
    for weight_1, weight_2 in zip(weights_network1, weights_network2):
        weight_network1_flat = weight_1.flatten()
        weight_network2_flat = weight_2.flatten()

        # Sort the weights and get indices
        sorted_indices_network1 = np.argsort(weight_network1_flat)

        # Sort the network2 layer and then apply mapping indices for alignment. 
        sorting_indices_for_each_kernel.append(np.argsort(weight_network2_flat))
        sorted_network2_weights = np.sort(weight_network2_flat)

        # Mapping indices between the two networks
        mapping_indices = np.argsort(sorted_indices_network1)
        reorder_indices_for_each_kernel.append(mapping_indices)

        # Reorder weights of the second network based on mapping indices
        matched_weight_network2 = sorted_network2_weights[mapping_indices]

        # Reshape the weights to their original shapes
        matched_weight_network2 = matched_weight_network2.reshape(weight_2.shape)

        # Replace the weights of network2_layer with the matched weights
        cnn_layer_model_to_align.weight[weight_index].data = torch.Tensor(matched_weight_network2).to(cnn_layer_model_to_align.weight.device)

        weight_index+=1

    return cnn_layer_model_to_align, sorting_indices_for_each_kernel, reorder_indices_for_each_kernel


def cnn_align(model: torch.nn.Module, model_to_align: torch.nn.Module, perm): 
    model_layers = standardize.get_layers(model)
    cnn_layer = model_layers[0]
    align_layers = standardize.get_layers(model_to_align)
    cnn_layer_model_to_align = align_layers[0]

    # inter kernel alignment 
    print('inter-kernel alignment')
    order_kernels_cnn(perm,cnn_layer_model_to_align)

    fnn_layer = align_layers[1]
    fnn_layer = order_fnn_weights(perm, fnn_layer)
    
    # print("after ordering:", cnn_layer.weight)
    align_layers[1] = fnn_layer
    # print("assign model ordering:", cnn_layer.weight)
    
    return model_to_align


def standardize_scale_cnn(model: torch.nn.Module, tanh: bool =None): 
    cnn_layer = standardize.get_layers(model)[0]
    fnn_layer = standardize.get_layers(model)[1]
    # cnn layer normalize and then multiply 

    # concat weights and biases
    cnn_layer_weights_biases_1 = torch.cat((cnn_layer.weight[0].flatten(), cnn_layer.bias[0].view(1)))
    cnn_layer_weights_biases_2 = torch.cat((cnn_layer.weight[1].flatten(), cnn_layer.bias[1].view(1)))
    cnn_layer_weights_biases_3 = torch.cat((cnn_layer.weight[2].flatten(), cnn_layer.bias[2].view(1)))

    with torch.no_grad(): 
        cnn_layer_weights_biases_1 = cnn_layer_weights_biases_1.expand(196, 5)
        kernel_1_scales =   torch.norm(cnn_layer_weights_biases_1, dim=1, p=2)  
        cnn_layer_weights_biases_2 = cnn_layer_weights_biases_2.expand(196, 5)
        kernel_2_scales =  torch.norm(cnn_layer_weights_biases_2, dim=1, p=2)  
        cnn_layer_weights_biases_3 = cnn_layer_weights_biases_3.expand(196, 5)
        kernel_3_scales =  torch.norm(cnn_layer_weights_biases_3, dim=1, p=2)  

        # divide the cnn_weights_biases with the kernel_scales 
        cnn_layer_weights_biases_1 = cnn_layer_weights_biases_1/kernel_1_scales.reshape(-1,1)
        cnn_layer_weights_biases_2 = cnn_layer_weights_biases_2/kernel_2_scales.reshape(-1,1)
        cnn_layer_weights_biases_3 = cnn_layer_weights_biases_3/kernel_3_scales.reshape(-1,1)

        # reassign the kernels to normalized weights and biases. 
        cnn_layer.weight[0] = cnn_layer_weights_biases_1[0, 0:4].reshape(2,2)  # all 196 rows are the same so take any one except bias

        cnn_layer.weight[1] =  cnn_layer_weights_biases_2[0, 0:4].reshape(2,2) # want to only use the weights and not the biases
        cnn_layer.weight[2] =  cnn_layer_weights_biases_3[0, 0:4].reshape(2,2) #  want to only use the weights and not the biases

        cnn_layer.bias[0] = cnn_layer_weights_biases_1[0,4] # use only the bias
        cnn_layer.bias[1] = cnn_layer_weights_biases_2[0,4]
        cnn_layer.bias[2] = cnn_layer_weights_biases_3[0,4] 

        weights_biases = (fnn_layer.weight, fnn_layer.bias.reshape(-1, 1))
        fnn_layer_weights_biases = torch.hstack(weights_biases)

        # only need to apply kernel scales to weights because those are ones affected from kernel. 
        fnn_layer_weights_biases[:, 0:196] = fnn_layer_weights_biases[:, 0:196] * kernel_1_scales
        fnn_layer_weights_biases[:, 196:392] = fnn_layer_weights_biases[:, 196:392] * kernel_2_scales
        fnn_layer_weights_biases[:, 392:588] = fnn_layer_weights_biases[:, 392:588] * kernel_3_scales

        # norms of fnn weights and biases 
        appended_fnn_weights_biases_1 = torch.cat((fnn_layer_weights_biases[:, 0:196],fnn_layer_weights_biases[:, 588].view(10,1)), dim=1)
        fnn_layer_norm_1 = torch.norm(appended_fnn_weights_biases_1 ,dim=1, p=2)
        appended_fnn_weights_biases_2 = torch.cat((fnn_layer_weights_biases[:, 196:392],fnn_layer_weights_biases[:, 588].view(10,1)), dim=1)
        fnn_layer_norm_2 = torch.norm(appended_fnn_weights_biases_2, dim=1, p=2)
        appended_fnn_weights_biases_3 = torch.cat((fnn_layer_weights_biases[:,  392:588],fnn_layer_weights_biases[:, 588].view(10,1)), dim=1)
        fnn_layer_norm_3 = torch.norm(appended_fnn_weights_biases_3, dim=1, p=2)
        
        #compute the avg scale to spread across
        
        avg_out_scale_mul_1 = (sum(fnn_layer_norm_1)/len(fnn_layer_norm_1))**0.5
        avg_out_scale_mul_2 = (sum(fnn_layer_norm_2)/len(fnn_layer_norm_2)) **0.5
        avg_out_scale_mul_3 = (sum(fnn_layer_norm_3)/len(fnn_layer_norm_3)) ** 0.5


        # multiply these avg out scales across the CNN 
        cnn_layer.weight[0] =   cnn_layer.weight[0]*avg_out_scale_mul_1 # all 196 rows are the same so take any one except bias
        cnn_layer.bias[0] =   cnn_layer.bias[0]*avg_out_scale_mul_1
        
        cnn_layer.weight[1] =   cnn_layer.weight[1]*avg_out_scale_mul_2 # want to only use the weights and not the biases
        cnn_layer.bias[1] =   cnn_layer.bias[1]*avg_out_scale_mul_2

        cnn_layer.weight[2] =  cnn_layer.weight[2]*avg_out_scale_mul_3 #  want to only use the weights and not the biases
        cnn_layer.bias[2] =   cnn_layer.bias[2]*avg_out_scale_mul_3

        # divide this for FNN 
        
        fnn_layer.weight[:, 0:196] =  fnn_layer_weights_biases[:, 0:196]/avg_out_scale_mul_1
        fnn_layer.weight[:, 196:392] =  fnn_layer_weights_biases[:, 196:392]/avg_out_scale_mul_2
        fnn_layer.weight[:, 392:588] = fnn_layer_weights_biases[:, 392:588]/avg_out_scale_mul_3


def get_mae(original, reconstructed): 
    original_layers = standardize.get_layers(original)
    reconstruced_layers = standardize.get_layers(reconstructed)

    total_size = sum(
            weights.numel() for weights in original.state_dict().values()
        )
    
    sum_cnn_weights = torch.nn.functional.l1_loss(original_layers[0].weight.flatten(), reconstruced_layers[0].weight.flatten(), reduction='sum')

    sum_cnn_bias  = torch.nn.functional.l1_loss(original_layers[0].bias.flatten(), reconstruced_layers[0].bias.flatten(), reduction='sum')

    sum_fnn_weights = torch.nn.functional.l1_loss(original_layers[1].weight.flatten(), reconstruced_layers[1].weight.flatten(), reduction='sum')

    sum_fnn_bias = torch.nn.functional.l1_loss(original_layers[1].bias.flatten(), reconstruced_layers[1].bias.flatten(), reduction='sum')

    overall_error = (sum_cnn_weights + sum_cnn_bias + sum_fnn_weights + sum_fnn_bias)/total_size

    return ([sum_cnn_weights/torch.numel(original_layers[0].weight.flatten()), 
    sum_cnn_bias/torch.numel(original_layers[0].bias.flatten()), sum_fnn_weights/torch.numel(original_layers[1].weight.flatten()), 
    sum_fnn_bias/torch.numel(original_layers[1].bias.flatten())], overall_error)

def bruteforce_cnn_evaluate(model: torch.nn.Module, model_to_evaluate: torch.nn.Module, tanh: bool = None):   
    print("bruteforce cnn eval")
    standardize_scale_cnn(model, tanh=None)
    standardize_scale_cnn(model_to_evaluate, tanh=None)
    #standardize.standardize_scale(model)
    #standardize.standardize_scale(model_to_evaluate)


    perms = get_all_permutations_for_kernel_indices()
    min_max_abs_error = math.inf
    perm_model_w_lowest_max_error = None
    for perm in perms:
        model_copy = copy.deepcopy(model_to_evaluate)
        aligned_model_copy = cnn_align(model, model_copy, perm)
        # get max error of all the permuted models. use the one with lowest max error for all evaluation. 
        mae = max_ae.calculate_distance_mae(model,aligned_model_copy)
        if mae < min_max_abs_error: 
            min_max_abs_error = mae
            perm_model_w_lowest_max_error = copy.deepcopy(aligned_model_copy)

    low_max_error_model_layers =   standardize.get_layers(perm_model_w_lowest_max_error)  
    print("avg abs magnitude cnn_layer_weights", torch.mean(torch.abs(low_max_error_model_layers[0].weight.flatten())))
    print("avg abs magnitude cnn_layer_biases", torch.mean(torch.abs(low_max_error_model_layers[0].bias.flatten())))
    print("avg abs magnitute fnn_layer_weights", torch.mean(torch.abs(low_max_error_model_layers[1].weight.flatten())))
    print("avg abs magnitudefnn_layer_biases", torch.mean(torch.abs(low_max_error_model_layers[1].bias.flatten())))

    # now evaluate for all of them. 
    #mean_se, layers_mean_se = meanse_meanae.calculate_distance_mse_or_mae('mse', model, perm_model_w_lowest_max_error)
    mean_ae, layers_mean_ae = meanse_meanae.calculate_distance_mse_or_mae('mae', model, perm_model_w_lowest_max_error)
    max_overall_error =  max_ae.calculate_distance_mae(model,perm_model_w_lowest_max_error)

    print("get mae", get_mae(model, model_to_evaluate))

    return (mean_ae, layers_mean_ae, max_overall_error)


def cnn_evaluate(model: torch.nn.Module, model_to_evaluate: torch.nn.Module, tanh: bool = None):
    print('heuristic evaluate')
    standardize_scale_cnn(model, tanh=None)
    standardize_scale_cnn(model_to_evaluate, tanh=None)

    cnn_layer_original = standardize.get_layers(model)[0]
    cnn_layer_align = standardize.get_layers(model_to_evaluate)[0]

    kernel_ordering = heuristic_ordering_kernels_cnn(cnn_layer_original, cnn_layer_align)
    aligned_model_copy = cnn_align(model, model_to_evaluate, kernel_ordering)

    #mean_se, layers_mean_se = meanse_meanae.calculate_distance_mse_or_mae('mse', model, aligned_model_copy)
    mean_ae, layers_mean_ae = meanse_meanae.calculate_distance_mse_or_mae('mae', model, aligned_model_copy)
    max_overall_error =  max_ae.calculate_distance_mae(model,aligned_model_copy)

    print("get mae", get_mae(model, model_to_evaluate))

    return (mean_ae, layers_mean_ae, max_overall_error)