import torch
from diffusers import DDIMInverseScheduler, DDIMScheduler, UNet2DModel, VQModel

DEFAULT_T = 100
CELEBA_LDM_256 = "CompVis/ldm-celebahq-256"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


def get_ldm_celeba(device: torch.device = DEVICE):
    unet = UNet2DModel.from_pretrained(CELEBA_LDM_256, subfolder="unet")
    vae = VQModel.from_pretrained(CELEBA_LDM_256, subfolder="vqvae")
    return unet.to(device), vae.to(device)


def get_scheduler(T: int = DEFAULT_T) -> DDIMScheduler:
    scheduler = DDIMScheduler.from_config(CELEBA_LDM_256, subfolder="scheduler")
    scheduler.set_timesteps(num_inference_steps=T)
    return scheduler


def get_inv_scheduler(T: int = DEFAULT_T) -> DDIMInverseScheduler:
    scheduler = DDIMInverseScheduler.from_config(CELEBA_LDM_256, subfolder="scheduler")
    scheduler.set_timesteps(num_inference_steps=T)
    return scheduler
