'''this is the local pruning scheme by maintaining the same pruning ratio across all the layers'''
import dataclasses
import numpy as np

from foundations import hparams
import models.base
from pruning import base
from pruning.mask import Mask


@dataclasses.dataclass
class PruningHparams(hparams.PruningHparams):
    pruning_fraction: float = 0.2
    pruning_layers_to_ignore: str = None

    _name = 'Hyperparameters for Sparse Global Pruning'
    _description = 'Hyperparameters that modify the way pruning occurs.'
    _pruning_fraction = 'The fraction of additional weights to prune from the network.'
    _layers_to_ignore = 'A comma-separated list of addititonal tensors that should not be pruned.'


class Strategy(base.Strategy):
    @staticmethod
    def get_pruning_hparams() -> type:
        return PruningHparams

    @staticmethod
    def prune(pruning_hparams: PruningHparams, trained_model: models.base.Model, current_mask: Mask = None):
        current_mask = Mask.ones_like(trained_model).numpy() if current_mask is None else current_mask.numpy()

        number_of_remaining_weights = [np.sum(v) for v in current_mask.values()]

        number_of_weights_to_prune = [np.ceil(
            pruning_hparams.pruning_fraction * number_of_remaining_weight).astype(int) for number_of_remaining_weight in number_of_remaining_weights]

        # Determine which layers can be pruned.
        prunable_tensors = set(trained_model.prunable_layer_names)
        if pruning_hparams.pruning_layers_to_ignore:
            prunable_tensors -= set(pruning_hparams.pruning_layers_to_ignore.split(','))

        # Get the model weights.
        weights = {k: v.clone().cpu().detach().numpy()
                   for k, v in trained_model.state_dict().items()
                   if k in prunable_tensors}

        
        # Create a vector of all the unpruned weights in the model.
        weight_vectors = []
        for k, v in weights.items():
            weight_vector = [v[current_mask[k] == 1]]
            weight_vectors.append(weight_vector)
        
        thresholds = []
        for i, weight_vector in enumerate(weight_vectors):
            threshold = np.sort(np.abs(weight_vector))[0][number_of_weights_to_prune[i]]
            thresholds.append(threshold)
                     
        
        new_mask = Mask({k: np.where(np.abs(v) > thresholds[i], current_mask[k], np.zeros_like(v))
                         for i, (k, v) in enumerate(weights.items())})
                         
        for k in current_mask:
            if k not in new_mask:
                new_mask[k] = current_mask[k]

        return new_mask
