import torch
import copy
from tqdm import tqdm
from .helpers import ModelCheckpoint, ok_layer, update_masks
from .normal import Trainer as NormalTrainer


class Trainer(NormalTrainer):
    """
    A Trainer that enables iterative training or 1shot training based on 1 original fp training

    Parameters
    ----------
    train_cfg : TrainConfig
        Training configuration with training regime information
    model : nn.Module
        Initialized Model that should be trained
    loaders : tuple
        Contains (train, val, test) data loaders

    prune_perc : float
        The amount of % to prune of the remaining weights at each iteration. Total % will follow log scale
    n_prunes : int
        Number of pruning iterations
    pruning_method : str
        One of {l1, l1_layer, l1_normalized, lamp}
    rewind_epoch : int
        Epoch to which the weights should rewind
    iterative : bool
        If True will use the model from the previous epoch to prune new one
        If False will always go back to the first training
    use_best_model : bool
        If True uses early stopping to select best model.

    """
    def __init__(self, train_cfg, model, loaders, prune_perc=0.2, iterative=False, use_best_model=True,
                 n_prunes=20, rewind_epoch=4, pruning_method="l1"):
        super().__init__(train_cfg, model, loaders)

        self.prune_perc_per_iter = prune_perc
        self.n_prunes = n_prunes
        self.iterative = iterative
        self.use_best_model = use_best_model
        self.rewind_epoch = rewind_epoch
        self.pruning_method = pruning_method

    def start(self):
        # Add hooks to zero-out gradient for pruned weights
        masks = {}
        hook_handles = []
        for n, p in self.model.named_parameters():
            if not ok_layer(n):
                continue
            mask = torch.ones_like(p, device=self.cfg.device)

            # Add hooks to zero out the gradients where necessary
            hook_handles.append(
                p.register_hook(lambda grad, mask=mask: grad*mask)
            )
            masks[n] = mask
        self.model.to(self.cfg.device)

        # Repeated training-loop
        prune_perc = 0
        checkpoint = None
        reinit_state = None
        for iteration in tqdm(range(self.n_prunes), desc="Pruning iteration", leave=False):
            # Go back to rewind state
            with torch.no_grad():
                if reinit_state is None and checkpoint is not None:  # initial weights to rewind to
                    reinit_state = copy.deepcopy(checkpoint.get_rewind())
                if reinit_state is not None:
                    self.model.load_state_dict(reinit_state)

                for n, p in self.model.named_parameters():
                    if not ok_layer(n):
                        continue
                    p[:] = masks[n]*p

            # Apply a full training loop
            logdir = self.cfg.logdir / f"{iteration:02d}--{prune_perc*100:07.3f}"
            logdir.mkdir()
            new_checkpoint = ModelCheckpoint(
                filepath=logdir / "model.pt",
                model=self.model,
                period=1,
                rewind=self.rewind_epoch if iteration == 0 else -1
            )
            self.training_loop(logdir, new_checkpoint)

            # Iterative or not
            if checkpoint is None:
                checkpoint = new_checkpoint
            else:
                if self.iterative:  # iterative needs an updated checkpoint
                    checkpoint = new_checkpoint
                else:
                    pass  # Do Nothing

            # Prune weights
            prune_perc += self.prune_perc_per_iter * (1.-self.prune_perc_per_iter)**iteration
            ref_weights = checkpoint.best_weights if self.use_best_model else checkpoint.last_weights
            update_masks(ref_weights, masks, prune_perc, self.pruning_method)
