import os
import json
import copy
from tqdm import tqdm

import torch

# from huggingface_hub import login
# login()
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler, StableDiffusion3Pipeline

# https://github.com/gnobitab/InstaFlow/tree/main/code
# from pipeline_rf import RectifiedFlowPipeline


def load_anns(orig_train_file):
    with open(orig_train_file, 'r') as f:
        ann = json.load(f)
    return ann



if __name__=="__main__":

    # import argparse
    # parser = argparse.ArgumentParser()
    # parser.add_argument("--orig_train_file", type=str, required=True)
    # parser.add_argument("--new_train_file", type=str, required=True)
    # parser.add_argument("--image_output_dir", type=str, required=True)
    # args = parser.parse_args()

    dataset = "flickr30k"
    MODEL_ID = "SD3"
    model_id = "stabilityai/stable-diffusion-3-medium-diffusers"
    model_path = None

    out_dir_name = f"flickr30k-images-{MODEL_ID}"
    image_output_dir=f"/data/dataset/Flickr30k/{out_dir_name}"
    orig_train_file="/data/dataset/dataset_json/data/flickr30k_train.json"
    new_train_dir=f"/data/dataset/dataset_json/data_{dataset}_{MODEL_ID}"

    os.makedirs(image_output_dir, exist_ok=True)
    os.makedirs(new_train_dir, exist_ok=True)

    ##############################################################
    # https://huggingface.co/stabilityai/stable-diffusion-2-1-base
    # model_id = "stabilityai/stable-diffusion-2-1-base"

    # Use the DPMSolverMultistepScheduler (DPM-Solver++) scheduler here instead
    pipe = StableDiffusion3Pipeline.from_pretrained(model_id, torch_dtype=torch.float16)
    # pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)

    # load the model from the checkpoint
    # model_path = "/data/ckpts/stable_diffusion_lora/checkpoint-15000"
    if model_path is not None:
        pipe.unet.load_attn_procs(model_path)
    pipe = pipe.to("cuda")
    ##############################################################

    BATCH_SIZE = 16
    NUM_PER_CAPTION = 1
    NUM_CAPS = 1

    anns = load_anns(orig_train_file)

    # anns_dict = {f"cap{i}": {j: [] for j in range(NUM_PER_CAPTION)} for i in range(NUM_CAPS)}

    print("Generating images...")
    image_id2n = {}
    image_id_list = []
    caption_idx_list = []
    captions_list = []
    for i, ann in tqdm(enumerate(anns), total=len(anns)):
        image_id = ann["image_id"]
        caption = ann["caption"]

        if image_id not in image_id2n:
            image_id2n[image_id] = 0
        else:
            image_id2n[image_id] += 1
        cap_n = image_id2n[image_id]
        if cap_n >= NUM_CAPS:
            continue

        image_id_list += [image_id]
        captions_list += [caption]
        caption_idx_list += [cap_n]

        if len(captions_list) == BATCH_SIZE:
            for j in range(NUM_PER_CAPTION):
                images = pipe(
                    captions_list,
                    num_inference_steps=28,
                    guidance_scale=7.0,
                ).images
                for image, _id, cap_idx in zip(images, image_id_list, caption_idx_list):

                    img_path = os.path.join(image_output_dir, f"img{_id}/cap{cap_idx}_{j}.jpg")
                    os.makedirs(os.path.dirname(img_path), exist_ok=True)
                    image.save(img_path)
            captions_list = []
            image_id_list = []
            caption_idx_list = []

    print("Done")


    ####################################
    ###### Create annotation file ######
    ####################################
    print("Creating annotation files...")
    for img_gen_idx in range(NUM_PER_CAPTION): # 5つ生成

        image_id2caption = {}
        image_id2n = {}
        image_id2cap_idx2image_file = {}
        image_id2cap_idx2image_id = {}

        for i, ann in tqdm(enumerate(anns), total=len(anns)):
            image_id = ann["image_id"]
            caption = ann["caption"]

            image_id2caption.setdefault(image_id, []).append(caption)

            if image_id not in image_id2n:
                image_id2n[image_id] = 0
            else:
                image_id2n[image_id] += 1
            cap_idx = image_id2n[image_id]

            new_image_id = f"img{image_id}/cap{cap_idx}_{img_gen_idx}"
            new_image_file = os.path.join(image_output_dir, f"img{image_id}/cap{cap_idx}_{img_gen_idx}.jpg")
            image_id2cap_idx2image_file.setdefault(image_id, {})[cap_idx] = new_image_file
            image_id2cap_idx2image_id.setdefault(image_id, {})[cap_idx] = new_image_id

        for cap_idx in range(NUM_CAPS):
            # それぞれのcaption idxに対して。例：caps_k=1のファイルを作る。

            for use_caps_k in range(1,1+NUM_CAPS):

                new_anns = []
                for image_id in image_id2caption:
                    for cap_idx_for_gen in range(use_caps_k):

                        img_path = os.path.join("/data/dataset/Flickr30k", image_id2cap_idx2image_file[image_id][cap_idx_for_gen])
                        if not os.path.exists(img_path):
                            print("Does not exist: ", img_path)
                            continue
                
                        new_ann = {
                            "image": image_id2cap_idx2image_file[image_id][cap_idx_for_gen],
                            "image_id": image_id2cap_idx2image_id[image_id][cap_idx_for_gen],
                            "caption": image_id2caption[image_id][cap_idx],
                        }
                        new_anns.append(new_ann)

                name = f"{MODEL_ID}_cap={cap_idx}_useCap{use_caps_k}_imgGenIdx{img_gen_idx}.json"
                with open(os.path.join(new_train_dir, name), 'w') as f:
                    json.dump(new_anns, f, indent=4)


    print("Done")