from diffusers import get_cosine_schedule_with_warmup, get_constant_schedule_with_warmup
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR
import math

from diffusion.utils.logger import get_root_logger


def build_lr_scheduler(config, optimizer, train_dataloader, lr_scale_ratio):
    if not config.get('lr_schedule_args', None):
        config.lr_schedule_args = {}
    if config.get('lr_warmup_steps', None):
        config['num_warmup_steps'] = config.get('lr_warmup_steps')  # for compatibility with old version

    logger = get_root_logger()
    logger.info(
        f'Lr schedule: {config.lr_schedule}, ' + ",".join(
            [f"{key}:{value}" for key, value in config.lr_schedule_args.items()]) + '.')
    if config.lr_schedule == 'cosine':
        lr_scheduler = get_cosine_schedule_with_warmup(
            optimizer=optimizer,
            **config.lr_schedule_args,
            num_training_steps=(len(train_dataloader) * config.num_epochs),
        )
    elif config.lr_schedule == 'constant':
        lr_scheduler = get_constant_schedule_with_warmup(
            optimizer=optimizer,
            **config.lr_schedule_args,
        )
    elif config.lr_schedule == 'cosine_decay_to_constant':
        assert lr_scale_ratio >= 1
        lr_scheduler = get_cosine_decay_to_constant_with_warmup(
            optimizer=optimizer,
            **config.lr_schedule_args,
            final_lr=1 / lr_scale_ratio,
            num_training_steps=(len(train_dataloader) * config.num_epochs),
        )
    else:
        raise RuntimeError(f'Unrecognized lr schedule {config.lr_schedule}.')
    return lr_scheduler


def get_cosine_decay_to_constant_with_warmup(optimizer: Optimizer,
                                             num_warmup_steps: int,
                                             num_training_steps: int,
                                             final_lr: float = 0.0,
                                             num_decay: float = 0.667,
                                             num_cycles: float = 0.5,
                                             last_epoch: int = -1
                                             ):
    """
    Create a schedule with a cosine annealing lr followed by a constant lr.

    Args:
        optimizer ([`~torch.optim.Optimizer`]):
            The optimizer for which to schedule the learning rate.
        num_warmup_steps (`int`):
            The number of steps for the warmup phase.
        num_training_steps (`int`):
            The number of total training steps.
        final_lr (`int`):
            The final constant lr after cosine decay.
        num_decay (`int`):
            The
        last_epoch (`int`, *optional*, defaults to -1):
            The index of the last epoch when resuming training.

    Return:
        `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
    """

    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))

        num_decay_steps = int(num_training_steps * num_decay)
        if current_step > num_decay_steps:
            return final_lr

        progress = float(current_step - num_warmup_steps) / float(max(1, num_decay_steps - num_warmup_steps))
        return (
            max(
                0.0,
                0.5 * (1.0 + math.cos(math.pi * num_cycles * 2.0 * progress)),
            )
            * (1 - final_lr)
        ) + final_lr

    return LambdaLR(optimizer, lr_lambda, last_epoch)
