import torch
from timm.scheduler.cosine_lr import CosineLRScheduler
from omegaconf import DictConfig
from typing import List

class CosineLRSchedulerWrapper(CosineLRScheduler):
    def __init__(self, optimizer: torch.optim.Optimizer, total_training_opt_steps: int, trainer: DictConfig, warmup_epochs: int, min_lr: float, warmup_lr_init: float, t_in_epochs: bool = False):
        self.optimizer = optimizer
        self.trainer = trainer
        self.min_lr = min_lr
        self.warmup_lr_init = warmup_lr_init
        self.t_in_epochs = t_in_epochs
        self.num_opt_steps_per_epoch = (total_training_opt_steps // self.trainer.max_epochs)
        self.warmup_steps = warmup_epochs * self.num_opt_steps_per_epoch
        self.total_steps = total_training_opt_steps
        
        super().__init__(
            optimizer=self.optimizer,
            t_initial=self.total_steps,
            lr_min=self.min_lr,
            warmup_lr_init=self.warmup_lr_init,
            warmup_t=self.warmup_steps,
            t_in_epochs=self.t_in_epochs
        )
