import torch
import numpy as np
import tqdm

from types import SimpleNamespace

from pruning.utils import *


def pfa_en_prune_loop(net, dataloader, Ds, ratios, loss_fn, device, pivots=None, args=SimpleNamespace()):
    path = args.results_path + f'{args.pruning_algo}-{args.extra_args}'

    total_accs = np.zeros(len(ratios))
    total_losses = np.zeros_like(total_accs)
    parameter_count = np.zeros_like(total_losses)
    flops_count = np.zeros_like(parameter_count)

    prunable_layers = replace_layers(net, None, None, pivots, args)
    for r_indx, ratio in tqdm(enumerate(ratios)):
        for m_indx in range(1, len(prunable_layers)):
            m = get_n_nodes_for_variance_cutoff(Ds[m_indx], ratio)
            print(f'Layer {m_indx} pruned by {m}/{len(Ds[m_indx])} units.')

            mod_pre = prunable_layers[m_indx - 1]
            mod_post = prunable_layers[m_indx]

            if args.model == 'resnet50' and m_indx == 54:
                n_inputs = 512
            else:
                n_inputs = len(mod_pre.layer.weight)

            indices = torch.arange(-m, 0).to(device) + n_inputs

            mod_post.prune_inputs(indices)
            if args.prune_by_removal:
                mod_pre.prune_outputs(indices)

        loss, acc = measure_perf(net, loss_fn, dataloader, device)

        total_accs[r_indx] = acc
        total_losses[r_indx] = loss
        pc, flops = get_network_stats(net, dataloader, device)
        parameter_count[r_indx] = pc
        flops_count[r_indx] = flops

        for m_indx in range(len(prunable_layers) - 1):
            mod_pre = prunable_layers[m_indx]
            mod_post = prunable_layers[m_indx + 1]

            mod_pre.reset()
            mod_post.reset()

        # store value
        print(total_accs, pc, flops)

    np.savetxt(path + '-accuracies.txt', total_accs)
    np.savetxt(path + '-losses.txt', total_losses)
    np.savetxt(path + '-pc.txt', parameter_count)
    np.savetxt(path + '-flops.txt', flops_count)

def layerwise_prune_loop(net, dataloader, Rs, ratios, loss_fn, device, means=None, pivots=None, args=SimpleNamespace()):
    path = args.results_path + f'{args.pruning_algo}-{args.extra_args}'

    prunable_layers = replace_layers(net, Rs, means, pivots, args)

    total_accs = np.zeros((len(prunable_layers)-1, len(ratios)))
    total_losses = np.zeros_like(total_accs)
    parameter_count = np.zeros_like(total_losses)
    flops_count = np.zeros_like(parameter_count)

    for l_indx, (mod_post, mod_pre) in enumerate(zip(prunable_layers[1:], prunable_layers), start=1):
        for r_indx, ratio in tqdm(enumerate(ratios)):
            
            n_inputs = mod_post.pre_dim
            m = int(n_inputs * ratio)
            indices = torch.arange(-m, 0).to(device) + n_inputs

            # prune succeeding layer
            mod_post.prune_inputs(indices)

            # pruning the preceeding layer - technically not necessary as the outputs won't be used regardless
            if args.prune_by_removal:
                mod_pre.prune_outputs(indices)

            # testing the pruned model
            loss, acc = measure_perf(net, loss_fn, dataloader, device)
            pc, flops = get_network_stats(net, dataloader, device)

            # reset network
            mod_pre.reset()
            mod_post.reset()

            # store value
            total_accs[l_indx-1, r_indx] = acc
            total_losses[l_indx-1, r_indx] = loss
            parameter_count[l_indx-1, r_indx] = pc
            flops_count[l_indx-1, r_indx] = flops
            
        print(total_accs)

    np.savetxt(path + '-accuracies.txt', total_accs)
    np.savetxt(path + '-losses.txt', total_losses)
    np.savetxt(path + '-pc.txt', parameter_count)
    np.savetxt(path + '-flops.txt', flops_count)


def prune_loop_whole_net_equal_amount(net, dataloader, Rs, ratios, loss_fn, device, means=None, pivots=None, args=SimpleNamespace()):
    path = args.results_path + f'{args.pruning_algo}-{args.extra_args}'

    total_accs = np.zeros(len(ratios))
    total_losses = np.zeros_like(total_accs)
    parameter_count = np.zeros_like(total_losses)
    flops_count = np.zeros_like(parameter_count)
    
    prunable_layers = replace_layers(net, Rs, means, pivots, args)
    for r_indx, ratio in tqdm(enumerate(ratios)):        
        for m_indx in range(len(prunable_layers)-1):
            mod_pre = prunable_layers[m_indx]
            mod_post = prunable_layers[m_indx + 1]
            n_inputs = mod_post.pre_dim

            m = int(ratio * n_inputs)
            indices = torch.arange(-m, 0).to(device) + n_inputs

            mod_post.prune_inputs(indices)
            if args.prune_by_removal:
                mod_pre.prune_outputs(indices)

        loss, acc = measure_perf(net, loss_fn, dataloader, device)
        pc, flops = get_network_stats(net, dataloader, device)

        for m_indx in range(len(prunable_layers)-1):
            mod_pre = prunable_layers[m_indx]
            mod_post = prunable_layers[m_indx + 1]

            mod_pre.reset()
            mod_post.reset()

        # store value
        total_accs[r_indx] = acc
        total_losses[r_indx] = loss
        parameter_count[r_indx] = pc
        flops_count[r_indx] = flops
            
        print(total_accs)

    np.savetxt(path + '-accuracies.txt', total_accs)
    np.savetxt(path + '-losses.txt', total_losses)
    np.savetxt(path + '-pc.txt', parameter_count)
    np.savetxt(path + '-flops.txt', flops_count)

#### WIP ####
def decor_prune_loop_whole_net(net, dataloader, Rs, ratios, loss_fn, device, pruning_order, means=None, pivots=None, args=SimpleNamespace()):
    path = args.results_path + f'{args.pruning_algo}-{args.extra_args}'

    total_accs = np.zeros(len(ratios))
    total_losses = np.zeros_like(total_accs)
    parameter_count = np.zeros_like(total_losses)
    flops_count = np.zeros_like(parameter_count)
    
    prunable_layers = replace_layers(net, Rs, means, pivots, args)
    for r_indx, ratio in tqdm(enumerate(ratios)):
        number_of_nodes_pruned = int(ratio * len(pruning_order))

        layer_indices, prune_nums = np.unique(pruning_order[:number_of_nodes_pruned], return_counts=True)
        
        # for m_indx in range(len(prunable_layers)-1):
        #     mod_pre = prunable_layers[m_indx]
        #     mod_post = prunable_layers[m_indx + 1]
        #     n_inputs = len(mod_pre.layer.weight)

        #     m = int(ratio * n_inputs)
        #     indices = torch.arange(-m, 0).to(device) + n_inputs

        #     mod_post.prune_inputs(indices)
        #     mod_pre.prune_outputs(indices)

        # loss, acc = measure_perf(net, loss_fn, dataloader, device)
        # pc, flops = get_network_stats(net, dataloader, device)

        # for m_indx in range(len(prunable_layers)-1):
        #     mod_pre = prunable_layers[m_indx]
        #     mod_post = prunable_layers[m_indx + 1]

        #     mod_pre.reset()
        #     mod_post.reset()

        print(layer_indices, prune_nums)
        for m_indx, m in zip(layer_indices, prune_nums):
            mod_pre = prunable_layers[m_indx]
            mod_post = prunable_layers[m_indx + 1]

            indices = torch.arange(-m, 0).to(device) + len(mod_pre.layer.weight)

            mod_post.prune_inputs(indices)
            mod_pre.prune_outputs(indices)

        loss, acc = measure_perf(net, loss_fn, dataloader, device)
        pc, flops = get_network_stats(net, dataloader, device)


        for m_indx in layer_indices:
            mod_pre = prunable_layers[m_indx]
            mod_post = prunable_layers[m_indx + 1]

            mod_pre.reset()
            mod_post.reset()

        # store value
        total_accs[r_indx] = acc
        total_losses[r_indx] = loss
        parameter_count[r_indx] = pc
        flops_count[r_indx] = flops
            
        print(total_accs)

    np.savetxt(path + '-accuracies.txt', total_accs)
    np.savetxt(path + '-losses.txt', total_losses)
    np.savetxt(path + '-pc.txt', parameter_count)
    np.savetxt(path + '-flops.txt', flops_count)


def decor_prune_loop_whole_net_var_ratio(net, dataloader, Rs, Ds, ratios, loss_fn, device, means=None, pivots=None, args=SimpleNamespace()):
    path = args.results_path + f'{args.pruning_algo}-{args.extra_args}'

    total_accs = np.zeros(len(ratios))
    total_losses = np.zeros_like(total_accs)
    parameter_count = np.zeros_like(total_losses)
    flops_count = np.zeros_like(parameter_count)

    prunable_layers = replace_layers(net, Rs, means, pivots, args)
    for r_indx, ratio in tqdm(enumerate(ratios)):
        for m_indx in range(1, len(prunable_layers)):
            m = get_n_nodes_for_variance_cutoff(Ds[m_indx], ratio)

            mod_pre = prunable_layers[m_indx - 1]
            mod_post = prunable_layers[m_indx]
        
            n_inputs = len(mod_pre.layer.weight)
        
            indices = torch.arange(-m, 0).to(device) + n_inputs

            mod_post.prune_inputs(indices)
            mod_pre.prune_outputs(indices)

        loss, acc = measure_perf(net, loss_fn, dataloader, device)

        total_accs[r_indx] = acc
        total_losses[r_indx] = loss
        pc, flops = get_network_stats(net, dataloader, device)
        parameter_count[r_indx] = pc
        flops_count[r_indx] = flops

        for m_indx in range(len(prunable_layers)-1):
            mod_pre = prunable_layers[m_indx]
            mod_post = prunable_layers[m_indx + 1]

            mod_pre.reset()
            mod_post.reset()

        # store value   
        print(total_accs, pc, flops)

    np.savetxt(path + '-accuracies.txt', total_accs)
    np.savetxt(path + '-losses.txt', total_losses)
    np.savetxt(path + '-pc.txt', parameter_count)
    np.savetxt(path + '-flops.txt', flops_count)


def weight_magnitude_prune(net, dataloader, ratios, loss_fn, device, args=SimpleNamespace()):
    path = args.results_path + f'{args.pruning_algo}-{args.extra_args}'

    total_accs = np.zeros(len(ratios))
    total_losses = np.zeros_like(total_accs)
    parameter_count = np.zeros_like(total_losses)
    flops_count = np.zeros_like(parameter_count)

    prunable_layers = get_prunable_layers(net)

    for r_indx, r in tqdm(enumerate(ratios)):
        resetter = []
        for l_indx in range(len(prunable_layers)):
            W = prunable_layers[l_indx].weight.clone()
            resetter.append(W.clone())
            orig_shape = W.shape
            m = int(np.prod(orig_shape) * r)

            flat_W = W.view(-1)
            indices = flat_W.abs().argsort(descending=False)
            
            flat_W[indices[:m]] = 0.
            prunable_layers[l_indx].weight.data = flat_W.view(orig_shape)

        loss, acc = measure_perf(net, loss_fn, dataloader, device)
        pc, flops = get_network_stats(net, dataloader, device)
        pc = parameter_count_nonzero(net)

        # reset network
        for l_indx in range(len(prunable_layers)):
            prunable_layers[l_indx].weight.data = resetter[l_indx]

        # store value
        total_accs[r_indx] = acc
        total_losses[r_indx] = loss
        parameter_count[r_indx] = pc
        flops_count[r_indx] = flops

        # store value   
        print(total_accs, pc, flops)

    
    np.savetxt(path + '-accuracies.txt', total_accs)
    np.savetxt(path + '-losses.txt', total_losses)
    np.savetxt(path + '-pc.txt', parameter_count)
    np.savetxt(path + '-flops.txt', flops_count)
