import math
from torch.optim.lr_scheduler import LambdaLR
from functools import partial


def get_scheduler(
    optimizer,
    start_lr,
    max_lr,
    min_lr,
    warmup_epochs,
    sustain_epochs,
    total_epochs,
    decay,
    mode="cosine",
):
    def lr_lambda(epoch):
        if epoch < warmup_epochs:
            return (max_lr - start_lr) / warmup_epochs * epoch + start_lr

        elif epoch < warmup_epochs + sustain_epochs:
            return max_lr

        elif mode == "exponential":
            return (max_lr - min_lr) * decay ** (
                epoch - warmup_epochs - sustain_epochs
            ) + min_lr

        elif mode == "step":
            return max_lr * decay ** ((epoch - warmup_epochs - sustain_epochs) // 2)

        elif mode == "cosine":
            decay_total_epochs = total_epochs - warmup_epochs - sustain_epochs + 3
            decay_epoch_index = epoch - warmup_epochs - sustain_epochs
            phase = math.pi * decay_epoch_index / decay_total_epochs
            cosine_decay = 0.5 * (1 + math.cos(phase))
            return (max_lr - min_lr) * cosine_decay + min_lr

        else:
            raise ValueError(
                f"Unsupported mode '{mode}'. Supported modes are 'exp', 'step', 'cosine'."
            )

    return LambdaLR(optimizer, lr_lambda)


def _get_cosine_schedule_with_warmup_lr_lambda(
    current_step: int,
    *,
    num_warmup_steps: int,
    num_training_steps: int,
    num_cycles: float,
):
    if current_step < num_warmup_steps:
        return float(current_step) / float(max(1, num_warmup_steps))
    progress = float(current_step - num_warmup_steps) / float(
        max(1, num_training_steps - num_warmup_steps)
    )
    return max(
        0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))
    )


def get_cosine_schedule_with_warmup(
    optimizer, num_warmup_steps, num_training_steps, num_cycles=0.5, last_epoch=-1
):
    """
    Create a schedule with a learning rate that decreases following the values of the cosine function between the
    initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
    initial lr set in the optimizer.

    Args:
        optimizer ([`~torch.optim.Optimizer`]):
            The optimizer for which to schedule the learning rate.
        num_warmup_steps (`int`):
            The number of steps for the warmup phase.
        num_training_steps (`int`):
            The total number of training steps.
        num_cycles (`float`, *optional*, defaults to 0.5):
            The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
            following a half-cosine).
        last_epoch (`int`, *optional*, defaults to -1):
            The index of the last epoch when resuming training.

    Return:
        `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
    """

    lr_lambda = partial(
        _get_cosine_schedule_with_warmup_lr_lambda,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=num_training_steps,
        num_cycles=num_cycles,
    )
    return LambdaLR(optimizer, lr_lambda, last_epoch)
