import torch
import torch.nn.utils.prune as prune
from utilities import LAMPUnstructured

import metrics.metrics
from utilities import FixedLR
from typing import List
from collections import OrderedDict
from strategies.strategies import Dense
from math import ceil, log


class IMP(Dense):
    def __init__(self, desired_sparsity: float, n_phases: int = 1, n_epochs_per_phase: int = 1) -> None:
        super().__init__()
        self.desired_sparsity = desired_sparsity
        self.n_phases = n_phases # If None, compute this manually using Renda's approach of pruning 20% per phase
        self.n_epochs_per_phase = n_epochs_per_phase
        # Sparsity factor on remaining weights after each round, yields desired_sparsity after all rounds
        if self.n_phases is not None:
            self.pruning_sparsity = 1 - (1 - self.desired_sparsity) ** (1. / self.n_phases)
        else:
            self.pruning_sparsity = 0.2
            self.n_phases = ceil(log(1 - self.desired_sparsity, 1 - self.pruning_sparsity))

    def at_train_end(self, model, finetuning_callback, restore_callback, save_model_callback, after_pruning_callback, opt):
        restore_callback()  # Restore to checkpoint model
        prune_per_phase = self.pruning_sparsity
        for phase in range(1, self.n_phases + 1, 1):
            self.pruning_step(model, pruning_sparsity=prune_per_phase)
            self.current_sparsity = 1 - (1-prune_per_phase) ** phase
            after_pruning_callback(desired_sparsity=self.current_sparsity)
            self.finetuning_step(desired_sparsity=self.current_sparsity, finetuning_callback=finetuning_callback, phase=phase)
            save_model_callback(model_type=f"{self.current_sparsity}-sparse_final")  # removing of pruning hooks happens in restore_callback

    def finetuning_step(self, desired_sparsity, finetuning_callback, phase):
        finetuning_callback(desired_sparsity=desired_sparsity, n_epochs_finetune=self.n_epochs_per_phase,
                            phase=phase, n_phases_total=self.n_phases)

    def get_pruning_method(self):
        return prune.L1Unstructured

    def final(self, model, final_log_callback):
        super().final(model=model, final_log_callback=final_log_callback)
        final_log_callback()
