import os
from datasets import load_dataset
from diffusers import StableDiffusionPipeline
import torch
import argparse
import shutil
import numpy as np
from tqdm import tqdm


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--num_chunks",
        type=int,
        default=20,
    )
    parser.add_argument(
        "--chunk_idx",
        type=int,
        default=None,
        required=True,
    )
    parser.add_argument(
        "--path_to_indices",
        type=str,
        default="/cmlscratch/XXX/t5_analysis/clean_fid_coco/coco_random_indices.npy"
    )
    parser.add_argument(
        "--path_to_save_generated_images",
        type=str,
        default="/cmlscratch/XXX/t5_analysis/clean_fid_coco/sd/0"
    )

    args = parser.parse_args()

    assert os.path.isfile(args.path_to_indices)

    return args


def get_list_chunk(arr: np.ndarray, num_chunks: int, chunk_idx: int) -> list:
    arr_len = arr.shape[0] 

    chunk_size = (arr_len + num_chunks - 1) // num_chunks

    start_index = chunk_size * chunk_idx
    end_index = min((chunk_idx + 1) * chunk_size, arr_len)

    print(f"Choosing chunk ({start_index}:{end_index})")
    print(f"First item of the chunk: {arr[start_index]}")
    print(f"Last item of the chunk: {arr[end_index-1]}", flush=True)

    return start_index, arr[start_index:end_index]


def main():
    args = parse_args()

    if not os.path.exists(args.path_to_save_generated_images):
        print(f'Creating path \"{args.path_to_save_generated_images}\"')
        os.makedirs(args.path_to_save_generated_images)

    random_indices = np.load(args.path_to_indices)
    start_global_cnt, chunk_indices = get_list_chunk(random_indices, args.num_chunks, args.chunk_idx)
    dataset = load_dataset("HuggingFaceM4/COCO", split='train').select(chunk_indices)
    def transform(examples):
        return {'prompts': [x['raw'].strip('.').lower() for x in examples['sentences']]}
    dataset.set_transform(transform)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=10, shuffle=False)

    # model_id = "stabilityai/stable-diffusion-2-1"
    model_id = "CompVis/stable-diffusion-v1-4"

    # Use the Euler scheduler here instead
    pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
    pipe = pipe.to("cuda")

    global_cnt = start_global_cnt
    for batch in tqdm(dataloader):
        images = pipe(batch['prompts'], num_inference_steps=50).images
        for img in images:
            img.save(os.path.join(args.path_to_save_generated_images, f"{global_cnt:06d}.png"))
            global_cnt += 1


if __name__ == "__main__":
    main()
