import logging
import os
from typing import Union, Optional, List

import torch
import torch.nn as nn

from . import Callback
from torch.optim.lr_scheduler import _LRScheduler


class LRScheduler(Callback):
    def __init__(self, lr_scheduler: _LRScheduler):
        super(LRScheduler, self).__init__()
        self.lr_scheduler = lr_scheduler

    def on_train_epoch_end(self, log):
        self.lr_scheduler.step()
