# we calculate  several importance matrices
import time
import copy
import numpy as np
import torch
from utilities import acc_net

def create_impoprtance_vector_by_name(algorithm_name,first_order, second_order, net, p_names, criterion,device,dataloader):
    """

    :param model_name: A string repsenting a name of a pruning method defined
    :return: importance vector of size (num of structures)
    """
    algorithm_importance_vector = globals()[algorithm_name]
    importance_vector = algorithm_importance_vector(first_order, second_order, net, p_names, criterion,device,dataloader)
    return importance_vector

# calculate importance for iterative structured pruning
def iterative_lin_obd_abs(first_order,second_order, net, p_names, criterion,device,dataloader):
    Second_Order_Matrix = -first_order + second_order

    start = time.time()
    num_strucs = Second_Order_Matrix.shape[0]
    pruned_idx_abs = []
    for pruned_strucs in range(num_strucs):
        if pruned_strucs % 1000 == 0 and pruned_strucs != 0:
            print('Pruned structures included: ', pruned_strucs)
            print('Time needed: ', time.time() - start)


        imp_mat = copy.deepcopy(np.diag(Second_Order_Matrix))
        if len(pruned_idx_abs) != 0:
            imp_mat += np.sum(Second_Order_Matrix[pruned_idx_abs], axis=0) * 2
            imp_mat[pruned_idx_abs] = np.Infinity

        next_to_prune_abs = np.argsort(np.abs(imp_mat))[0]

        pruned_idx_abs.append(next_to_prune_abs)
    print('Time needed: ', time.time() - start)
    #check if all structures are included
    if len(list(set(pruned_idx_abs))) != num_strucs:
        raise #print('Not all structures are considered for pruning!')
    return np.array(pruned_idx_abs)




def extreme_iterative_lin_obd_abs_non_cum(first_order,second_order, net, p_names, criterion,device,dataloader):
    # faster implementation of old method (much faster :))
    Second_Order_Matrix_abs = np.abs(first_order) + np.abs(second_order)

    # calculate importance for iterative structured pruning
    start = time.time()
    num_strucs = Second_Order_Matrix_abs.shape[0]
    pruned_idx_extreme_abs = []
    for pruned_strucs in range(num_strucs):
        if pruned_strucs % 1000 == 0 and pruned_strucs != 0:
            print('Time needed: ', time.time() - start)
            print('Pruned structures: ', pruned_strucs)

        imp_mat = copy.deepcopy(np.diag(Second_Order_Matrix_abs))
        if len(pruned_idx_extreme_abs) != 0:
            imp_mat += np.sum(Second_Order_Matrix_abs[pruned_idx_extreme_abs], axis=0) * 2
            imp_mat[pruned_idx_extreme_abs] = np.Infinity

        next_to_prune_ext_abs = np.argsort(imp_mat)[0]

        pruned_idx_extreme_abs.append(next_to_prune_ext_abs)
    print('Time needed: ', time.time() - start)

    return pruned_idx_extreme_abs

def kernel_scale_vector(net,p_names):
    kernel_scale = None
    net_layers =  dict(net.named_parameters())
    for layer_name in p_names:
        lay = net_layers[layer_name]
        if kernel_scale is None:
            kernel_scale = np.full(shape= len(lay),fill_value=lay.shape[-1]*lay.shape[-2])
        else:
            kernel_scale = np.concatenate((kernel_scale,np.full(shape= len(lay),fill_value=lay.shape[-1]*lay.shape[-2])),axis=0)
    return kernel_scale

def kernel_variable_scale_vector(net,p_names,scale):
    kernel_scale = None
    net_layers =  dict(net.named_parameters())
    for layer_name in p_names:
        lay = net_layers[layer_name]
        filter_size = lay.shape[-1]*lay.shape[-2]
        if filter_size != 1:
            vector_scale= scale
        else:
            vector_scale = 1
        if kernel_scale is None:
            kernel_scale = np.full(shape= len(lay),fill_value=vector_scale)
        else:
            kernel_scale = np.concatenate((kernel_scale,np.full(shape= len(lay),fill_value=vector_scale)),axis=0)
    return kernel_scale

def extreme_iterative_lin_obd_abs_non_cum_kernel_scale(first_order,second_order, net, p_names, criterion,device,dataloader,scale):
    # faster implementation of old method (much faster :))
    Second_Order_Matrix_abs = np.abs(first_order) + np.abs(second_order)
    kernel_scale = kernel_scale_vector(net,p_names)
    # calculate importance for iterative structured pruning
    start = time.time()
    num_strucs = Second_Order_Matrix_abs.shape[0]
    pruned_idx_extreme_abs = []
    for pruned_strucs in range(num_strucs):
        if pruned_strucs % 1000 == 0 and pruned_strucs != 0:
            print('Time needed: ', time.time() - start)
            print('Pruned structures: ', pruned_strucs)

        imp_mat = copy.deepcopy(np.diag(Second_Order_Matrix_abs))
        if len(pruned_idx_extreme_abs) != 0:
            imp_mat += np.sum(Second_Order_Matrix_abs[pruned_idx_extreme_abs], axis=0) * 2
            imp_mat[pruned_idx_extreme_abs] = np.Infinity

        next_to_prune_ext_abs = np.argsort(imp_mat/np.sqrt(kernel_scale))[0]

        pruned_idx_extreme_abs.append(next_to_prune_ext_abs)
    print('Time needed: ', time.time() - start)

    return pruned_idx_extreme_abs




def c_obd_abs_non_cum(first_order,second_order, net, p_names, criterion,device,dataloader):
    C_OBD_loss =  np.diag(np.abs(first_order) + np.abs(second_order))
    return np.argsort(np.abs(C_OBD_loss))



######## follwoing methods consider the full hessian

def del_attr(obj, names):
    '''
    From: https://discuss.pytorch.org/t/combining-functional-jvp-with-a-nn-module/81215
    '''
    if len(names) == 1:
        delattr(obj, names[0])
        # print(obj)
    else:
        del_attr(getattr(obj, names[0]), names[1:])


def set_attr(obj, names, val):
    '''
    From: https://discuss.pytorch.org/t/combining-functional-jvp-with-a-nn-module/81215
    '''
    if len(names) == 1:
        setattr(obj, names[0], val)
    else:
        set_attr(getattr(obj, names[0]), names[1:], val)


def make_functional(mod):
    '''
    From: https://discuss.pytorch.org/t/combining-functional-jvp-with-a-nn-module/81215
    '''
    orig_params = tuple(mod.parameters())
    # Remove all the parameters in the model
    names = []
    for name, p in list(mod.named_parameters()):
        del_attr(mod, name.split("."))
        names.append(name)
    return orig_params, names



def hessian_sum_abs_non_cum_kernel_scale(first_order, secon_order, net, p_names, criterion,device,dataloader, steps=20,batch_size = 32,
                              total_number_of_datapoints=1000):
    data_set = dataloader.dataset
    num_strucs = secon_order.shape[0]
    net_hvp = copy.deepcopy(net)
    kernel_scale = kernel_scale_vector(net, p_names)

    #print(acc_net(net_hvp,dataloader,device))
    net_hvp.eval()

    prunable_parameters = tuple(map(dict(net_hvp.named_parameters()).get, p_names))
    orig_params, names = make_functional(net_hvp)

    def functional_loss_dummy(*params):
        for name, p in zip(names, params):
            set_attr(net_hvp, name.split("."), p)
        return criterion(net_hvp(data_in.to(device)), labels.to(device))

    def functional_loss(*p_params):
        for name, p in zip(p_names, p_params):
            if name.split(".")[-2][0] == 'c':
                set_attr(net_hvp, name.split("."), p)
        return criterion(net_hvp(data_in.to(device)), labels.to(device))

    dataloader = torch.utils.data.DataLoader(data_set, batch_size=batch_size, shuffle=False, num_workers=4)
    time1 = time.time()
    sum_hessian = np.zeros(num_strucs)
    print(total_number_of_datapoints)
    for i, t_data in enumerate(dataloader, 0):
        data_in = t_data[0]
        labels = t_data[1]

        # stopping criterion
        if i * batch_size >= total_number_of_datapoints:
            break
        if i == 0:
            # still don't know why i need this
            _, _ = torch.autograd.functional.hvp(functional_loss_dummy, orig_params, v=orig_params)

        _, hess_sum = torch.autograd.functional.hvp(functional_loss, prunable_parameters, v=prunable_parameters)

        sum_hessian_temp = []

        for layer_idx in range(len(hess_sum)):
            if len(hess_sum[layer_idx].shape) > 1:
                for struc_idx in range(hess_sum[layer_idx].shape[0]):
                    sum_hessian_temp.append(torch.dot(prunable_parameters[layer_idx][struc_idx].reshape(-1),
                                                      hess_sum[layer_idx][struc_idx].reshape(-1)).cpu().data.numpy())
        sum_hessian += np.array(sum_hessian_temp) * float(batch_size)

    sum_hessian =  sum_hessian / float(total_number_of_datapoints)
    print(' Finished in: ' + str(time.time() - time1) +' seconds')
    imp_vector_unsort_scaled = (np.abs(sum_hessian) + np.abs(np.diag(first_order)))/np.sqrt(kernel_scale)
    return np.argsort(imp_vector_unsort_scaled)


def hessian_sum_abs_non_cum(first_order, secon_order, net, p_names, criterion,device,dataloader, steps=20,batch_size = 32,
                              total_number_of_datapoints=1000):
    data_set = dataloader.dataset
    num_strucs = secon_order.shape[0]
    net_hvp = copy.deepcopy(net)
    #print(acc_net(net_hvp,dataloader,device))
    net_hvp.eval()

    prunable_parameters = tuple(map(dict(net_hvp.named_parameters()).get, p_names))
    orig_params, names = make_functional(net_hvp)

    def functional_loss_dummy(*params):
        for name, p in zip(names, params):
            set_attr(net_hvp, name.split("."), p)
        return criterion(net_hvp(data_in.to(device)), labels.to(device))

    def functional_loss(*p_params):
        for name, p in zip(p_names, p_params):
            if name.split(".")[-2][0] == 'c':
                set_attr(net_hvp, name.split("."), p)
        return criterion(net_hvp(data_in.to(device)), labels.to(device))

    dataloader = torch.utils.data.DataLoader(data_set, batch_size=batch_size, shuffle=False, num_workers=4)
    time1 = time.time()
    sum_hessian = np.zeros(num_strucs)
    print(total_number_of_datapoints)
    for i, t_data in enumerate(dataloader, 0):
        data_in = t_data[0]
        labels = t_data[1]

        # stopping criterion
        if i * batch_size >= total_number_of_datapoints:
            break
        if i == 0:
            # still don't know why i need this
            _, _ = torch.autograd.functional.hvp(functional_loss_dummy, orig_params, v=orig_params)

        _, hess_sum = torch.autograd.functional.hvp(functional_loss, prunable_parameters, v=prunable_parameters)

        sum_hessian_temp = []

        for layer_idx in range(len(hess_sum)):
            if len(hess_sum[layer_idx].shape) > 1:
                for struc_idx in range(hess_sum[layer_idx].shape[0]):
                    sum_hessian_temp.append(torch.dot(prunable_parameters[layer_idx][struc_idx].reshape(-1),
                                                      hess_sum[layer_idx][struc_idx].reshape(-1)).cpu().data.numpy())
        sum_hessian += np.array(sum_hessian_temp) * float(batch_size)

    sum_hessian =  sum_hessian / float(total_number_of_datapoints)
    print(' Finished in: ' + str(time.time() - time1) +' seconds')
    return np.argsort(np.abs(sum_hessian) + np.abs(np.diag(first_order)))




def zero_prunable_parameters(prunable_parameters,num_strucs,list_of_structures):
    temp_param = copy.deepcopy(prunable_parameters)
    idx = 0
    for layer_idx in range(len(temp_param)):
        for struc_idx in range(len(temp_param[layer_idx])):
            if idx in list_of_structures:
                temp_param[layer_idx][struc_idx] = 0
            idx +=1
    return temp_param




def random(first_order,second_order, net, p_names, criterion,device,dataloader):
    num_strucs = second_order.shape[0]
    random_importance = np.array(range(num_strucs))
    np.random.shuffle(random_importance)
    return random_importance






