import math

import torch





def _wsd_lr_lambda(current_step: int, num_training_steps: int, warmup_ratio: float, decay_ratio: float) -> float:

    if num_training_steps <= 0:

        return 1.0



    warmup_steps = int(num_training_steps * warmup_ratio)

    decay_steps = int(num_training_steps * decay_ratio)

    warmup_steps = max(warmup_steps, 0)

    decay_steps = max(decay_steps, 0)



    stable_steps = num_training_steps - warmup_steps - decay_steps

    if stable_steps < 0:

        stable_steps = 0

        decay_steps = max(0, num_training_steps - warmup_steps)



    if warmup_steps > 0 and current_step < warmup_steps:

        return float(current_step + 1) / float(warmup_steps)



    if current_step < warmup_steps + stable_steps:

        return 1.0



    if decay_steps <= 0:

        return 1.0



    decay_step = current_step - warmup_steps - stable_steps + 1

    return 1.0 / math.sqrt(decay_step)





def get_wsd_scheduler(

    optimizer: torch.optim.Optimizer,

    num_training_steps: int,

    warmup_ratio: float = 0.1,

    decay_ratio: float = 0.2,

) -> torch.optim.lr_scheduler.LambdaLR:

    return torch.optim.lr_scheduler.LambdaLR(

        optimizer,

        lambda step: _wsd_lr_lambda(step, num_training_steps, warmup_ratio, decay_ratio),

    )

