import os
import json
import copy
from tqdm import tqdm

import torch
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler

# https://github.com/gnobitab/InstaFlow/tree/main/code
# from pipeline_rf import RectifiedFlowPipeline


def load_anns(orig_train_file):
    with open(orig_train_file, 'r') as f:
        ann = json.load(f)
    return ann



if __name__=="__main__":

    # import argparse
    # parser = argparse.ArgumentParser()
    # parser.add_argument("--orig_train_file", type=str, required=True)
    # parser.add_argument("--new_train_file", type=str, required=True)
    # parser.add_argument("--image_output_dir", type=str, required=True)
    # args = parser.parse_args()

    out_dir_name = "train-stableDiffusion"
    image_output_dir=f"/data/dataset/MSCOCO/{out_dir_name}"
    # data_dir="/data/dataset/Flickr30k"
    orig_train_file="/data/dataset/dataset_json/data/coco_train.json"
    new_train_dir="/data/dataset/dataset_json/data_coco_stableDiffusion"

    
    os.makedirs(image_output_dir, exist_ok=True)
    os.makedirs(new_train_dir, exist_ok=True)

    ##############################################################
    # https://huggingface.co/stabilityai/stable-diffusion-2-1-base
    model_id = "stabilityai/stable-diffusion-2-1-base"
    # model_id = "aaronb/dreamshaper-8-dmd-1kstep"

    # Use the DPMSolverMultistepScheduler (DPM-Solver++) scheduler here instead
    pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
    pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
    pipe = pipe.to("cuda")

    # https://github.com/gnobitab/InstaFlow/tree/main/code
    # pipe = RectifiedFlowPipeline.from_pretrained(
    #     "XCLIU/instaflow_0_9B_from_sd_1_5", 
    #     torch_dtype=torch.float16,
    #     safety_checker=None,
    #     requires_safety_checker=False,
    # ) 
    # pipe.to("cuda")  ### if GPU is not available, comment this line
    ##############################################################


    # prompt = "a photo of an astronaut riding a horse on mars"
    # image = pipe(prompt).images[0]

    BATCH_SIZE = 32
    NUM_PER_CAPTION = 4
    NUM_CAPS = 1

    anns = load_anns(orig_train_file)

    # anns_dict = {f"cap{i}": {j: [] for j in range(NUM_PER_CAPTION)} for i in range(NUM_CAPS)}

    # image_id2n = {}
    # image_id_list = []
    # caption_idx_list = []
    # captions_list = []
    # for i, ann in tqdm(enumerate(anns), total=len(anns)):
    #     image_id = ann["image_id"]
    #     caption = ann["caption"]

    #     if image_id not in image_id2n:
    #         image_id2n[image_id] = 0
    #     else:
    #         image_id2n[image_id] += 1
    #     cap_n = image_id2n[image_id]
    #     if cap_n >= NUM_CAPS:
    #         continue

    #     image_id_list += [image_id]
    #     captions_list += [caption]
    #     caption_idx_list += [cap_n]

    #     if len(captions_list) == BATCH_SIZE:
    #         for j in range(NUM_PER_CAPTION):
    #             images = pipe(captions_list).images
    #             # images = pipe(
    #             #     prompt=captions_list, 
    #             #     num_inference_steps=1, 
    #             #     guidance_scale=0.0
    #             # ).images 
    #             for image, _id, cap_idx in zip(images, image_id_list, caption_idx_list):
    #                 p = os.path.join(image_output_dir, f"img{_id}_cap{cap_idx}_{j}.jpg")
    #                 if os.path.exists(p):
    #                     print("Already exists: ", p)
    #                     continue
    #                 image.save(p)
    #         captions_list = []
    #         image_id_list = []
    #         caption_idx_list = []


    # #     print(caption)
    # #     for j in range(NUM_PER_CAPTION):
    # #         # image = pipe(caption).images[0]
    # #         # image.save(os.path.join(image_output_dir, f"{image_id}_{j}.jpg"))

    # #         new_ann = copy.deepcopy(ann)
    # #         new_ann["image_id"] = f"{image_id}_{j}.jpg"
    # #         new_ann["image"] = out_dir_name + f"/{image_id}_{j}.jpg"
            
    # #         anns_dict[f"cap{cap_n}"][j].append(new_ann)
        

    # #     if i % 500 == 0:
    # #         print(f"Processed {i} captions")
    # #         for k in anns_dict:
    # #             for j in anns_dict[k]:
    # #                 name = f"flickr30k_train_{k}-{j}.json"
    # #                 with open(os.path.join(new_train_dir, name), 'w') as f:
    # #                     json.dump(anns_dict[k][j], f)

    # # for k in anns_dict:
    # #     for j in anns_dict[k]:
    # #         name = f"flickr30k_train_{k}-{j}.json"
    # #         with open(os.path.join(new_train_dir, name), 'w') as f:
    # #             json.dump(anns_dict[k][j], f)
    # a

    ####################################
    ###### Create annotation file ######
    ####################################
    new_anns_all = []
    for img_gen_idx in range(NUM_PER_CAPTION): # 5つ生成

        image_id2caption = {}
        image_id2n = {}
        image_id2cap_idx2image_file = {}
        image_id2cap_idx2image_id = {}

        for i, ann in tqdm(enumerate(anns), total=len(anns)):
            image_id = ann["image_id"]
            caption = ann["caption"]

            if image_id not in image_id2caption:
                image_id2caption[image_id] = [caption]
            else:
                image_id2caption[image_id] += [caption]

            if image_id not in image_id2n:
                image_id2n[image_id] = 0
            else:
                image_id2n[image_id] += 1
            cap_idx = image_id2n[image_id]


            new_image_id = f"img{image_id}_cap{cap_idx}_{img_gen_idx}"
            new_image_file = out_dir_name + f"/img{image_id}_cap{cap_idx}_{img_gen_idx}.jpg"
            if image_id not in image_id2cap_idx2image_file:
                image_id2cap_idx2image_file[image_id] = {cap_idx: new_image_file}
                image_id2cap_idx2image_id[image_id] = {cap_idx: new_image_id}
            else:
                image_id2cap_idx2image_file[image_id][cap_idx] = new_image_file
                image_id2cap_idx2image_id[image_id][cap_idx] = new_image_id

        cap_idx = 0
        # それぞれのcaption idxに対して。例：caps_k=1のファイルを作る。

        new_anns = []
        for image_id in image_id2caption:
            img_path = os.path.join("/data/dataset/MSCOCO", image_id2cap_idx2image_file[image_id][cap_idx])
            if not os.path.exists(img_path):
                print("Does not exist: ", img_path)
                continue
    
            new_ann = {
                "image": image_id2cap_idx2image_file[image_id][cap_idx],
                "image_id": image_id2cap_idx2image_id[image_id][cap_idx],
                "caption": image_id2caption[image_id][cap_idx],
            }
            new_anns.append(new_ann)
        
        new_anns_all += new_anns

    # Sort new_anns_all by original image id
    new_anns_all = sorted(new_anns_all, key=lambda x: x["image_id"].split("_")[0].replace("img", "").replace("coco_", ""))
    name = f"many-to-one_cap=0_useCap0_1to4.json"
    with open(os.path.join(new_train_dir, name), 'w') as f:
        json.dump(new_anns_all, f, indent=4)

