import torch
import fire
from lenet import get_layer_params_lenet, set_layer_params_lenet
from vgg import get_layer_params_vgg, set_layer_params_vgg
import numpy as np


def prune_datadepdet(model_dir, out_dir, data_vals, epsilon: float, sparsity: float, model_type, layers=[1]):
    """Outputs: samples required, sparsity achieved."""
    assert epsilon is None or sparsity is None

    if model_type == "lenet":
        get_params, set_params = get_layer_params_lenet, set_layer_params_lenet
    elif model_type == "vgg":
        get_params, set_params = get_layer_params_vgg, set_layer_params_vgg
    else:
        raise Exception(f'Unknown model type {model_type}')

    data_vals = torch.load(data_vals, map_location=torch.device('cpu'))
    params = torch.load(model_dir, map_location=torch.device('cpu'))

    pos_sensitivities, neg_sensitivities = data_vals['pos_sensitivities'], data_vals['neg_sensitivities']

    for layer in layers:
        mask, samples_required, implied_eps = prune_layer(
            pos_sensitivities[layer-1], neg_sensitivities[layer-1], epsilon, sparsity, data_vals['layer_deltas'][layer-1].squeeze())
        print(f'{samples_required}')
        print(f'{np.sum(mask == 0) / mask.size}')
        if epsilon is not None:
            print(f'{epsilon}')
        else:
            print(f'{implied_eps}')

        weight, bias, old_mask = get_params(params, layer)
        assert torch.equal(torch.ones(old_mask.shape), old_mask)
        mask, bias_mask = mask[:, :-1], mask[:,-1]
        set_params(params, layer, weight, bias * bias_mask, torch.Tensor(mask))

    torch.save(params, out_dir)


def prune_layer(pos_sensitivities, neg_sensitivities, epsilon, sparsity, delta_output):
    K = 2
    def eps_prune(sensitivities):
        mask = np.ones(sensitivities.shape)
        implied_epsilons = []

        for neuron in range(sensitivities.shape[0]):
            implied_eps = 0

            while mask[neuron].sum() > 0:
                prune_idx = np.argmin(sensitivities[neuron])
                if implied_eps + 3 * K * sensitivities[neuron, prune_idx] > epsilon / delta_output[neuron]:
                    break
                else:
                    implied_eps += 3 * K * sensitivities[neuron, prune_idx]
                    mask[neuron, prune_idx] = 0
                    sensitivities[neuron, prune_idx] = float('inf')

            implied_epsilons.append(implied_eps * delta_output[neuron])

        return mask, mask.sum(), max(implied_epsilons)


    pos_sensitivities, neg_sensitivities = pos_sensitivities.numpy(), neg_sensitivities.numpy()
    if epsilon is not None:
        pos_mask, pos_samples, pos_epsilon = eps_prune(pos_sensitivities)
        neg_mask, neg_samples, neg_epsilon = eps_prune(neg_sensitivities)
        return np.logical_or(pos_mask, neg_mask), pos_samples + neg_samples, max(pos_epsilon, neg_epsilon)

    elif sparsity is not None:
        mask = np.ones(pos_sensitivities.shape)
        implied_epsilons = []
        for neuron in range(pos_sensitivities.shape[0]):
            prune_size = int(sparsity * pos_sensitivities[neuron].shape[0])
            zero_vals = np.logical_and(pos_sensitivities[neuron] == 0, neg_sensitivities[neuron] == 0)

            if zero_vals.sum() > prune_size:
                mask[neuron, zero_vals.nonzero()[0][:prune_size]] = 0
                implied_epsilons.append(0)
                continue

            mask[neuron, zero_vals] = 0
            component_prune_size = int((prune_size - zero_vals.sum()) / 2)

            def prune(sensitivities, other):
                sensitivities = np.copy(sensitivities[neuron])
                sensitivities[zero_vals] = float('inf')
                sensitivities[other[neuron] != 0] = float('inf')
                _, prune_indices = torch.topk(torch.from_numpy(sensitivities), component_prune_size, largest=False)
                mask[neuron, prune_indices] = 0
                implied_epsilons.append(3 * K * sensitivities[prune_indices].sum() * delta_output[neuron])

            prune(pos_sensitivities, neg_sensitivities)
            prune(neg_sensitivities, pos_sensitivities)

        return mask, mask.sum(), max(implied_epsilons)

    else:
        raise ValueError("Either sparsity or epsilon must not be None")


if __name__ == '__main__':
    fire.Fire(prune_datadepdet)
