# From https://github.com/BenMussay/Data-Independent-Neural-Pruning-via-Coresets/blob/master/coreset.py

from typing import Callable, Tuple
import sys
import torch
import numpy as np

def compress_fc_layer(layer1: Tuple[torch.Tensor, torch.Tensor],
                      layer2: torch.Tensor,
                      epsilon,
                      sparsity,
                      delta,
                      activation: Callable,
                      upper_bound):
    assert epsilon is None or sparsity is None
    num_neurons = layer1[1].shape[0]
    points = np.concatenate(
        (layer1[0].cpu().numpy(), layer1[1].view(num_neurons, 1).cpu().numpy()),
        axis=1)
    points = torch.tensor(points)
    points_norm = points.norm(dim=1)

    weights = layer2.t()
    weights = torch.abs(weights).max(dim=1)[0]  # max returns (values, indices)
    assert weights.shape[0] == points.shape[0]

    sensitivity = weights * torch.abs(activation(upper_bound * points_norm))
    sensitivity = sensitivity.cpu().numpy()
    total_sensitivity = sensitivity.sum()
    prob = sensitivity / np.sum(sensitivity)

    # TODO: what is C?
    C = 1
    d = points.shape[1]
    if epsilon is not None:
        samples = int(np.ceil(
            C * total_sensitivity * (d * np.log(total_sensitivity) + np.log(1 / delta)) / epsilon**2))

        if samples > 1000000:
            # Don't do any sampling, keep the network as is.
            return np.ones(points.shape[0]), layer2, samples, epsilon
        else:
            idxs = np.random.choice(points.shape[0], size=samples, p=prob)

        implied_epsilon = epsilon
    elif sparsity is not None:
        coreset_size = (1 - sparsity) * points.shape[0]
        indices = set()
        idxs = []

        samples = 0
        while len(indices) < coreset_size:
            i = np.random.choice(a=points.shape[0], size=1, p=prob).tolist()[0]
            idxs.append(i)
            indices.add(i)
            samples += 1

        implied_epsilon = np.sqrt(C * total_sensitivity * (d * np.log(total_sensitivity) + np.log(1 / delta)) / samples)
    else:
        raise ValueError("Either epsilon or sparsity must not be None")

    mask = np.zeros(points.shape[0])
    scalars = np.zeros(points.shape[0])
    for idx in idxs:
        mask[idx] = 1
        scalars[idx] += 1 / prob[idx] / samples

    return mask, layer2 * scalars, samples, implied_epsilon
