import torch
import fire
from lenet import get_layer_params_lenet, set_layer_params_lenet
from vgg import get_layer_params_vgg, set_layer_params_vgg
import numpy as np


def prune_datadep(model_dir, out_dir, data_vals, epsilon: float, sparsity: float, model_type, layers=[1]):
    """Outputs: samples required, sparsity achieved."""
    assert epsilon is None or sparsity is None

    if model_type == "lenet":
        get_params, set_params = get_layer_params_lenet, set_layer_params_lenet
    elif model_type == "vgg":
        get_params, set_params = get_layer_params_vgg, set_layer_params_vgg
    else:
        raise Exception(f'Unknown model type {model_type}')

    data_vals = torch.load(data_vals, map_location=torch.device('cpu'))
    params = torch.load(model_dir, map_location=torch.device('cpu'))

    pos_sensitivities, neg_sensitivities = data_vals['pos_sensitivities'], data_vals['neg_sensitivities']
    total_hidden = sum(data_vals['hidden_sizes'])

    for layer in layers:
        mask, scalars, samples_required, implied_eps = prune_layer(
            pos_sensitivities[layer-1], neg_sensitivities[layer-1], epsilon, sparsity, total_hidden, data_vals['delta_prob'], data_vals['layer_deltas'][layer-1].squeeze())
        print(f'{samples_required}')
        print(f'{np.sum(mask == 0) / mask.size}')
        print(f'{implied_eps}')

        weight, bias, old_mask = get_params(params, layer)
        assert torch.equal(torch.ones(old_mask.shape), old_mask)
        scalars, bias_scalars = scalars[:, :-1], scalars[:,-1]
        mask, bias_mask = mask[:, :-1], mask[:,-1]
        set_params(params, layer, weight * scalars, bias * bias_mask * bias_scalars, torch.Tensor(mask))

    torch.save(params, out_dir)


def prune_layer(pos_sensitivities, neg_sensitivities, epsilon, sparsity, total_hidden, delta_prob, delta_output):

    def sample(sensitivities):
        sensitivities = sensitivities.cpu().numpy()
        total_sensitivity = sensitivities.sum(axis=1)
        # numpy broadcast semantics require broadcast dim to be last
        probs = (sensitivities.T / total_sensitivity).T

        mask = np.zeros(sensitivities.shape)
        scalars = np.zeros(sensitivities.shape)

        neuron_samples = []
        for neuron in range(sensitivities.shape[0]):
            K = 2
            # If we never saw any reason to keep any of these weights, just skip the neuron
            if total_sensitivity[neuron] == 0:
                continue

            if epsilon is not None:
                samples = np.ceil((6 + 2 * epsilon) * delta_output[neuron]**2 * total_sensitivity[neuron] * K * np.log(8 * total_hidden / delta_prob) * epsilon**(-2) )
                neuron_samples.append(samples)
                # If there are too many samples, just keep the neuron as-is
                if samples > 1000000:
                    mask[neuron, :] = 1
                    scalars[neuron, :] = 1
                    continue
                else:
                    indices = np.random.choice(sensitivities.shape[1], int(samples), p=probs[neuron])
            elif sparsity is not None:
                # Mainly because I don't want to bother with the implied epsilon calculation (and we never use it)
                # We know how to do this, though, see notes 9/20
                raise NotImplementedError("Sparsity-first data-dependent weight pruning is not implemented")
            else:
                raise ValueError("Either sparsity or epsilon must not be None")

            for index in indices:
                mask[neuron, index] = 1
                scalars[neuron, index] += 1 / samples / probs[neuron, index]

        return mask, scalars, sum(neuron_samples)

    pos_mask, pos_scalars, pos_samples = sample(pos_sensitivities)
    neg_mask, neg_scalars, neg_samples = sample(neg_sensitivities)
    return np.logical_or(pos_mask, neg_mask), pos_scalars + neg_scalars, pos_samples + neg_samples, epsilon


if __name__ == '__main__':
    fire.Fire(prune_datadep)
