from __future__ import annotations
import math


def get_tau(step: int, total_steps: int, kind: str, t0: float, t1: float) -> float:
    alpha = min(1.0, step / max(1, total_steps))
    if kind == "linear":
        return t0 + (t1 - t0) * alpha
    if kind == "exp":
        return t0 * (t1 / max(1e-8, t0)) ** alpha
    if kind == "cosine":
        return t1 + 0.5 * (t0 - t1) * (1 + math.cos(math.pi * alpha))
    raise ValueError(f"Unknown tau_schedule {kind}")
