# implements various forms of greedy forward selection pruning

import copy
import pickle

import torch

from models import TwoLayerMLP 

def prune_mlp_fixed_size(args, model, dl, num_neuron=100, num_mb=1):
    """prunes a 2-layer neural network to a fixed number of hidden neurons

    Parameters
    model: the full model that is being pruned
    dl: data loader from which you sample batches to prune
    num_neuron: the number of neurons in the final, pruned model
    num_mb: the number of mini-batches to sample in the pruning algorithm
    """

    # set up everything to track the best pruned model
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.cpu()
    total_neuron = model.num_hidden
    best_pruned_model = None
    tmp_best_pruned_model = None

    # iterate through the number of pruned neurons
    # select another neuron to include at each iteration
    neuron_results = []
    for neuron_idx in range(num_neuron):
        curr_best_acc = None # use to track which neuron gives you the best result
        pruned_model = TwoLayerMLP(model.input_size, neuron_idx + 1)
        if args.verbose:
            print(f'Selecting neuron {neuron_idx + 1} / {num_neuron}')

        # copy in neuron weights from previous iterations
        if best_pruned_model is not None:
            input_weights = best_pruned_model.hidden_layer.weight.data
            output_weights = best_pruned_model.output_layer.weight.data
            assert (
                    input_weights.shape[0] == neuron_idx
                    and output_weights.shape[1] == neuron_idx)
            pruned_model.hidden_layer.weight.data[:neuron_idx, :] = input_weights
            pruned_model.output_layer.weight.data[:, :neuron_idx] = output_weights

        # find next best neuron to use
        model_in, target = next(iter(dl))
        model_in = torch.flatten(model_in, start_dim=1)
        model_in = model_in.to(device)
        target = target.to(device)
        pruned_model = pruned_model.to(device)
        for new_neuron_idx in range(total_neuron):
            new_input_weight = model.hidden_layer.weight.data[new_neuron_idx, :]
            new_output_weight = model.output_layer.weight.data[:, new_neuron_idx]
            pruned_model.hidden_layer.weight.data[neuron_idx, :] = new_input_weight
            pruned_model.output_layer.weight.data[:, neuron_idx] = new_output_weight

            # determine accuracy on the current minibatch
            with torch.no_grad():
                output = pruned_model(model_in)
            preds = output > 0.5
            num_corr = float(torch.sum(preds.long().squeeze() == target))
            num_ex = target.shape[0]
            curr_acc = num_corr / num_ex 

            # keep track of the best performing network during each iteration
            if curr_best_acc is None or curr_acc > curr_best_acc:
                curr_best_acc = curr_acc
                pruned_model = pruned_model.cpu()
                tmp_best_pruned_model = copy.deepcopy(pruned_model)
                pruned_model = pruned_model.to(device)

        # track the new best pruned model at end of each neuron iteration
        pruned_model = pruned_model.cpu()
        best_pruned_model = copy.deepcopy(tmp_best_pruned_model)

        # check trn acc of best pruned model
        best_pruned_model = best_pruned_model.to(device)
        num_corr = 0.
        num_ex = 0.
        for model_in, target in dl:
            # perform a training update
            target = target.to(device)
            model_in = torch.flatten(model_in, start_dim=1)
            model_in = model_in.to(device)
            with torch.no_grad():
                output = best_pruned_model(model_in)
            preds = output > 0.5
            num_corr += float(torch.sum(preds.long().squeeze() == target))
            num_ex += target.shape[0]
        final_acc = num_corr / num_ex
        best_pruned_model = best_pruned_model.cpu()
        print(f'{neuron_idx} --> {final_acc:.4f}')
        neuron_results.append(final_acc)

    with open('./results/20K_prune_accs.pckl', 'wb') as f:
        pickle.dump(neuron_results, f)

    # fix output weights of the network
    #with torch.no_grad():
    #    best_pruned_model.output_layer.weight.data *= (total_neuron / num_neuron)

    # run fine tuning on the pruned model (optional)
    prune_ft_losses = []
    prune_ft_accs = []
    if args.prune_epochs > 0:
        opt = torch.optim.SGD(
                best_pruned_model.parameters(), lr=args.lr,
                momentum=args.momentum)
        criterion = torch.nn.BCELoss()
        best_pruned_model = best_pruned_model.to(device)
        for e in range(args.prune_epochs):
            if args.verbose:
                print(f'Running Fine-Tuning Epoch {e + 1} / {args.prune_epochs}')
            agg_trn_loss = 0.
            num_corr = 0.
            num_ex = 0.
            for model_in, target in dl:
                target = target.float()
                model_in = torch.flatten(model_in, start_dim=1)
                model_in = model_in.to(device)
                target = target.to(device)
                output = best_pruned_model(model_in).squeeze(dim=1)
                loss = criterion(output, target)
                loss.backward()
                opt.step()
                preds = output > 0.5
                tmp_num_corr = float(torch.sum(preds.long() == target))
                num_corr += tmp_num_corr
                num_ex += target.shape[0]
                agg_trn_loss += loss.item()
            prune_ft_loss = agg_trn_loss / len(dl)
            prune_ft_acc = num_corr / num_ex
            prune_ft_losses.append(prune_ft_loss)
            prune_ft_accs.append(prune_ft_acc)
            if args.verbose:
                print(f'Fine-Tuning Loss: {prune_ft_loss:.4f}')
                print(f'Fine-Tuning Acc.: {prune_ft_acc:.4f}')


    # check trn acc of best pruned model
    best_pruned_model = best_pruned_model.to(device)
    num_corr = 0.
    num_ex = 0.
    for model_in, target in dl:
        # perform a training update
        target = target.to(device)
        model_in = torch.flatten(model_in, start_dim=1)
        model_in = model_in.to(device)
        with torch.no_grad():
            output = best_pruned_model(model_in)
        preds = output > 0.5
        num_corr += float(torch.sum(preds.long().squeeze() == target))
        num_ex += target.shape[0]
    final_acc = num_corr / num_ex
    best_pruned_model = best_pruned_model.cpu()
    if args.verbose:
        print(f'Final Pruning Acc.: {final_acc:.4f}')
    return best_pruned_model, final_acc
