from PIL import Image
import torch
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler
from diffusers import UniPCMultistepScheduler
import copy

def show_images(x):
    """Given a batch of images x, make a grid and convert to PIL"""
    x = x * 0.5 + 0.5  # Map from (-1, 1) back to (0, 1)
    grid = torchvision.utils.make_grid(x)
    grid_im = grid.detach().cpu().permute(1, 2, 0).clip(0, 1) * 255
    grid_im = Image.fromarray(np.array(grid_im).astype(np.uint8))
    return grid_im

def make_grid(images, size=64):
    """Given a list of PIL images, stack them together into a line for easy viewing"""
    output_im = Image.new("RGB", (size * len(images), size))
    for i, im in enumerate(images):
        output_im.paste(im.resize((size, size)), (i * size, 0))
    return output_im

class TestPipelineCu126:
    def __init__(self):
        self.vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae")
        self.tokenizer = CLIPTokenizer.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="tokenizer")
        self.text_encoder = CLIPTextModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="text_encoder")
        self.unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet")

        self.scheduler = UniPCMultistepScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")

        self.torch_device = "cuda"
        self.vae.to(self.torch_device)
        self.text_encoder.to(self.torch_device)
        self.unet.to(self.torch_device)

        self.prompt = ["a photograph of an astronaut riding a horse"]
        self.height = 512  # default height of Stable Diffusion
        self.width = 512  # default width of Stable Diffusion
        self.num_inference_steps = 25  # Number of denoising steps
        self.guidance_scale = 1  # Scale for classifier-free guidance
        self.generator = torch.manual_seed(0)  # Seed generator to create the inital latent noise
        self.batch_size = len(self.prompt)   # Number of images to generate

    def create_text_embeddings(self):
        self.text_input = self.tokenizer(
            self.prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt"
        )

        with torch.no_grad():
            self.text_embeddings = self.text_encoder(self.text_input.input_ids.to(self.torch_device))[0]

        self.max_length = self.text_input.input_ids.shape[-1]
        self.uncond_input = self.tokenizer([""] * self.batch_size, padding="max_length", max_length=self.max_length, return_tensors="pt")
        self.uncond_embeddings = self.text_encoder(self.uncond_input.input_ids.to(self.torch_device))[0]
        self.text_embeddings = torch.cat([self.uncond_embeddings, self.text_embeddings])

    def create_random_noise(self):
        latents = torch.randn(
            (self.batch_size, self.unet.in_channels, self.height // 8, self.width // 8),
            generator=self.generator,
        )
        latents = latents.to(self.torch_device)
        return latents
    
    def denoise_image(self, latents):
        latents = latents * self.scheduler.init_noise_sigma
        print("shape of noise: ", latents.shape)

        from tqdm.auto import tqdm

        scheduler = self.scheduler
        unet = self.unet
        text_embeddings = self.text_embeddings
        guidance_scale = self.guidance_scale
        num_inference_steps = self.num_inference_steps

        scheduler.set_timesteps(num_inference_steps)

        for t in tqdm(scheduler.timesteps):
            # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
            latent_model_input = torch.cat([latents] * 2)

            latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=t)

            # predict the noise residual
            with torch.no_grad():
                noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample

            # perform guidance
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

            # compute the previous noisy sample x_t -> x_t-1
            latents = scheduler.step(noise_pred, t, latents).prev_sample
        return latents
    
    def denoise_image_conditional(self, latents, i_conditional, t_conditional):
        from tqdm.auto import tqdm

        print("shape of conditional noise: ", latents.shape)

        # deep copy the scheduler to avoid modifying the original scheduler
        scheduler = copy.deepcopy(self.scheduler)
        scheduler.set_timesteps(self.num_inference_steps)
        unet = self.unet
        text_embeddings = self.text_embeddings
        guidance_scale = self.guidance_scale
        num_inference_steps = self.num_inference_steps

        scheduler.set_timesteps(num_inference_steps)

        for i, t in enumerate(tqdm(scheduler.timesteps)):
            if i >= i_conditional:
                # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
                latent_model_input = torch.cat([latents] * 2)

                latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=t)

                # predict the noise residual
                with torch.no_grad():
                    noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample

                # perform guidance
                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

                # compute the previous noisy sample x_t -> x_t-1
                latents = scheduler.step(noise_pred, t, latents).prev_sample

        del scheduler
        torch.cuda.empty_cache()

        return latents
    

    def decode(self, latents):
        # scale and decode the image latents with vae
        latents = 1 / 0.18215 * latents
        with torch.no_grad():
            image = self.vae.decode(latents).sample

        image = (image / 2 + 0.5).clamp(0, 1)
        image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
        images = (image * 255).round().astype("uint8")
        return images

    def img_to_pil(self, images):
        pil_images = [Image.fromarray(image) for image in images]
        return pil_images

def main():
    import numpy as np
    import torch
    import torch.nn.functional as F
    import torchvision
    from datasets import load_dataset
    from diffusers import DDIMScheduler, DDPMPipeline
    from matplotlib import pyplot as plt
    from PIL import Image
    from torchvision import transforms
    from tqdm.auto import tqdm
    import os
    import datetime

    device = "cuda" if torch.cuda.is_available() else "cpu"

    stable_diffusion_class = TestPipelineCu126()

    # Sample some images with a DDIM Scheduler over 40 steps
    scheduler = stable_diffusion_class.scheduler
    unet = stable_diffusion_class.unet

    stable_diffusion_class.create_text_embeddings()
    latents = stable_diffusion_class.create_random_noise()
    print("shape of noisy latents: ", latents.shape)
    latents = stable_diffusion_class.denoise_image(latents)
    print("shape of latents after denoised: ", latents.shape)
    images = stable_diffusion_class.decode(latents)
    images = stable_diffusion_class.img_to_pil(images)

    dirname = "outputs/test_pipeline_cu126"
    now = datetime.datetime.now()
    now_str = datetime.datetime.strftime(now, "%Y%m%d_%H%M%S")
    dirname = os.path.join(dirname, now_str)
    # dirを作成
    os.makedirs(dirname, exist_ok=True)
    filename0 = os.path.join(dirname, "output0.png")
    # View the results
    plt.figure()
    plt.imshow(images[0])
    plt.savefig(filename0)
    plt.close()

    # @markdown load a CLIP model and define the loss function
    import open_clip

    clip_model, _, preprocess = open_clip.create_model_and_transforms(
        "ViT-B-32", pretrained="openai"
    )
    clip_model.to(device)

    # Transforms to resize and augment an image + normalize to match CLIP's training data
    tfms = torchvision.transforms.Compose(
        [
            torchvision.transforms.RandomResizedCrop(224),  # Random CROP each time
            torchvision.transforms.RandomAffine(
                5
            ),  # One possible random augmentation: skews the image
            torchvision.transforms.RandomHorizontalFlip(),  # You can add additional augmentations if you like
            torchvision.transforms.Normalize(
                mean=(0.48145466, 0.4578275, 0.40821073),
                std=(0.26862954, 0.26130258, 0.27577711),
            ),
        ]
    )

    # And define a loss function that takes an image, embeds it and compares with
    # the text features of the prompt
    def clip_loss(image, text_features):
        image_features = clip_model.encode_image(
            tfms(image)
        )  # Note: applies the above transforms
        input_normed = torch.nn.functional.normalize(image_features.unsqueeze(1), dim=2)
        embed_normed = torch.nn.functional.normalize(text_features.unsqueeze(0), dim=2)
        dists = (
            input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2)
        )  # Squared Great Circle Distance
        return dists.mean()
    
    # @markdown applying guidance using CLIP

    prompt = "an astronomer riding a unicorn"  # @param

    # Explore changing this
    guidance_scale_FT = 0.001 # 8  # @param
    n_cuts = 1 # 4  # @param

    # We embed a prompt with CLIP as our target
    text = open_clip.tokenize([prompt]).to(device)
    with torch.no_grad(), torch.cuda.amp.autocast():
        text_features = clip_model.encode_text(text)

    latents = stable_diffusion_class.create_random_noise()

    text_embeddings = stable_diffusion_class.text_embeddings

    scheduler.set_timesteps(stable_diffusion_class.num_inference_steps)

    for i, t in enumerate(tqdm(scheduler.timesteps)):
        # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
        latent_model_input = torch.cat([latents] * 2)
        latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=t)

        cond_grad = 0

        # predict the noise residual
        with torch.no_grad():
            noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample

        # perform guidance
        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
        noise_pred = noise_pred_uncond + stable_diffusion_class.guidance_scale * (noise_pred_text - noise_pred_uncond)

        for cut in range(n_cuts):
            # Set requires grad on x
            # Get the predicted x0:
            # latent_0 = scheduler.step(noise_pred, t, latent_model_input).pred_original_sample
            latents_0 = stable_diffusion_class.denoise_image_conditional(latents, i, t)
            image_0_img = stable_diffusion_class.decode(latents_0)
            image_0 = torch.tensor(image_0_img, dtype=torch.float32).to(device)
            print("image_0.shape:", image_0.shape)
            # [2, 512, 512, 3] to [2, 3, 512, 512]
            image_0 = image_0.permute(0, 3, 1, 2)
            # 画像として保存
            image_0_img = stable_diffusion_class.img_to_pil(image_0_img)
            filename = os.path.join(dirname, f"output_{i}_{cut}.png")
            plt.figure()
            plt.imshow(image_0_img[0])
            plt.axis("off")
            plt.savefig(filename)
            plt.close()

            image_0 = image_0.detach().requires_grad_()
            # image_0 = stable_diffusion_class.img_to_pil(image_0)
            # Calculate loss
            loss = clip_loss(image_0, text_features) * guidance_scale_FT

            # Get gradient (scale by n_cuts since we want the average)
            grad = torch.autograd.grad(loss, image_0)[0]
            print("grad[0].shape:", grad[0].shape)
            cond_grad -= grad / n_cuts

        if i % 25 == 0:
            print("Step:", i, ", Guidance loss:", loss.item())

        print("cond_grad.shape:", cond_grad.shape)
        with torch.no_grad():
            # 値を0から1に正規化
            cond_grad = F.normalize(cond_grad, p=2, dim=1)
            cond_grad = stable_diffusion_class.vae.encode(cond_grad).latent_dist.sample()
        print("cond_grad.shape:", cond_grad.shape)
        print("latents.shape:", latents.shape)

        # Modify x based on this gradient
        alpha_bar = scheduler.alphas_cumprod[i]
        latents = (
            latents.detach() + cond_grad * alpha_bar.sqrt()
        )  # Note the additional scaling factor here!

        # compute the previous noisy sample x_t -> x_t-1
        latents = scheduler.step(noise_pred, t, latents).prev_sample

    images = stable_diffusion_class.decode(latents)
    images = stable_diffusion_class.img_to_pil(images)

    filename1 = os.path.join(dirname, "output1.png")
    plt.figure()
    plt.imshow(images[0])
    plt.axis("off")
    plt.savefig(filename1)
    plt.close()

if __name__ == '__main__':
    main()