import torch.optim as optim
from utils.cosine import CosineLRScheduler


def get_scheduler(
        sched: str,
        optimizer,
        epoch: int,
        decay_t=0,
        gamma=0.1,
        **kwargs
        ):

    args = kwargs['args']
    lr_scheduler = None

    if sched == 'step':
        lr_scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer,
            milestones=[int(epoch * t / decay_t)
                        for t in range(1, decay_t+1)],
            gamma=gamma,
        )
    elif sched == 'cosine':
        lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max=epoch,
            eta_min=0)
    elif sched == 'cosine_warmup':
        lr_scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer, 10, T_mult=2, eta_min=0, last_epoch=-1, verbose=False)
    elif sched == 'cosine_timm':
        noise_range = None
        noise_args = dict(
            noise_range_t=noise_range,
            noise_pct=getattr(args, 'lr_noise_pct', 0.67),
            noise_std=getattr(args, 'lr_noise_std', 1.),
            noise_seed=getattr(args, 'seed', 42),
        )
        cycle_args = dict(
            cycle_mul=getattr(args, 'lr_cycle_mul', 1.),
            cycle_decay=getattr(args, 'lr_cycle_decay', 0.1),
            cycle_limit=getattr(args, 'lr_cycle_limit', 1),
        )
        lr_scheduler = CosineLRScheduler(
            optimizer,
            t_initial=epoch,
            lr_min=args.min_lr,
            warmup_lr_init=args.warmup_lr,
            warmup_t=args.warmup_epochs,
            k_decay=getattr(args, 'lr_k_decay', 1.0),
            **cycle_args,
            **noise_args,
        )
    elif sched == 'onecycle':
        lr_scheduler = optim.lr_scheduler.OneCycleLR(
            optimizer, max_lr=kwargs['max_lr'],
            total_steps=epoch,
            pct_start=kwargs['pct_start'],
            div_factor=kwargs['max_lr']/kwargs['lr'],
            final_div_factor=kwargs['lr']/kwargs['min_lr'],
            anneal_strategy='cos', cycle_momentum=False
        )
    else:
        raise NotImplementedError(f'check your lr scheduler : {sched}')

    if lr_scheduler:
        print('scheduler type: {}'.format(sched))

    return lr_scheduler
