from diffusers import UNet2DModel_H, DDPMScheduler, DDIMScheduler, VQModel, DDIMInverseScheduler
from torchvision.utils import save_image, make_grid
from torchvision.io import read_image
import torch
import PIL.Image
import numpy as np
import tqdm
import matplotlib.pyplot as plt
import argparse, ast

def slerp(t, v0, v1, DOT_THRESHOLD=0.9995):
    """helper function to spherically interpolate two arrays v1 v2"""

    inputs_are_torch = isinstance(v0, torch.Tensor)
    if inputs_are_torch:
        input_device = v0.device
        v0 = v0.cpu().numpy()
        v1 = v1.cpu().numpy()
        t = t.cpu().numpy()

    dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
    if np.abs(dot) > DOT_THRESHOLD:
        v2 = (1 - t) * v0 + t * v1
    else:
        theta_0 = np.arccos(dot)
        sin_theta_0 = np.sin(theta_0)
        theta_t = theta_0 * t
        sin_theta_t = np.sin(theta_t)
        s0 = np.sin(theta_0 - theta_t) / sin_theta_0
        s1 = sin_theta_t / sin_theta_0
        v2 = s0 * v0 + s1 * v1

    if inputs_are_torch:
        v2 = torch.from_numpy(v2).to(input_device)

    return v2

def image_process(image):
    image_processed = image.permute(0, 2, 3, 1)
    image_processed = (image_processed + 1.0) * 127.5
    image_processed = image_processed.clamp(0, 255).type(torch.uint8)

    return image_processed

def arg_as_list(s):
    return ast.literal_eval(s)

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--save_path",
        type=str,
        default="output/default/",
        help="The config of the Dataset, leave as None if there's only one config.",
    )
    parser.add_argument(
        "--unet_path",
        type=str,
        default="google/ddpm-celebahq-256",
        help="The config of the UNet model to train, leave as None to use standard DDPM configuration.",
    )
    parser.add_argument(
        "--seeds",
        type=arg_as_list,
        default=[],
        help="Random seeds",
    )
    parser.add_argument(
        "--step",
        type=float,
        default=0.1,
        help="Random seeds",
    )
    parser.add_argument("--interp_cond", default=False, action="store_true", help="interpolation")
    parser.add_argument("--image_save", default=False, action="store_true", help="save image")
    parser.add_argument("--find_seeds", default=False, action="store_true", help="find unstable seeds")

    args = parser.parse_args()

    return args
##############################################################################

def main():
    args = parse_args()

    # seeds = range(0, 200) #[197, 138, 160, 154]
    seeds = args.seeds
    noises = []
    h_features = {seed:[] for seed in seeds}
    interp_cond = args.interp_cond
    image_save = args.image_save
    path = args.save_path

    print(f"Saving path= {path}")

    # load all models
    # unet = UNet2DModel_H.from_pretrained("anton-l/ddpm-ema-flowers-64", subfolder='unet')
    # vqvae = VQModel.from_pretrained("CompVis/ldm-celebahq-256", subfolder="vqvae")
    unet = UNet2DModel_H.from_pretrained(args.unet_path)
    scheduler = DDIMScheduler.from_pretrained("google/ddpm-celebahq-256")
    inverse_scheduler = DDIMInverseScheduler.from_pretrained("google/ddpm-celebahq-256")
    # inverse_scheduler.final_alpha_cumprod = inverse_scheduler.alphas_cumprod[-1]
    scheduler.set_timesteps(num_inference_steps=20)
    inverse_scheduler.set_timesteps(num_inference_steps=20)

    torch_device = "cuda" if torch.cuda.is_available() else "cpu"
    unet.to(torch_device)

    if image_save:
        for seed in seeds:
            generator = torch.Generator(device='cuda')
            generator.manual_seed(seed)
            noise = torch.randn(
                (2, unet.config.in_channels, unet.sample_size, unet.sample_size),
                device = torch_device, generator=generator,
            )
            noises.append(noise)

            image = noise
            fig, axes = plt.subplots(nrows=2 ,ncols=5, figsize=(6, 5))

            # image = read_image('image_seed/image_45.png').to(torch_device).to(torch.float).unsqueeze(0) / 127.5 - 1 

            for t in scheduler.timesteps:
                with torch.no_grad():
                    residual = unet(image, t)[0]["sample"]
                    h_feature = unet(image, t)[1]
                    h_features[seed].append(h_feature)

                prev_image = scheduler.step(residual, t, image, eta=0.0)["prev_sample"]
                image = prev_image

            if image_save:
                image_processed = image.cpu().permute(0, 2, 3, 1)
                image_processed = (image_processed + 1.0) * 127.5
                image_processed = image_processed.clamp(0, 255).numpy().astype(np.uint8)
                image_pil = PIL.Image.fromarray(image_processed[0])
                image_pil.save(path + f"/image_{seed}.png")
                image_pil = PIL.Image.fromarray(image_processed[1])
                image_pil.save(path + f"/image_{seed}_next.png")

            # for t in inverse_scheduler.timesteps[:10]:
            #     with torch.no_grad():
            #         residual = unet(image, t)[0]["sample"]
            #         h_feature = unet(image, t)[1]
            #         h_features[seed].append(h_feature)

            #     prev_image = inverse_scheduler.step(residual, t, image, eta=0.0)["prev_sample"]
            #     image = prev_image

            # if image_save:
            #     image_processed = image.cpu().permute(0, 2, 3, 1)
            #     image_processed = (image_processed + 1.0) * 127.5
            #     image_processed = image_processed.clamp(0, 255).numpy().astype(np.uint8)
            #     image_pil = PIL.Image.fromarray(image_processed[0])
            #     image_pil.save(path + f"/inverse_image_{seed}.png")

            # # image = slerp(0.99, noise, image)

            # for t in scheduler.timesteps[-10:]:
            #     with torch.no_grad():
            #         residual = unet(image, t)[0]["sample"]
            #         h_feature = unet(image, t)[1]
            #         h_features[seed].append(h_feature)

            #     prev_image = scheduler.step(residual, t, image, eta=0.0)["prev_sample"]
            #     image = prev_image

            # if image_save:
            #     image_processed = image.cpu().permute(0, 2, 3, 1)
            #     image_processed = (image_processed + 1.0) * 127.5
            #     image_processed = image_processed.clamp(0, 255).numpy().astype(np.uint8)
            #     image_pil = PIL.Image.fromarray(image_processed[0])
            #     image_pil.save(path + f"/recovered_image_{seed}.png")


    if interp_cond:
        for n, seed in enumerate(seeds):
            images_h_lerp = []
            images_h_slerp = []
            images_x_slerp = []
            images_x_lerp = []
            # print(f"Currently interpolating from n={n}, seed={seed}")
            # generate gaussian noise to be decoded
            generator = torch.Generator(device='cuda')
            generator.manual_seed(seed)
            noise = torch.randn(
                (2, unet.config.in_channels, unet.sample_size, unet.sample_size),
                device = torch_device, generator=generator,
            )
            z0, z1 = noise[0].unsqueeze(0), noise[1].unsqueeze(0)

            step = args.step
            for w in torch.arange(1 - 10 * step + 1e-4, 1 + step, step, device = torch_device):
                # image_h_lerp = slerp(w, noises[n], noises[n-1])
                # image_h_slerp = slerp(w, noises[n], noises[n-1])
                image_x_slerp = slerp(w, z0, z1) # image_x_slerp = slerp(w, noises[n], noises[n-1])
                # image_x_lerp = torch.lerp(z0, z1, w)
                for i, t in enumerate(scheduler.timesteps):
                    with torch.no_grad():
                        # residual_h_lerp = unet(image_h_lerp, t, h_feature_in= torch.lerp(h_features[seeds[n]][i], h_features[seeds[n-1]][i], w) )[0]["sample"]
                        # residual_h_slerp = unet(image_h_slerp, t, h_feature_in= slerp(w, h_features[seeds[n]][i], h_features[seeds[n-1]][i]) )[0]["sample"]
                        residual_x_slerp = unet(image_x_slerp, t)[0]["sample"]
                        # residual_x_lerp = unet(image_x_lerp, t)[0]["sample"]

                    # image_h_lerp = scheduler.step(residual_h_lerp, t, image_h_lerp, eta=0.0)["prev_sample"]
                    # image_h_slerp = scheduler.step(residual_h_slerp, t, image_h_slerp, eta=0.0)["prev_sample"]
                    image_x_slerp = scheduler.step(residual_x_slerp, t, image_x_slerp, eta=0.0)["prev_sample"]
                    # image_x_lerp = scheduler.step(residual_x_lerp, t, image_x_lerp, eta=0.0)["prev_sample"]

                # images_h_lerp.append(image_h_lerp/2 + 0.5)
                # images_h_slerp.append(image_h_slerp/2 + 0.5)
                # save_image(image_x_slerp/2 + 0.5, path + f"/x_slerp_{seed}->{seed}_w={w:.2f}.png")
                images_x_slerp.append(image_x_slerp/2 + 0.5)
                # images_x_lerp.append(image_x_lerp/2 + 0.5)

            # grid_h_lerp = make_grid(torch.cat(images_h_lerp, 0), nrow=int(1/step) + 1, padding=0)
            # grid_h_slerp = make_grid(torch.cat(images_h_slerp, 0), nrow=int(1/step) + 1, padding=0)
            grid_x_slerp = make_grid(torch.cat(images_x_slerp, 0), nrow=int(1/step) + 1, padding=0)
            # grid_x_lerp = make_grid(torch.cat(images_x_lerp, 0), nrow=int(1/step) + 1, padding=0)
            
            # save_image(grid_h_lerp, path + f"/h_lerp_{seeds[n]}->{seeds[n-1]}.png")
            # save_image(grid_h_slerp, path + f"/h_slerp_{seeds[n]}->{seeds[n-1]}.png")
            save_image(grid_x_slerp, path + f"/x_slerp_{seed}->{seed}.png")
            # save_image(grid_x_lerp, path + f"/x_lerp_{seed}->{seed}.png")
            print(f'Saved {seed}->{seed}.png to {args.save_path}.')

            if n==len(seeds)-1:
                break

    print("Done!")

if __name__ == '__main__':
    main()