import torch

def linear_scheduler(t:torch.Tensor, start:float, end:float)->torch.Tensor:
    assert (t <= 1.0).all()
    assert (0 <= t).all()

    return start + t*(end-start)

def exponential_scheduler(t:torch.Tensor, start:float, end:float)->torch.Tensor:
    assert (t <= 1.0).all()
    assert (0 <= t).all()
    assert (end*start) > 0.0 # start and end both have the same sign and are not 0

    return start*(end/start)**t

def oscillating_scheduler(t:torch.Tensor, start:float, end:float, num_cycles:int=1)->torch.Tensor:
    assert (t <= 1.0).all()
    assert (0 <= t).all()
    assert num_cycles > 0

    return start + (end-start)*torch.sin(t*(torch.pi/2 + (num_cycles-1)*torch.pi))**2

def linear_oscillating_scheduler(t:torch.Tensor, start:float, end:float,num_cycles:int=1)->torch.Tensor:
    return oscillating_scheduler(t, 1.0, 0.0, num_cycles) * linear_scheduler(t, start, end)

def exponential_oscillating_scheduler(t:torch.Tensor, start:float, end:float, num_cycles:int=1)->torch.Tensor:
    return oscillating_scheduler(t, 1.0, 0.0, num_cycles) * exponential_scheduler(t, start, end)