from PIL import Image
import torch
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler
from diffusers import UniPCMultistepScheduler
import copy

import wandb

import sys
sys.path.append("/home//work/doob_apps/hug")

from src.models.SD_model import DiffusionPipeline
from src.models.SD_preference import preference_loss

def main():
    import numpy as np
    import torch
    import torch.nn.functional as F
    import torchvision
    from datasets import load_dataset
    from diffusers import DDIMScheduler, DDPMPipeline
    from matplotlib import pyplot as plt
    from PIL import Image
    from torchvision import transforms
    from tqdm.auto import tqdm
    import os
    import datetime

    device = "cuda" if torch.cuda.is_available() else "cpu"

    prompt = ["a picture of a planet"]  # @param
    prompt_preference = "saturn"  # @param

    batch_size = 2
    prompt = prompt * batch_size
    stable_diffusion_class = DiffusionPipeline(prompt)
    preference_loss_class = preference_loss(prompt_preference)

    # Sample some images with a DDIM Scheduler over 25 steps

    dirname = "hug/outputs/SD_sample_ref"
    num_iterations = 10

    now = datetime.datetime.now()
    now_str = datetime.datetime.strftime(now, "%Y%m%d_%H%M%S")
    dirname = os.path.join(dirname, now_str)
    # dirを作成
    os.makedirs(dirname, exist_ok=True)

    # latentsをtorch.tensorとして保存
    latents_save = []

    wandb.init(
        project="stable-diffusion-doob-ref",
        config={
            "prompt": prompt,
            "prompt_preference": prompt_preference,
            "batch_size": batch_size,
            "num_iterations": num_iterations,
            "now": now_str
        }            
    )

    for i in range(num_iterations):
        stable_diffusion_class.create_text_embeddings()
        stable_diffusion_class.generator = torch.manual_seed(i)
        latents = stable_diffusion_class.create_random_noise()
        print("shape of noisy latents: ", latents.shape)
        latents = stable_diffusion_class.denoise_image(latents)
        print("shape of latents after denoised: ", latents.shape)
        images_decoded = stable_diffusion_class.decode(latents)
        images = stable_diffusion_class.img_to_pil(images_decoded)
        # Calculate the loss
        images_tensor = torch.tensor(images_decoded, dtype=torch.float32).to(device)
        images_tensor = images_tensor.permute(0, 3, 1, 2)
        images_tensor = images_tensor.view(batch_size, 1, 3, 512, 512)
        images_tensor = images_tensor.repeat(1, 4, 1, 1, 1)
        # save the images
        if i % 100 == 0 or i < 10:
            for j in range(batch_size):
                filename = os.path.join(dirname, "output_"+str(i)+"_"+str(j)+".png")
                img_j = images_tensor[j]
                print("img_j.shape: ", img_j.shape)
                loss = preference_loss_class.clip_loss(img_j)
                # View the results
                plt.figure()
                plt.title(f"Loss: {loss.item()}")
                plt.imshow(images[j])
                plt.savefig(filename)
                plt.close()
                # Save the latents
                latents_save.append(latents[j])
                print("latents.shape: ", latents.shape)
                # Save the images
                # wandb.log({"image": [wandb.Image(images[j], caption=f"Loss: {loss.item()}")]})
    # latents_saveをcat
    latents_save = torch.cat(latents_save, dim=0)
    torch.save(latents_save, os.path.join(dirname, "latents.pth"))
    wandb.finish()

if __name__ == "__main__":
    main()