from mlwiz.training.callback.optimizer import Optimizer
from mlwiz.training.callback.scheduler import Scheduler
from mlwiz.training.event.state import State
from torch.optim.lr_scheduler import ConstantLR


class LinearAlphaScheduler(Scheduler):

    def __init__(
        self, scheduler_class_name: str, optimizer: Optimizer, **kwargs: dict
    ):
        # Be careful, this seems to have an influence! Keep factor to 1.
        self.scheduler = ConstantLR(optimizer, factor=1.0)  # dummy, not used

        self._step_size = None
        self.annealing_amount: float = kwargs["annealing_amount"]
        self.start_epoch: int = kwargs["start_epoch"]
        self.max_epochs: float = kwargs["max_epochs"]
        assert 0.0 <= self.annealing_amount < 1.0

    def on_training_epoch_start(self, state: State):
        super().on_training_epoch_start(state)

        m = state.model

        if m.alpha_prior_scale is not None:
            if state.epoch < self.max_epochs:

                if self._step_size is None:
                    self._step_size = (
                        m.alpha_prior_scale.data * self.annealing_amount
                    ) / (self.max_epochs - self.start_epoch + 1)

                if state.epoch >= self.start_epoch:
                    if m.alpha_prior_scale is not None:
                        m.alpha_prior_scale.data = (
                            m.alpha_prior_scale.data - self._step_size
                        )
                    m.apply_alpha_prior(True)

                else:
                    m.apply_alpha_prior(False)
            else:
                m.apply_alpha_prior(True)

            # print('Alpha prior scale:', m.alpha_prior_scale.data)

    def on_training_epoch_end(self, state: State):
        pass


class CifarScheduler(Scheduler):

    def __init__(
        self, scheduler_class_name: str, optimizer: Optimizer, **kwargs: dict
    ):
        max_epochs = kwargs.pop("max_epochs")
        milestones = [int(max_epochs / 2), int(3 * max_epochs // 4)]
        kwargs["milestones"] = milestones

        super().__init__(scheduler_class_name, optimizer, **kwargs)
        self.optimizer = optimizer

    def on_training_epoch_end(self, state: State):
        """
        Performs a scheduler's step at the end of the training epoch.

        Args:
            state (:class:`~training.event.state.State`):
                object holding training information
        """
        self.scheduler.step()
