"""Learning rate lambda schedulers"""

from collections.abc import Callable
from math import cos, pi
from typing import Literal


def get_scheduler(
    kind: Literal["linear", "cte", "cosine"] = "linear", *, warmup_steps: int = 0, total_steps: int = 1
) -> Callable[[int], float]:
    match kind:
        case "linear":
            return lambda step: linear_scheduler(step, warmup_steps=warmup_steps, total_steps=total_steps)
        case "cte":
            return lambda step: cte_scheduler(step, warmup_steps=warmup_steps, total_steps=total_steps)
        case "cosine":
            return lambda step: cosine_scheduler(step, warmup_steps=warmup_steps, total_steps=total_steps)
        case _:
            msg = f"Unknown scheduler type: {kind}"
            raise ValueError(msg)


def linear_scheduler(step: int, warmup_steps: int, total_steps: int) -> float:
    if warmup_steps and step < warmup_steps:
        return step / warmup_steps
    progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
    return max(0.0, 1.0 - progress)


def cte_scheduler(step: int, warmup_steps: int, total_steps: int) -> float:  # noqa: ARG001
    if warmup_steps and step < warmup_steps:
        return step / warmup_steps
    return 1.0


def cosine_scheduler(step: int, warmup_steps: int, total_steps: int) -> float:
    if warmup_steps and step < warmup_steps:
        return step / warmup_steps
    progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
    cosine = (1.0 + cos(pi * progress)) * 0.5  # cosine domain [1,-1] -> +1 [2,0] -> *.5 [1,0]
    return max(0, cosine)
