from torch.optim.lr_scheduler import _LRScheduler
from torch.optim import SGD
import torch
import warnings

class PolynomialLRWarmup(_LRScheduler):
    def __init__(self, optimizer, warmup_iters, total_iters=5, power=1.0, last_epoch=-1, verbose=False):
        super().__init__(optimizer, last_epoch=last_epoch, verbose=verbose)
        self.total_iters = total_iters
        self.power = power
        self.warmup_iters = warmup_iters


    def get_lr(self):
        if not self._get_lr_called_within_step:
            warnings.warn("To get the last learning rate computed by the scheduler, "
                          "please use `get_last_lr()`.", UserWarning)

        if self.last_epoch == 0 or self.last_epoch > self.total_iters:
            return [group["lr"] for group in self.optimizer.param_groups]

        if self.last_epoch <= self.warmup_iters:
            return [base_lr * self.last_epoch / self.warmup_iters for base_lr in self.base_lrs]
        else:        
            l = self.last_epoch
            w = self.warmup_iters
            t = self.total_iters
            decay_factor = ((1.0 - (l - w) / (t - w)) / (1.0 - (l - 1 - w) / (t - w))) ** self.power
        return [group["lr"] * decay_factor for group in self.optimizer.param_groups]

    def _get_closed_form_lr(self):

        if self.last_epoch <= self.warmup_iters:
            return [
                base_lr * self.last_epoch / self.warmup_iters for base_lr in self.base_lrs]
        else:
            return [
                (
                    base_lr * (1.0 - (min(self.total_iters, self.last_epoch) - self.warmup_iters) / (self.total_iters - self.warmup_iters)) ** self.power
                )
                for base_lr in self.base_lrs
            ]

    
if __name__ == "__main__":

    class TestModule(torch.nn.Module):
        def __init__(self) -> None:
            super().__init__()
            self.linear = torch.nn.Linear(32, 32)
        
        def forward(self, x):
            return self.linear(x)

    test_module = TestModule()
    test_module_pfc = TestModule()
    lr_pfc_weight = 1 / 3
    base_lr = 10
    total_steps = 32000
    
    sgd = SGD([
        {"params": test_module.parameters(), "lr": base_lr},
        {"params": test_module_pfc.parameters(), "lr": base_lr * lr_pfc_weight}
        ], base_lr)

    scheduler = PolynomialLRWarmup(sgd, total_steps//10, total_steps, power=2)

    x = []
    y = []
    y_pfc = []
    for i in range(total_steps):
        scheduler.step()
        lr = scheduler.get_last_lr()[0]
        lr_pfc = scheduler.get_last_lr()[1]
        x.append(i)
        y.append(lr)
        y_pfc.append(lr_pfc)

    import matplotlib.pyplot as plt
    fontsize=15
    plt.figure(figsize=(6, 6))
    plt.plot(x, y, linestyle='-', linewidth=2, )
    plt.plot(x, y_pfc, linestyle='-', linewidth=2, )
    plt.xlabel('Iterations')     # x_label
    plt.ylabel("Lr")             # y_label
    plt.savefig("tmp.png", dpi=600, bbox_inches='tight')
