import numpy as np    

# def cosine_scheduler(optimizer, base_lr, num_warmup_steps, total_steps):
#     def _scheduler(current_step):
#         if(current_step < num_warmup_steps):
#             lr = base_lr * (current_step + 1) / num_warmup_steps
#         else:
#             n = current_step - num_warmup_steps
#             d = total_steps - num_warmup_steps
#             lr = 0.5 * (1 + np.cos(np.pi * n / d)) * base_lr

#         for param_group in optimizer.param_groups:
#             param_group["lr"] = lr
            
#     return _scheduler


def cosine_scheduler(optimizer, base_lr, num_warmup_steps, total_steps):
    state_dict = {"current_step": 0}
    
    def _scheduler(current_step):
        state_dict["current_step"] = current_step
        if current_step < num_warmup_steps:
            lr = base_lr * (current_step + 1) / num_warmup_steps
        else:
            n = current_step - num_warmup_steps
            d = total_steps - num_warmup_steps
            lr = 0.5 * (1 + np.cos(np.pi * n / d)) * base_lr

        for param_group in optimizer.param_groups:
            param_group["lr"] = lr

    def save_state_dict():
        return state_dict

    def load_state_dict(new_state_dict):
        nonlocal state_dict
        state_dict = new_state_dict
        _scheduler(state_dict["current_step"])
        
    _scheduler.save_state_dict = save_state_dict
    _scheduler.load_state_dict = load_state_dict
    _scheduler.state_dict = lambda: state_dict

    return _scheduler