import os
import json
import copy
from tqdm import tqdm

import torch
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler

# https://github.com/gnobitab/InstaFlow/tree/main/code
# from pipeline_rf import RectifiedFlowPipeline


def load_anns(orig_train_file):
    with open(orig_train_file, 'r') as f:
        ann = json.load(f)
    return ann



if __name__=="__main__":

    # import argparse
    # parser = argparse.ArgumentParser()
    # parser.add_argument("--orig_train_file", type=str, required=True)
    # parser.add_argument("--new_train_file", type=str, required=True)
    # parser.add_argument("--image_output_dir", type=str, required=True)
    # args = parser.parse_args()

    model_name = "SD3"
    model_id = "stabilityai/stable-diffusion-3-medium"

    out_dir_name = "flickr30k-images-stableDiffusion"
    image_output_dir=f"/data/dataset/Flickr30k/{out_dir_name}"
    # data_dir="/data/dataset/Flickr30k"
    orig_train_file="/data/dataset/dataset_json/data/flickr30k_train.json"
    new_train_dir="/data/dataset/dataset_json/data_stableDiffusion"

    
    os.makedirs(image_output_dir, exist_ok=True)
    os.makedirs(new_train_dir, exist_ok=True)

    ##############################################################
    # https://huggingface.co/stabilityai/stable-diffusion-2-1-base
    # model_id = "stabilityai/stable-diffusion-2-1-base"
    # model_id = "aaronb/dreamshaper-8-dmd-1kstep"

    # Use the DPMSolverMultistepScheduler (DPM-Solver++) scheduler here instead
    pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
    # pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
    pipe = pipe.to("cuda")

    # https://github.com/gnobitab/InstaFlow/tree/main/code
    # pipe = RectifiedFlowPipeline.from_pretrained(
    #     "XCLIU/instaflow_0_9B_from_sd_1_5", 
    #     torch_dtype=torch.float16,
    #     safety_checker=None,
    #     requires_safety_checker=False,
    # ) 
    # pipe.to("cuda")  ### if GPU is not available, comment this line
    ##############################################################


    # prompt = "a photo of an astronaut riding a horse on mars"
    # image = pipe(prompt).images[0]

    BATCH_SIZE = 32
    NUM_PER_CAPTION = 4
    NUM_CAPS = 1

    anns = load_anns(orig_train_file)

    anns_dict = {f"cap{i}": {j: [] for j in range(NUM_PER_CAPTION)} for i in range(NUM_CAPS)}

    image_id2n = {}
    image_id_list = []
    caption_idx_list = []
    captions_list = []
    for i, ann in tqdm(enumerate(anns), total=len(anns)):
        image_id = ann["image_id"]
        caption = ann["caption"]

        if image_id not in image_id2n:
            image_id2n[image_id] = 0
        else:
            image_id2n[image_id] += 1
        cap_n = image_id2n[image_id]
        if cap_n >= NUM_CAPS:
            continue

        image_id_list += [image_id]
        captions_list += [caption]
        caption_idx_list += [cap_n]

        if len(captions_list) == BATCH_SIZE:
            for j in range(NUM_PER_CAPTION):
                images = pipe(captions_list).images
                # images = pipe(
                #     prompt=captions_list, 
                #     num_inference_steps=1, 
                #     guidance_scale=0.0
                # ).images 
                for image, _id, cap_idx in zip(images, image_id_list, caption_idx_list):
                    # j=0はやってあるので,j+1にする.
                    p = os.path.join(image_output_dir, f"img{_id}_cap{cap_idx}_{j+1}.jpg")
                    if os.path.exists(p):
                        print("Already exists: ", p)
                        continue
                        # raise ValueError
                    image.save(p)
            captions_list = []
            image_id_list = []
            caption_idx_list = []


        print(caption)
        for j in range(NUM_PER_CAPTION):
            # image = pipe(caption).images[0]
            # image.save(os.path.join(image_output_dir, f"{image_id}_{j}.jpg"))

            new_ann = copy.deepcopy(ann)
            new_ann["image_id"] = f"{image_id}_{j}.jpg"
            new_ann["image"] = out_dir_name + f"/{image_id}_{j}.jpg"
            
            anns_dict[f"cap{cap_n}"][j].append(new_ann)
        

        if i % 500 == 0:
            print(f"Processed {i} captions")
            for k in anns_dict:
                for j in anns_dict[k]:
                    name = f"flickr30k_train_{k}-{j}.json"
                    with open(os.path.join(new_train_dir, name), 'w') as f:
                        json.dump(anns_dict[k][j], f)

    for k in anns_dict:
        for j in anns_dict[k]:
            name = f"flickr30k_train_{k}-{j}.json"
            with open(os.path.join(new_train_dir, name), 'w') as f:
                json.dump(anns_dict[k][j], f)


    ####################################
    ###### Create annotation file ######
    ####################################
    new_anns_all = []
    for img_gen_idx in range(NUM_PER_CAPTION): # 5つ生成

        image_id2caption = {}
        image_id2n = {}
        image_id2cap_idx2image_file = {}
        image_id2cap_idx2image_id = {}

        for i, ann in tqdm(enumerate(anns), total=len(anns)):
            image_id = ann["image_id"]
            caption = ann["caption"]

            if image_id not in image_id2caption:
                image_id2caption[image_id] = [caption]
            else:
                image_id2caption[image_id] += [caption]

            if image_id not in image_id2n:
                image_id2n[image_id] = 0
            else:
                image_id2n[image_id] += 1
            cap_idx = image_id2n[image_id]


            new_image_id = f"img{image_id}_cap{cap_idx}_{img_gen_idx}"
            new_image_file = out_dir_name + f"/img{image_id}_cap{cap_idx}_{img_gen_idx}.jpg"
            if image_id not in image_id2cap_idx2image_file:
                image_id2cap_idx2image_file[image_id] = {cap_idx: new_image_file}
                image_id2cap_idx2image_id[image_id] = {cap_idx: new_image_id}
            else:
                image_id2cap_idx2image_file[image_id][cap_idx] = new_image_file
                image_id2cap_idx2image_id[image_id][cap_idx] = new_image_id

        cap_idx = 0
        # それぞれのcaption idxに対して。例：caps_k=1のファイルを作る。

        new_anns = []
        for image_id in image_id2caption:
            img_path = os.path.join("/data/dataset/Flickr30k", image_id2cap_idx2image_file[image_id][cap_idx])
            if not os.path.exists(img_path):
                print("Does not exist: ", img_path)
                continue
    
            new_ann = {
                "image": image_id2cap_idx2image_file[image_id][cap_idx],
                "image_id": image_id2cap_idx2image_id[image_id][cap_idx],
                "caption": image_id2caption[image_id][cap_idx],
            }
            new_anns.append(new_ann)

        name = f"many-to-one_cap={cap_idx}_useCap{cap_idx}_imgGenIdx{img_gen_idx}.json"
        # with open(os.path.join(new_train_dir, name), 'w') as f:
        #     json.dump(new_anns, f, indent=4)

        new_anns_all += new_anns

    # Sort new_anns_all by original image id
    new_anns_all = sorted(new_anns_all, key=lambda x: int(x["image_id"].split("_")[0].replace("img", "")))
    name = f"many-to-one_cap=0_useCap0_1to4.json"
    with open(os.path.join(new_train_dir, name), 'w') as f:
        json.dump(new_anns_all, f, indent=4)


