import torch
from diffusers import DDIMScheduler, DDPMScheduler, UNet2DModel
from torch.utils.data import DataLoader

from torchvision.datasets import CIFAR10, CIFAR100, ImageFolder
from torchvision.transforms import transforms as T

from tqdm import tqdm
import time

import os

import numpy as np

def denoise(mu, C, model_id, local_model, dataset, data_folder, batch_size, num_iterations, ddpm_rec=False, float16=False):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(device)

    output_folder = "output"

    if float16:
        torch.set_default_dtype(torch.float16)
        dtype = torch.float16
    else:
        dtype = torch.float32

    # Load dataset
    if (dataset == "ImageNet") or (dataset == "CheXpert"):
        transforms = T.Compose([
            T.ToTensor(),
            T.Resize(size=256, antialias=True),
            T.CenterCrop(256),
            T.Normalize((0.5,0.5,0.5,), (0.5,0.5,0.5,))
        ])
    else:
        transforms = T.Compose([
            T.ToTensor(),
            T.Normalize((0.5,0.5,0.5,), (0.5,0.5,0.5,))
        ])

    if dataset == "CIFAR10":
        test_set = CIFAR10(root=data_folder, train=False, transform=transforms, download=True)
    elif dataset == "CIFAR100":
        test_set = CIFAR100(root=data_folder, train=False, transform=transforms, download=True)
    elif dataset in ["CelebA", "ImageNet", "CheXpert"]:
        test_set = ImageFolder(root=data_folder, transform=transforms)

    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)

    if local_model:
        model = UNet2DModel.from_pretrained(model_id, subfolder="unet", torch_dtype=dtype)
        if ddpm_rec:
            noise_scheduler = DDPMScheduler.from_pretrained(model_id, subfolder="scheduler")
        else:
            noise_scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
    else:
        model = UNet2DModel.from_pretrained(model_id, torch_dtype=dtype)
        if ddpm_rec:
            noise_scheduler = DDPMScheduler.from_pretrained(model_id)
        else:
            noise_scheduler = DDIMScheduler.from_pretrained(model_id)

    model.to(device)

    schedule = compute_schedule(noise_scheduler.config.beta_start, noise_scheduler.config.beta_end, noise_scheduler.config.num_train_timesteps)
    sigma_t = schedule["sigma_t"]
    sigma_t = sigma_t.to(device)

    n_T = noise_scheduler.config.num_train_timesteps
    sigma_val = C/mu

    model.eval()
    for i,(x,_) in tqdm(enumerate(test_loader)):
        if i>=num_iterations:
            break
        start = time.time()
        print(f"------------\nBatch {i}, Sigma {sigma_val}, C {C}, Mu {mu}\n------------")
        x = x.to(device)
        kappa = get_kappa(x, C)

        sigma = torch.ones(batch_size, device=kappa.device)*sigma_val*C*kappa
        T_start = torch.max(sigma_t.view(1,-1) > sigma.view(-1,1), -1)[1]

        noise = torch.randn_like(x)
        x_noisy = noise_scheduler.add_noise(x, noise, T_start)
        x_noisy_scaled = (1+sigma_t[T_start]**2).sqrt().view(-1, 1, 1, 1)*x_noisy

        image_denoised = denoise_image_ddim(model, x_noisy, schedule, T_start, n_T)
        
        # Save images as numpy arrays to disk
        np.savez_compressed(
            os.path.join(output_folder, f"mu_{mu:04d}_C_{int(C)}_batch_{i}.npz"), 
            img=x.cpu().numpy(), 
            noisy=x_noisy_scaled.cpu().numpy(),
            denoised=image_denoised.cpu().numpy()
        )

        end = time.time()
        print(f"Finished batch after {(end - start):.4f} seconds")

def get_kappa(x, C):
    kappa = torch.clamp(x.view(x.size(0), -1).norm(dim=-1)/C, min=1.0)
    return kappa

@torch.no_grad()
def denoise_image_ddim(model, x_noisy, schedule, T_start):
    alphabar_t = schedule["alphabar_t"]

    x_i = x_noisy

    max_T = T_start.max().item()

    for i in range(max_T, 1, -1):
        x_i_part = x_i[T_start >= i]

        eps = model(
            x_i_part, torch.tensor(i).to(x_i_part.device).repeat(x_i_part.shape[0])
        ).sample

        x0_t = (x_i_part - eps * (1 - alphabar_t[i]).sqrt()) / alphabar_t[i].sqrt() # "Predicted x_0"

        c2 = ((1 - alphabar_t[i - 1])).sqrt()
        x_i_part = alphabar_t[i - 1].sqrt() * x0_t + c2 * eps

        x_i[T_start >= i] = x_i_part

    return x_i

@torch.no_grad()
def denoise_image_ddpm(model, x_noisy, schedule, T_start):
    sqrt_beta_t = schedule["sqrt_beta_t"]
    oneover_sqrta = schedule["oneover_sqrta"]
    mab_over_sqrtmab = schedule["mab_over_sqrtmab"]

    x_i = x_noisy

    max_T = T_start.max().item()

    for i in range(max_T, 0, -1):
        x_i_part = x_i[T_start >= i]

        z = torch.randn_like(x_i_part) if i > 1 else 0
        eps = model(
            x_i_part, torch.tensor(i).to(x_i_part.device).repeat(x_i_part.shape[0])
        ).sample
        x_i_part = (
            oneover_sqrta[i] * (x_i_part - eps * mab_over_sqrtmab[i])
            + sqrt_beta_t[i] * z
        )

        x_i[T_start >= i] = x_i_part

    return x_i

def compute_schedule(beta1: float, beta2: float, T: int):
    """
    Returns pre-computed schedules for sampling, training process.
    """
    assert beta1 < beta2 < 1.0, "beta1 and beta2 must be in (0, 1)"

    beta_t = (beta2 - beta1) * torch.arange(0, T + 1, dtype=torch.float32) / T + beta1
    sqrt_beta_t = torch.sqrt(beta_t)
    alpha_t = 1 - beta_t
    log_alpha_t = torch.log(alpha_t)
    alphabar_t = torch.cumsum(log_alpha_t, dim=0).exp()

    sqrtab = torch.sqrt(alphabar_t)
    oneover_sqrta = 1 / torch.sqrt(alpha_t)

    sqrtmab = torch.sqrt(1 - alphabar_t)
    mab_over_sqrtmab_inv = (1 - alpha_t) / sqrtmab

    sigma_t = ((1/alphabar_t)-1).sqrt()

    return {
        "sigma_t" : sigma_t,
        "beta_t" : beta_t,
        "alpha_t": alpha_t,  # \alpha_t
        "oneover_sqrta": oneover_sqrta,  # 1/\sqrt{\alpha_t}
        "sqrt_beta_t": sqrt_beta_t,  # \sqrt{\beta_t}
        "alphabar_t": alphabar_t,  # \bar{\alpha_t}
        "sqrtab": sqrtab,  # \sqrt{\bar{\alpha_t}}
        "sqrtmab": sqrtmab,  # \sqrt{1-\bar{\alpha_t}}
        "mab_over_sqrtmab": mab_over_sqrtmab_inv,  # (1-\alpha_t)/\sqrt{1-\bar{\alpha_t}}
    }


if __name__ == "__main__":
    batch_size = 200
    num_iterations = 5000//batch_size

    C = 1.0
    mus = [100, 50, 30, 20, 10, 5, 3, 2, 1]

    # Select Diffusors model
    model_id = "google/ddpm-ema-celebahq-256" # "/local_models/ddpm_ema_cifar10", "google/ddpm-ema-celebahq-256", "xutongda/adm_imagenet_256x256_unconditional"
    local_model = False # If model is local or from HuggingFace

    # Select test data
    dataset = "CelebA" # "CIFAR10", "CIFAR100", "CelebA", "ImageNet", "CheXpert"
    data_folder = "data/celebaHQ_256" #  "data/CIFAR10", "data/CIFAR100", "data/ImageNet1k", "data/chexpert", "data/celebaHQ_256"

    ddpm_rec = False

    for mu in mus:
        denoise(
            mu=mu,
            C=C,
            model_id=model_id,
            local_model=local_model,
            dataset=dataset,
            data_folder=data_folder,
            batch_size=batch_size,
            num_iterations=num_iterations,
            ddpm_rec=False,
            float16=False
        )