import numpy as np
import PIL.Image
import torch
import tqdm
from diffusers import DDIMInverseScheduler, DDIMScheduler, UNet2DModel, VQModel
from einops import rearrange

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

BATCH_SIZE = 64
DEFAULT_SEED = 420


def get_noises(n_samples, seed=DEFAULT_SEED):
    generator = torch.manual_seed(seed)
    noise = torch.randn(
        (n_samples, 3, 64, 64),
        generator=generator,
    )
    return noise


def generate_samples(
    noise: torch.tensor,
    diffusion_unet: UNet2DModel,
    diffusion_scheduler: DDIMScheduler,
    batch_size: int = BATCH_SIZE,
    device: torch.device = DEVICE,
    from_each_t: bool = False,
):
    n = noise.shape[0]
    all_samples = []
    all_t_samples = []
    for idx_start in tqdm.tqdm(range(0, n, batch_size)):
        samples = noise[idx_start : idx_start + batch_size].to(device)
        t_samples = [samples.clone()] if from_each_t else []
        for t in diffusion_scheduler.timesteps:
            with torch.no_grad():
                residual = diffusion_unet(samples, t)["sample"]
            samples = diffusion_scheduler.step(residual, t, samples)["prev_sample"]
            if from_each_t:
                t_samples.append(samples.clone())
        all_samples.append(samples.cpu())
        if from_each_t:
            t_samples = torch.stack(t_samples).cpu()
            t_samples = rearrange(t_samples, "t b c h w -> b t c h w")
            all_t_samples.append(t_samples)

    all_samples = torch.cat(all_samples)
    if from_each_t:
        all_t_samples = torch.cat(all_t_samples)
        all_t_samples = rearrange(all_t_samples, "b t c h w -> t b c h w")
        return all_samples, all_t_samples
    else:
        return all_samples


def generate_latents(
    samples: torch.tensor,
    diffusion_unet: UNet2DModel,
    diffusion_scheduler: DDIMInverseScheduler,
    batch_size: int = BATCH_SIZE,
    device: torch.device = DEVICE,
    from_each_t: bool = False,
):
    n = samples.shape[0]
    all_latents = []
    all_t_latents = []
    for idx_start in tqdm.tqdm(range(0, n, batch_size)):
        latents = samples[idx_start : idx_start + batch_size].to(device)
        t_latents = [latents.clone()] if from_each_t else []
        for t in diffusion_scheduler.timesteps:
            with torch.no_grad():
                residual = diffusion_unet(latents, t)["sample"]
            latents = diffusion_scheduler.step(residual, t, latents)["prev_sample"]
            if from_each_t:
                t_latents.append(latents.clone())
        all_latents.append(latents.cpu())
        if from_each_t:
            t_latents = torch.stack(t_latents).cpu()
            t_latents = rearrange(t_latents, "t b c h w -> b t c h w")
            all_t_latents.append(t_latents)

    all_latents = torch.cat(all_latents)
    if from_each_t:
        all_t_latents = torch.cat(all_t_latents)
        all_t_latents = rearrange(all_t_latents, "b t c h w -> t b c h w")
        return all_latents, all_t_latents
    return all_latents


def decode_image(unet_out: torch.tensor, vqvae: VQModel, batch_size: int = BATCH_SIZE, device: torch.device = DEVICE):
    n = unet_out.shape[0]
    outs = []
    for idx_start in tqdm.tqdm(range(0, n, batch_size)):
        u_outs = unet_out[idx_start : idx_start + batch_size].to(device)
        with torch.no_grad():
            outs.append(vqvae.decode(u_outs).sample.cpu())
    return torch.cat(outs)


def to_image(vqvae_out):
    image_processed = vqvae_out.cpu().permute(0, 2, 3, 1)
    image_processed = (image_processed + 1.0) * 127.5
    image_processed = image_processed.clamp(0, 255).numpy().astype(np.uint8)
    return [PIL.Image.fromarray(img_processed) for img_processed in image_processed]
