def get_alpha(epoch: int,
              total_epochs: int,
              max_alpha: float = 0.7,
              warmup_frac: float = 0.0,
              ramp_frac: float = 0.5) -> float:
    """
    Linear schedule for alpha.
    - warmup_frac:  data portion of training with alpha = 0
    - ramp_frac: data portion over which alpha linearly climbs to max_alpha
    """
    warmup_epochs = int(total_epochs * warmup_frac)
    ramp_epochs   = int(total_epochs * ramp_frac)

    if epoch < warmup_epochs:                       
        return 0.0
    elif epoch < warmup_epochs + ramp_epochs:
        progress = (epoch - warmup_epochs) / max(ramp_epochs, 1)
        return max_alpha * progress
    else: 
        return max_alpha
