import numpy as np
import torch
import torch.nn.functional as F
from matplotlib import pyplot as plt
from PIL import Image
import torchvision
from datasets import load_dataset
from torchvision import transforms

import matplotlib.pyplot as plt

import base64
from io import BytesIO

import datetime

import os

from diffusers import DDPMScheduler, UNet2DModel, DDPMPipeline

from torch.autograd.functional import jacobian
from concurrent.futures import ThreadPoolExecutor

# Set image size and batch size
image_size = 32

def show_images(x):
    """Given a batch of images x, make a grid and convert to PIL"""
    x = x * 0.5 + 0.5  # Map from (-1, 1) back to (0, 1)
    grid = torchvision.utils.make_grid(x)
    grid_im = grid.detach().cpu().permute(1, 2, 0).clip(0, 1) * 255
    grid_im = Image.fromarray(np.array(grid_im).astype(np.uint8))
    return grid_im

def make_grid(images, size=64):
    """Given a list of PIL images, stack them together into a line for easy viewing"""
    output_im = Image.new("RGB", (size * len(images), size))
    for i, im in enumerate(images):
        output_im.paste(im.resize((size, size)), (i * size, 0))
    return output_im

# Define data augmentations
preprocess = transforms.Compose(
    [
        transforms.Resize((image_size, image_size)),  # Resize
        transforms.RandomHorizontalFlip(),  # Randomly flip (data augmentation)
        transforms.ToTensor(),  # Convert to tensor (0, 1)
        transforms.Normalize([0.5], [0.5]),# [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),  # Map to (-1, 1) for RGB
    ]
)

# Define the transformation function
def transform(examples):
    # dictのkeyをリストアップ
    # images = [preprocess(image['bytes'].convert("RGB")) for image in examples["image"]]
    images = []
    for image in examples["image"]:
        # images.append(preprocess(Image.open(BytesIO(image_dict['bytes'])).convert("RGB")))
        images.append(preprocess(image.convert("RGB")))
    return {"images": images}

def main():
    # cudaが使えるか (noteからは使えないらしい)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(device)

    # Load the dataset
    dataset = load_dataset("huggan/smithsonian_butterflies_subset", split="train")

    batch_size = 64

    # Apply the transform to the dataset
    dataset.set_transform(transform)

    # Create a DataLoader from the dataset to serve up the transformed images in batches
    train_dataloader = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, shuffle=True
    )

    # Example to iterate over the dataloader and print batch sizes
    for batch in train_dataloader:
        print(batch["images"].size())  # Should print torch.Size([64, 3, 32, 32]) for the first batch
        break

    xb = next(iter(train_dataloader))["images"].to(device)[:8]
    print("X shape:", xb.shape)
    grid_im = show_images(xb).resize((8 * 64, 64), resample=Image.NEAREST)
    plt.imshow(grid_im)
    plt.axis("off")
    # 日付をフォルダ名にする
    now = datetime.datetime.now()
    # outputs/日付/test.png
    dirname = "hug/outputs/tests" + now.strftime("%Y%m%d_%H%M%S")
    os.makedirs(dirname, exist_ok=True) # 存在していたらエラーを出さない
    pathname = dirname + "/test.png"
    plt.savefig(pathname)

    noise_scheduler = DDPMScheduler(num_train_timesteps=1000)

    # Plot the noise schedule
    plt.figure(figsize=(12, 6))
    plt.plot(noise_scheduler.alphas_cumprod.cpu() ** 0.5, label=r"${\sqrt{\bar{\alpha}_t}}$")
    plt.plot((1 - noise_scheduler.alphas_cumprod.cpu()) ** 0.5, label=r"$\sqrt{(1 - \bar{\alpha}_t)}$")
    plt.legend(fontsize="x-large")
    plt.savefig(dirname + "/noise_schedule.png")

    # Add noise to the batch
    timesteps = torch.linspace(0, 999, 8).long().to(device)
    noise = torch.randn_like(xb)
    noisy_xb = noise_scheduler.add_noise(xb, noise, timesteps)
    print("Noisy X shape", noisy_xb.shape)
    noisy_im = show_images(noisy_xb).resize((8 * 64, 64), resample=Image.NEAREST)
    plt.imshow(noisy_im)
    plt.axis("off")
    plt.savefig(dirname + "/noisy.png")

    # Create a model
    model = UNet2DModel(
        sample_size=image_size,  # the target image resolution
        in_channels=3,  # the number of input channels, 3 for RGB images
        out_channels=3,  # the number of output channels
        layers_per_block=2,  # how many ResNet layers to use per UNet block
        block_out_channels=(64, 128, 128, 256),  # More channels -> more parameters
        down_block_types=(
            "DownBlock2D",  # a regular ResNet downsampling block
            "DownBlock2D",
            "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
            "AttnDownBlock2D",
        ),
        up_block_types=(
            "AttnUpBlock2D",
            "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
            "UpBlock2D",
            "UpBlock2D",  # a regular ResNet upsampling block
        ),
    )
    model.to(device);

    with torch.no_grad():
        model_prediction = model(noisy_xb, timesteps).sample
    print("model_prediction.shape:", model_prediction.shape)

    print("Noise scheduler shape:", noise_scheduler.alphas_cumprod.shape)

    # Load the butterfly pipeline
    butterfly_pipeline = DDPMPipeline.from_pretrained(
        "johnowhitaker/ddpm-butterflies-32px"
    ).to(device)

    noise_scheduler = butterfly_pipeline.scheduler
    noise_scheduler.set_timesteps(1000)

    # Random starting point (8 random images):
    batch_size = 8
    sample = torch.randn(batch_size, 3, 32, 32).to(device)

    for i, t in enumerate(noise_scheduler.timesteps):
        # Get model pred
        with torch.no_grad():
            residual = butterfly_pipeline.unet(sample, t)["sample"]  # model(sample, t).sample
        # Update sample with step
        sample = noise_scheduler.step(residual, t, sample).prev_sample

    plt.figure(figsize=(6*batch_size+2, 6))
    img = show_images(sample)
    plt.imshow(img)
    plt.axis("off")
    plt.savefig(dirname + "/butterfly.png")

    t = torch.tensor([0], device=device)
    # x に関する微分を計算するための関数を定義
    def output_wrt_x(x):
        return butterfly_pipeline.unet(x, t)["sample"]

    print("-"*80)
    print("calculate jacobian")
    # 出力に対する x のヤコビアンを計算
    batch_size = 1
    sample = torch.randn(batch_size, 3, 32, 32, requires_grad = True).to(device)
    jacobian_matrix_x = jacobian(output_wrt_x, sample)

    if False:
        # バッチ内でヤコビアンを計算する関数
        def compute_jacobian(x):
            return jacobian(output_wrt_x, x)
        # スレッドプールを使用して並列処理
        with ThreadPoolExecutor() as executor:
            # サンプルデータを生成
            samples = [torch.randn(1, 3, 32, 32, requires_grad=True).to(device) for _ in range(batch_size)]
            # 並列処理でヤコビアンを計算
            results = list(executor.map(compute_jacobian, samples))

    print("jacobian_matrix_x.shape:", jacobian_matrix_x.shape)
    # 前半4つがoutput_wrt_xのshapeに対応, 後半4つがsampleのshapeに対応

    return 0

if __name__ == "__main__":
    main()