from typing import Dict, Optional
import torch.optim as optim
from antgine.callback import Callback


class LRSchedulerCallback(Callback):
    """
        Callback allowing for learning rate change during training.
    """
    def __init__(self, scheduler: optim.lr_scheduler._LRScheduler,
                 epoch_begin=True, metric: Optional[str] = None):
        """
        :param torch.optim.lr_scheduler._LRScheduler scheduler: Learning rate scheduler.
        :param bool epoch_begin: Whether scheduler.step() will be called.
        :param Optional[str] metric: Metric value to be sent to the scheduler if epoch_begin is false.
        """
        assert not (epoch_begin and metric is None)
        super().__init__()
        self._scheduler = scheduler
        self._epoch_begin = epoch_begin
        self._metric = metric

    def on_epoch_begin(self, epoch: int):
        """
            Scheduler step when epoch begins if epoch_begin=True.
            See :meth:`antgine.callback.Callback.on_epoch_begin`
        """
        if self._epoch_begin:
            self._scheduler.step()

    def on_epoch_test_end(self, epoch: int, metrics: Dict[str, float]):
        """
            Scheduler step when epoch ends if epoch_begin=False.
            See :meth:`antgine.callback.Callback.on_epoch_begin`
        """
        if not self._epoch_begin:
            self._scheduler.step(metrics[self._metric])
