import argparse, os, sys, glob, datetime, yaml
import torch
import time
import numpy as np
from tqdm import tqdm

from PIL import Image
import datasets
from datasets import load_dataset

from torchvision import transforms


def custom_to_np(x):
    # saves the batch in adm style as in https://github.com/openai/guided-diffusion/blob/main/scripts/image_sample.py
    # sample = ((sample + 1) * 127.5).clamp(0, 255).to(torch.uint8)
    output = x.detach().cpu()
    # output = ((output + 1) * 127.5).clamp(0, 255).to(torch.uint8)
    output = (output*255).clamp(0,255).to(torch.uint8)
    output = output.permute(0, 2, 3, 1)
    output = output.contiguous()
    return output

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Simple example of a training script.")
    parser.add_argument(
        "--dataset_name",
        type=str,
        default=None,
        help=(
            "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
            " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
            " or to a folder containing files that HF Datasets can understand."
        ),
    )
    parser.add_argument(
        "--resolution",
        type=int,
        default=64,
        help=(
            "The resolution for input images, all the images in the train/validation dataset will be resized to this"
            " resolution"
        ),
    )
    parser.add_argument(
        "--center_crop",
        default=False,
        action="store_true",
        help=(
            "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
            " cropped. The images will be resized to the resolution first before cropping."
        ),
    )
    parser.add_argument(
        "-n",
        "--n_samples",
        type=int,
        nargs="?",
        help="number of samples to draw",
        default=50000
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="output",
        help=(
            "The output directory where the model predictions and checkpoints will be written."
        ),
    )
    parser.add_argument(
        "--dataset_config_name",
        type=str,
        default=None,
        help="The config of the Dataset, leave as None if there's only one config.",
    )
    parser.add_argument(
        "--train_data_dir",
        type=str,
        default=None,
        help=(
            "A folder containing the training data. Folder contents must follow the structure described in"
            " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
            " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
        ),
    )
    parser.add_argument(
        "--cache_dir",
        type=str,
        default=None,
        help="The directory where the downloaded models and datasets will be stored.",
    )
    args = parser.parse_args()
    if args.dataset_name is not None:
        dataset = load_dataset(
            args.dataset_name,
            args.dataset_config_name,
            cache_dir=args.cache_dir,
            split="train",
        )
    else:
        dataset = load_dataset("imagefolder", data_dir=args.train_data_dir, cache_dir=args.cache_dir, split="train")
        # See more about loading custom images at
        # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
    
    os.makedirs(args.output_dir, exist_ok=True)

    # Preprocessing the datasets and DataLoaders creation.
    augmentations = transforms.Compose(
        [
            # transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
            # transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]),
        ]
    )
    def transform_images(examples):
        if "image" in examples:
            images = [augmentations(image.convert("RGB")) for image in examples["image"]]
        else:
            # import pdb; pdb.set_trace()
            images = [augmentations(image.convert("RGB")) for image in examples["img"]]
        return {"input": images}

    print(f"Dataset size: {len(dataset)}")

    dataset.set_transform(transform_images)
    train_dataloader = torch.utils.data.DataLoader(
        dataset, batch_size=500, shuffle=True, drop_last=True
    )
    all_images = []
    savedcount = 0
    pbar = tqdm(total=args.n_samples, desc="Generating images")
    for step, batch in enumerate(train_dataloader):
        images = batch["input"]
        # import pdb; pdb.set_trace()
        # images = torch.stack(images, dim=0)
        # images = custom_to_np(images)
        # all_images.extend(images)
        all_images.extend([custom_to_np(images)])
        savedcount += images.shape[0]
        if savedcount >= args.n_samples:
            print(f"Saved {savedcount} images")
            break
        pbar.update(savedcount)
    all_img = np.concatenate(all_images, axis=0)
    all_img = all_img[:args.n_samples]
    pbar.close()
    # Extract last part of dataset_name if it contains '/'
    if args.dataset_name and '/' in args.dataset_name:
        dataset_short_name = args.dataset_name.split('/')[-1]
    else:
        dataset_short_name = args.dataset_name  # 그대로 사용


    # pil_images = [Image.fromarray(image) for image in all_img]

    # for idx, img in enumerate(pil_images):
    #     img.save(os.path.join(args.output_dir, f"{idx:05}.png"))

    nppath = os.path.join(args.output_dir, f"train-{dataset_short_name}-{args.n_samples}-samples.npz")
    print(f"Saving {all_img.shape} to {nppath}")
    np.savez(nppath, all_img)


        