
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import CosineAnnealingLR, SequentialLR
from transformers import get_constant_schedule_with_warmup

optimizer = torch.optim.Adam(
    nn.Linear(1, 1).parameters(),
    lr=0.0001
)


# chain warmup with cosine
scheduler1 = get_constant_schedule_with_warmup(optimizer, 100)
scheduler2 = CosineAnnealingLR(optimizer, T_max=900, eta_min=0.00001)

scheduler = SequentialLR(optimizer, schedulers=[scheduler1, scheduler2], milestones=[100])



# plot the learning rate schedule
import matplotlib.pyplot as plt
lrs = []
for i in range(1000):
    optimizer.step()
    scheduler.step()
    lrs.append(optimizer.param_groups[0]["lr"])
    print(lrs[-1])

# scheduler.load_state_dict(scheduler.state_dict())

plt.plot(lrs)
plt.savefig('lr_schedule.pdf')