import numpy as np
import torch
import torch.nn.functional as F
from matplotlib import pyplot as plt
from PIL import Image
import torchvision
from datasets import load_dataset
from torchvision import transforms

import matplotlib.pyplot as plt

import base64
from io import BytesIO

import datetime

import os

from diffusers import DDPMScheduler, UNet2DModel, DDPMPipeline

# Set image size and batch size
image_size = 32

def show_images(x, nrow = -1, mode="butterfly"):
    """Given a batch of images x, make a grid and convert to PIL"""
    if mode == "butterfly":
        x = x * 0.5 + 0.5  # Map from (-1, 1) back to (0, 1)
    if nrow > 0:
        grid = torchvision.utils.make_grid(x, nrow=nrow)
    else:    
        grid = torchvision.utils.make_grid(x)
    grid_im = grid.detach().cpu().permute(1, 2, 0).clip(0, 1) * 255
    try:
        grid_im = Image.fromarray(np.array(grid_im).astype(np.uint8), mode="RGB")
    except Exception as e:
        print(e)
        grid_im = Image.fromarray(np.array(grid_im).astype(np.uint8))

    return grid_im

def make_grid(images, size=64):
    """Given a list of PIL images, stack them together into a line for easy viewing"""
    output_im = Image.new("RGB", (size * len(images), size))
    for i, im in enumerate(images):
        output_im.paste(im.resize((size, size)), (i * size, 0))
    return output_im

# Define data augmentations
preprocess = transforms.Compose(
    [
        transforms.Resize((image_size, image_size)),  # Resize
        transforms.RandomHorizontalFlip(),  # Randomly flip (data augmentation)
        transforms.ToTensor(),  # Convert to tensor (0, 1)
        transforms.Normalize([0.5], [0.5]),# [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),  # Map to (-1, 1) for RGB
    ]
)

# Define the transformation function
def transform(examples):
    # dictのkeyをリストアップ
    # images = [preprocess(image['bytes'].convert("RGB")) for image in examples["image"]]
    images = []
    for image in examples["image"]:
        # images.append(preprocess(Image.open(BytesIO(image_dict['bytes'])).convert("RGB")))
        images.append(preprocess(image.convert("RGB")))
    return {"images": images}

def main():
    # cudaが使えるか (noteからは使えないらしい)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(device)

    # Load the dataset
    dataset = load_dataset("huggan/smithsonian_butterflies_subset", split="train")

    batch_size = 64

    # Apply the transform to the dataset
    dataset.set_transform(transform)

    # Create a DataLoader from the dataset to serve up the transformed images in batches
    train_dataloader = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, shuffle=True
    )

    # Example to iterate over the dataloader and print batch sizes
    for batch in train_dataloader:
        print(batch["images"].size())  # Should print torch.Size([64, 3, 32, 32]) for the first batch
        break

    xb = next(iter(train_dataloader))["images"].to(device)[:8]
    print("X shape:", xb.shape)
    grid_im = show_images(xb).resize((8 * 64, 64), resample=Image.NEAREST)
    plt.imshow(grid_im)
    plt.axis("off")
    # 日付をフォルダ名にする
    now = datetime.datetime.now()
    # outputs/日付/test.png
    dirname = "hug/outputs/tests" + now.strftime("%Y%m%d_%H%M%S")
    os.makedirs(dirname, exist_ok=True) # 存在していたらエラーを出さない
    pathname = dirname + "/test.png"
    plt.savefig(pathname)

    # datasetの保存
    # dirname_data = "hug/src/datasets/" + now.strftime("%Y%m%d_%H%M%S")
    # os.makedirs(dirname_data, exist_ok=True) # 存在していたらエラーを出さない
    # pathname_data = dirname_data + "/dataset.pth"
    # torch.save(dataset, pathname_data)

if __name__  == "__main__":
    main()