from diffusers import UNet2DModel_H, DDPMScheduler, DDIMScheduler, VQModel, DDIMInverseScheduler, DDPMPipeline_H
from torchvision.utils import save_image, make_grid
from torchvision.io import read_image
from ddpm_unrolled import slerp, image_process, parse_args
from torch import lerp
from eval import find_seeds_2d
from utils import local_basis

import torch
import PIL.Image
import numpy as np
import tqdm
import matplotlib.pyplot as plt
import argparse, ast

args = parse_args()

seeds = range(100, 200) 
seeds = args.seeds
noises = []
h_features = {seed:[] for seed in seeds}
interp_cond = args.interp_cond
find_seeds = args.find_seeds
image_save = args.image_save
path = args.save_path
step = 0.5
N = 10

print(f"Saving path= {path}")
torch_device = "cuda" if torch.cuda.is_available() else "cpu"

# unet = UNet2DModel_H.from_pretrained("anton-l/ddpm-ema-flowers-64", subfolder='unet')
# vqvae = VQModel.from_pretrained("CompVis/ldm-celebahq-256", subfolder="vqvae")
unet = UNet2DModel_H.from_pretrained(args.unet_path).to(torch_device)
# unet_default = UNet2DModel_H.from_pretrained('output_jae/test_22/checkpoint-1000/unet').to(torch_device)
scheduler = DDIMScheduler.from_pretrained("google/ddpm-celebahq-256")
inverse_scheduler = DDIMInverseScheduler.from_pretrained("google/ddpm-celebahq-256")
scheduler.set_timesteps(num_inference_steps=20)
# inverse_scheduler.final_alpha_cumprod = inverse_scheduler.alphas_cumprod[-1]
# inverse_scheduler.set_timesteps(num_inference_steps=20)
# unet.to(torch_device)

if find_seeds:
    N = len(seeds)
    sampling_shape = [3, unet.config.in_channels, unet.sample_size, unet.sample_size]
    pipeline = DDPMPipeline_H(unet=unet, scheduler=scheduler)
    pipeline.set_progress_bar_config(disable=True)
    find_seeds_2d(n_samples=N, n_gpus=1, sampling_shape=sampling_shape, num_inference_steps=N, sampler=pipeline, device=torch_device)

if image_save:
    for seed in seeds:
        generator = torch.Generator(device='cuda')
        generator.manual_seed(seed)
        noise = torch.randn(
            (2, unet.config.in_channels, unet.sample_size, unet.sample_size),
            device = torch_device, generator=generator,
        )
        noises.append(noise)

        image = noise
        fig, axes = plt.subplots(nrows=2 ,ncols=5, figsize=(6, 5))

        for t in scheduler.timesteps:
            with torch.no_grad():
                residual = unet(image, t)[0]["sample"]
                h_feature = unet(image, t)[1]
                h_features[seed].append(h_feature)

            prev_image = scheduler.step(residual, t, image, eta=0.0)["prev_sample"]
            image = prev_image

        if image_save:
            image_processed = image.cpu().permute(0, 2, 3, 1)
            image_processed = (image_processed + 1.0) * 127.5
            image_processed = image_processed.clamp(0, 255).numpy().astype(np.uint8)
            image_pil = PIL.Image.fromarray(image_processed[0])
            image_pil.save(path + f"/image_{seed}.png")
            image_pil = PIL.Image.fromarray(image_processed[1])
            image_pil.save(path + f"/image_{seed}_next.png")


if interp_cond:
    for n, seed in enumerate(seeds):
        images_h_lerp = []
        images_h_slerp = []
        images_x_slerp = []
        images_x_lerp = []

        generator = torch.Generator(device='cuda')
        generator.manual_seed(seed)
        noise = torch.randn(
            (1, unet.config.in_channels, unet.sample_size, unet.sample_size),
            device = torch_device, generator=generator,
        )
        # z0, z1, z2 = noise.split(1, 0)
        # a=0.5
        # z1 = slerp(torch.tensor(a), z0, z1)
        # z2 = slerp(torch.tensor(a), z0, z2)

        z0 = noise
        sample, h_basis, s, x_basis = local_basis(unet, pooling_kernel=4, sample=z0)
        a = 3
        
        i = 20  # row
        j = 3  # column

        # dx = x_basis[:,:,:,i]/s[i] 
        # dy = x_basis[:,:,:,j]/s[j] 
        # z0 = z0 - a * dx - a * dy
        # z1 = z0 + 2*a*dy  
        # z2 = z0 + 2*a*dx 

        # for v in torch.arange(0, 1 + step, step, device = torch_device):
        #     for w in torch.arange(0, 1 + step, step, device = torch_device):
        #         # image_h_lerp = slerp(w, noises[n], noises[n-1])
        #         # image_h_slerp = slerp(w, noises[n], noises[n-1])
        #         image_x_slerp_mid = slerp(torch.tensor(0.5), slerp(v, z0, z1), slerp(w, z0, z2))
        #         image_x_slerp = slerp(torch.tensor(2), z0, image_x_slerp_mid)
        #         # image_x_lerp = lerp(lerp(z0, z1, v), z2, w)

        dx = x_basis[:,:,:,i]/s[i] 
        dy = x_basis[:,:,:,j]/s[j]

        for v in torch.arange(-1, 1 + step, step, device = torch_device):
            for w in torch.arange(-1, 1 + step, step, device = torch_device):
                z1 = z0 + a*v*dy 
                x = z1 + a*w*dx
                images_x_slerp.append(x)

        image_x_slerp = torch.cat(images_x_slerp, dim=0)

        for i, t in enumerate(scheduler.timesteps):
            with torch.no_grad():
                # residual_h_lerp = unet(image_h_lerp, t, h_feature_in= torch.lerp(h_features[seeds[n]][i], h_features[seeds[n-1]][i], w) )[0]["sample"]
                # residual_h_slerp = unet(image_h_slerp, t, h_feature_in= slerp(w, h_features[seeds[n]][i], h_features[seeds[n-1]][i]) )[0]["sample"]
                residual_x_slerp = unet(image_x_slerp, t)[0]["sample"]
                # residual_x_lerp = unet(image_x_lerp, t)[0]["sample"]

            # image_h_lerp = scheduler.step(residual_h_lerp, t, image_h_lerp, eta=0.0)["prev_sample"]
            # image_h_slerp = scheduler.step(residual_h_slerp, t, image_h_slerp, eta=0.0)["prev_sample"]
            image_x_slerp = scheduler.step(residual_x_slerp, t, image_x_slerp, eta=0.0)["prev_sample"]
            # image_x_lerp = scheduler.step(residual_x_lerp, t, image_x_lerp, eta=0.0)["prev_sample"]

        # images_h_lerp.append(image_h_lerp/2 + 0.5)
        # images_h_slerp.append(image_h_slerp/2 + 0.5)
        images_x_slerp = image_x_slerp/2 + 0.5
        # images_x_lerp.append(image_x_lerp/2 + 0.5)

        # grid_h_lerp = make_grid(torch.cat(images_h_lerp, 0), nrow=int(1/step) + 1, padding=0)
        # grid_h_slerp = make_grid(torch.cat(images_h_slerp, 0), nrow=int(1/step) + 1, padding=0)
        # grid_x_slerp = make_grid(torch.cat(images_x_slerp, 0), nrow=int(1/step) + 1, padding=0)
        grid_x_slerp = make_grid(images_x_slerp, nrow=2 * int(1/step) + 1, padding=0)
        # grid_x_lerp = make_grid(torch.cat(images_x_lerp, 0), nrow=int(1/step) + 1, padding=0)
        
        # save_image(grid_h_lerp, path + f"/h_lerp_{seeds[n]}->{seeds[n-1]}.png")
        # save_image(grid_h_slerp, path + f"/h_slerp_{seeds[n]}->{seeds[n-1]}.png")
        save_image(grid_x_slerp, path + f"/2d_grid_x_slerp_{seed}->{seed}.png")
        # save_image(grid_x_lerp, path + f"/x_lerp_{seed}->{seed}.png")
        print(f'Saved {seed}->{seed}.png to {args.save_path}.')

        if n==len(seeds)-1:
            break

print("Done!")