import torch
import numpy as np


def gen_calib_dit(pipeline, steps=250, tot=1000):
    # xs - steps, batch, 4, 64, 64 - 250, 1000, 4, 64, 64
    # ts - steps, batch - 250, 1000 LONG
    # cs - steps, batch - 250, 1000 LONG
    if tot < 1000:
        tot = list(np.linspace(0, 1000, tot))
        tot = [int(t) for t in tot]
    elif tot > 1000:
        tot = list(range(1000))
    else:
        tot = list(range(1000))

    latents = torch.randn(steps, len(tot), 4, 64, 64)
    ts = torch.zeros(steps, len(tot)).long()
    for step in range(steps):
        ts[step, :] += step
    cs = torch.zeros(steps, len(tot)).long()

    # xs
    noisy_latent_list = []
    for c in tot:
        cs[:, c] += c
        noisy_latent = pipeline.scheduler.add_noise(latents[:, c, :, :, :].cuda(),
                                                    torch.randn_like(latents[:, c, :, :, :]).cuda(),
                                                    ts[:, c].cuda())
        noisy_latent = pipeline.scheduler.scale_model_input(noisy_latent, ts[:, c].cuda())
        noisy_latent_list.append(noisy_latent.unsqueeze(1).detach().cpu())
    return torch.cat(noisy_latent_list, dim=1), ts, cs

if __name__ == "__main__":
    from diffusers import DiTPipeline, DPMSolverMultistepScheduler
    model = DiTPipeline.from_pretrained("facebook/DiT-XL-2-512", torch_dtype=torch.float16) # variant="fp16" is not needed for version >= 0.24.0
    model.scheduler = DPMSolverMultistepScheduler.from_config(model.scheduler.config)