import torch
import copy
from tqdm import tqdm
from .helpers import ModelCheckpoint
from .helpers import ok_layer, compute_threshold_by_layer, update_masks, apply_threshold_to_model
from .onebase import Trainer as OneBaseTrainer


class Trainer(OneBaseTrainer):
    """
    A Trainer that enables oneshot training followed by a rewind

    Parameters
    ----------

    train_cfg : TrainConfig
        Training configuration with training regime information
    model : nn.Module
        Initialized Model that should be trained
    loaders : tuple
        Contains (train, val, test) data loaders

    prune_perc : float
        The amount of % to prune of the remaining weights at each iteration. Total % will follow log scale
    n_prunes : int
        Number of pruning iterations
    pruning_method : str
        One of {l1, l1_layer, lamp, l1_normalized}
    rewind_epoch : int
        Epoch to which the weights should rewind
    use_best_model : bool
        If True uses early stopping to select best model.

    adaptive_start : int
        At which epoch to start the adaptive pruning
    adaptive_end : int
        At which epoch to stop the adaptive pruning
    """
    def __init__(self, train_cfg, model, loaders, prune_perc=0.2, use_best_model=True, start_prune_iter=0,
                 n_prunes=20, rewind_epoch=4, pruning_method="l1", adaptive_start=0, adaptive_end=-1):
        super().__init__(train_cfg, model, loaders, prune_perc=prune_perc, iterative=False,
                         use_best_model=use_best_model, n_prunes=n_prunes, rewind_epoch=rewind_epoch,
                         pruning_method=pruning_method)
        self.adaptive_start = adaptive_start
        self.adaptive_end = adaptive_end if adaptive_end >= 0 else train_cfg.n_epochs
        self.start_prune_iter = start_prune_iter

    def start(self):
        # Repeated training-loop
        self.model.to(self.cfg.device)
        init_state = copy.deepcopy(self.model.state_dict())
        for iteration in tqdm(range(self.start_prune_iter, self.start_prune_iter+self.n_prunes), desc="Pruning iteration", leave=False):
            # -- Update prune_perc for next iteration
            prune_perc = 1 - (1-self.prune_perc_per_iter)**iteration

            # -- First training with adaptive pruning during training
            self.adaptive_training = True
            self._final_prune_perc = prune_perc
            self._running_prune_perc = 0

            logdir = self.cfg.logdir / f"{iteration:02d}--{prune_perc*100:07.3f}" / "adaptive"
            logdir.mkdir(parents=True)

            with torch.no_grad():  # Same init for everyone
                self.model.load_state_dict(init_state)
            checkpoint = ModelCheckpoint(
                filepath=logdir / "model.pt",
                model=self.model,
                period=1,
                rewind=self.rewind_epoch,
                save_from=self.adaptive_end
            )
            self.training_loop(logdir, checkpoint)

            rewind_state = copy.deepcopy(checkpoint.get_rewind())

            # -- Rewind floating point training
            self._reset_model()
            self.adaptive_training = False

            ref_weights = checkpoint.best_weights if self.use_best_model else checkpoint.last_weights
            masks, hook_handles = self.add_mask_hooks()
            update_masks(ref_weights, masks, prune_perc, self.pruning_method)

            with torch.no_grad():
                self.model.load_state_dict(rewind_state)
                for n, p in self.model.named_parameters():
                    if not ok_layer(n):
                        continue
                    p[:] = masks[n]*p

            logdir = self.cfg.logdir / f"{iteration:02d}--{prune_perc*100:07.3f}" / "rewind"
            logdir.mkdir(parents=True)
            checkpoint = ModelCheckpoint(
                filepath=logdir / "model.pt",
                model=self.model,
                period=1,
                rewind=self.rewind_epoch
            )
            self.training_loop(logdir, checkpoint)

            for hook_handle in hook_handles:  # Clean-up
                hook_handle.remove()

    @torch.no_grad()
    def add_mask_hooks(self):
        # Add hooks to zero-out gradient for pruned weights
        masks = {}
        hook_handles = []
        for n, p in self.model.named_parameters():
            if not ok_layer(n):
                continue
            mask = torch.ones_like(p, device=self.cfg.device)

            # Add hooks to zero out the gradients where necessary
            hook_handles.append(
                p.register_hook(lambda grad, mask=mask: grad*mask)
            )
            masks[n] = mask
        return masks, hook_handles

    def on_iter_end(self, epoch, iteration):
        if not self.adaptive_training:
            return
        if epoch < self.adaptive_start:
            self._running_prune_perc = 0
        elif epoch >= self.adaptive_end:
            self._running_prune_perc = self._final_prune_perc
        else:  # if self._final_prune_perc > 0:
            iters_per_epoch = len(self.data["train"])
            n_iters = iters_per_epoch * (self.adaptive_end-self.adaptive_start)
            act_iter = (epoch-self.adaptive_start) * iters_per_epoch + iteration + 1
            self._running_prune_perc = self._final_prune_perc * act_iter/n_iters

    @torch.no_grad()
    def on_model_change(self):
        if not self.adaptive_training:
            return

        if self._running_prune_perc == 0:
            # self._reset_model()  # FIXME: remove this line?
            return
        threshold_per_layer = compute_threshold_by_layer(
            self.model.state_dict(), self._running_prune_perc, self.pruning_method
        )
        apply_threshold_to_model(self.model, threshold_per_layer, self._running_prune_perc)

    def _reset_model(self):
        self.model.apply(
            lambda x: (x.set_prune(-1, 0) if hasattr(x, "th_prune") else None)
        )
