import random

import numpy as np
import torch
from tqdm import tqdm

IN_CHANNELS = 3


def generate_noises(number_of_samples, diffusion_args, seed=420, in_channels=IN_CHANNELS, device="cuda"):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    random_noise = torch.randn(
        (number_of_samples, in_channels, diffusion_args["image_size"], diffusion_args["image_size"])
    )
    random_noise = random_noise.to(device)
    return random_noise


def generate_samples(
    random_noises,
    number_of_samples,
    batch_size,
    diffusion_pipeline,
    ddim_model,
    diffusion_args,
    device="cuda",
    from_each_t=False,
):
    ddim_noise_samples = []
    all_t_samples = []
    for i in tqdm(range(0, number_of_samples // batch_size), desc="Generating samples from noises"):
        outs = diffusion_pipeline.ddim_sample_loop(
            ddim_model,
            (batch_size, IN_CHANNELS, diffusion_args["image_size"], diffusion_args["image_size"]),
            clip_denoised=True,
            device=device,
            noise=random_noises[i * batch_size : (i + 1) * batch_size],
            from_each_t=from_each_t,
        )
        if from_each_t:
            sample, t_samples = outs
            all_t_samples.append(torch.stack(t_samples).cpu())
        else:
            sample = outs
        ddim_noise_samples.extend(sample.cpu())
    ddim_samples = torch.stack(ddim_noise_samples)
    if from_each_t:
        return ddim_samples, torch.stack(all_t_samples)
    else:
        return ddim_samples


def generate_latents(ddim_generations, batch_size, diffusion_pipeline, ddim_model, device="cuda", from_each_t=False):
    x = ddim_generations
    latents = []
    all_t_latents = []
    for j in tqdm(range((x.shape[0] // batch_size)), desc="Generating latents from samples"):
        xj = x[j * batch_size : (j + 1) * batch_size]
        timesteps_t_latents = []
        for i in range(diffusion_pipeline.num_timesteps):
            with torch.no_grad():
                xj = xj.to(device)
                t = torch.tensor([i] * xj.shape[0], device=device)
                sample = diffusion_pipeline.ddim_reverse_sample(
                    ddim_model,
                    xj,
                    t,
                    clip_denoised=True,
                )
                xj = sample["sample"]
                if from_each_t:
                    timesteps_t_latents.append(xj.cpu())
        if from_each_t:
            all_t_latents.append(torch.stack(timesteps_t_latents))
        latents.extend(xj.cpu())
    latents = torch.stack(latents)
    if from_each_t:
        return latents, torch.stack(all_t_latents)
    else:
        return latents
