import torch
import torch.nn.utils.prune as prune

from strategies import scratchStrategies
from utilities.utilities import LAMPUnstructured, GradientUnstructured, UndecayedUnstructured
from optimizers import losses

#### Base Class
class IMP(scratchStrategies.Dense):
    """Iterative Magnitude Pruning Base Class"""

    def __init__(self, **kwargs) -> None:
        super().__init__(**kwargs)

        self.n_phases = self.run_config['n_phases']
        self.n_epochs_per_phase = self.run_config['n_epochs_per_phase']
        self.n_epochs_to_split = self.run_config['n_epochs_to_split']

        if self.n_epochs_to_split is not None:
            assert self.n_epochs_per_phase in [None, 0]
            if self.n_epochs_to_split % self.n_phases == 0:
                self.n_epochs_per_phase = {p: self.n_epochs_to_split // self.n_phases for p in
                                           range(1, self.n_phases + 1, 1)}
            else:
                self.n_epochs_per_phase = {p: self.n_epochs_to_split // self.n_phases for p in
                                           range(1, self.n_phases, 1)}
                self.n_epochs_per_phase[self.n_phases] = self.n_epochs_to_split - (self.n_phases - 1) * (
                        self.n_epochs_to_split // self.n_phases)
        else:
            self.n_epochs_per_phase = {p: self.n_epochs_per_phase for p in range(1, self.n_phases + 1, 1)}

    def at_train_end(self, **kwargs):
        # Sparsity factor on remaining weights after each round, yields desired_sparsity after all rounds
        prune_per_phase = 1 - (1 - self.goal_sparsity) ** (1. / self.n_phases)
        for phase in range(1, self.n_phases + 1, 1):
            self.pruning_step(pruning_sparsity=prune_per_phase)
            self.current_sparsity = 1 - (1 - prune_per_phase) ** phase
            if phase == self.n_phases or self.run_config['retrain_adaptive_in_every_cycle'] is True:
                self.callbacks['after_pruning_callback']()
            self.finetuning_step(pruning_sparsity=prune_per_phase, phase=phase)

    def finetuning_step(self, pruning_sparsity, phase):
        self.callbacks['finetuning_callback'](pruning_sparsity=pruning_sparsity,
                                              n_epochs_finetune=self.n_epochs_per_phase[phase],
                                              phase=phase)

    def get_pruning_method(self):
        if self.run_config['pruning_selector'] in ['global', 'uniform']:
            # For uniform this is not actually needed, we always select using L1
            return prune.L1Unstructured
        elif self.run_config['pruning_selector'] == 'random':
            return prune.RandomUnstructured
        elif self.run_config['pruning_selector'] == 'LAMP':
            return lambda amount: LAMPUnstructured(parameters_to_prune=self.parameters_to_prune,
                                                   amount=amount)
        elif self.run_config['pruning_selector'] in ['gradient', 'gradient_uniform']:
            gradients = self.callbacks['gradient_estimation_callback']()
            return lambda amount: GradientUnstructured(parameters_to_prune=self.parameters_to_prune,
                                                   amount=amount, gradients=gradients, uniform=(self.run_config['pruning_selector'] == 'gradient_uniform'))
        elif self.run_config['pruning_selector'] in ['undecayed', 'undecayed_uniform']:
            gradients = self.callbacks['gradient_estimation_callback']()
            return lambda amount: UndecayedUnstructured(parameters_to_prune=self.parameters_to_prune,
                                                   amount=amount, gradients=gradients, uniform=(self.run_config['pruning_selector'] == 'undecayed_uniform'), wd=self.run_config['weight_decay'])

        else:
            raise NotImplementedError

    def final(self):
        super().final()
        self.callbacks['final_log_callback']()


class SoftRetraining(IMP):
    """Base class for smoothing the pruning process out,
     i.e. do regular training with some method forcing weights to zero, eventually prune"""

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        assert self.n_phases == 1, "Works only with a single phase for now"

        self.is_in_smoothing_phase = False

    def at_train_end(self, **kwargs):
        # Currently works with a single phase only
        self.is_in_smoothing_phase = True
        self.finetuning_step(pruning_sparsity=self.goal_sparsity, phase=1)
        self.is_in_smoothing_phase = False

        # Make current pruning permanent, if existing
        self.make_pruning_permant()

        # Do the final pruning
        self.pruning_step(pruning_sparsity=self.goal_sparsity)

        self.callbacks['after_pruning_callback']()


#### Pruning stable variants
class SoftGSM(SoftRetraining, scratchStrategies.GSM):
    """GSM for retraining"""

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

class SoftLC(SoftRetraining, scratchStrategies.LC):
    """LC for retraining"""

    def __init__(self, **kwargs):
        super().__init__(**kwargs)