import torch
import math

class WarmUpCosineLR(torch.optim.lr_scheduler._LRScheduler):
    def __init__(self, optimizer, warmup_steps, total_steps, warmup_lr, max_lr, min_lr=0.0, last_step=-1):
        self.warmup_steps = warmup_steps
        self.total_steps = total_steps
        self.warmup_lr = warmup_lr
        self.max_lr = max_lr
        self.min_lr = min_lr
        super(WarmUpCosineLR, self).__init__(optimizer, last_step)

    def get_lr(self):
        if self.last_epoch < self.warmup_steps:
            # Linear warmup
            return [(self.warmup_lr + (self.max_lr - self.warmup_lr) * (self.last_epoch / self.warmup_steps))
                    for _ in self.optimizer.param_groups]
        else:
            # Cosine annealing
            cos_inner = (self.last_epoch - self.warmup_steps) / (self.total_steps - self.warmup_steps) * math.pi
            cos_out = math.cos(cos_inner) + 1
            return [(self.min_lr + 0.5 * (self.max_lr - self.min_lr) * cos_out)
                    for _ in self.optimizer.param_groups]




if __name__ == "__main__":
    import torch
    import torch.nn as nn
    import torch.optim as optim
    import matplotlib.pyplot as plt

    warmup_steps = 100
    total_steps = 1000
    warmup_lr = 0.000001
    max_lr = 0.00001
    min_lr = 0.0000001

    model = nn.Linear(10, 10)
    optimizer = optim.SGD(model.parameters(), lr=max_lr)
    scheduler = WarmUpCosineLR(optimizer, warmup_steps, total_steps, warmup_lr, max_lr, min_lr)
    # Test the scheduler
    lrs = []
    for epoch in range(total_steps):
        scheduler.step()
        lr = scheduler.get_lr()[0]
        lrs.append(lr)
    plt.plot(lrs)
    plt.show()
