import os 
import json
# https://github.com/commonsense/conceptnet5/wiki/API
# import requests
import torch
import random 
random.seed(0)
import math
import argparse

# python run_diffusion.py --dataset cifar100 --start_class_id 0 --end_class_id 100
# python run_diffusion.py --dataset eurosat --start_class_id 0 --end_class_id 10

# python run_diffusion.py --dataset imagenet --start_class_id 0 --end_class_id 100
# python run_diffusion.py --dataset imagenet --start_class_id 100 --end_class_id 200
# python run_diffusion.py --dataset imagenet --start_class_id 200 --end_class_id 300
# python run_diffusion.py --dataset imagenet --start_class_id 300 --end_class_id 400
# .....
# python run_diffusion.py --dataset imagenet --start_class_id 900 --end_class_id 1000

dataset2filename={
    "cifar100": "cifar100_wikiRAG_all_clean.json",
    "imagenet": "imagenet_wikiRAG_all_clean.json",
    "eurosat": "eurosat_wikiRAG_all_clean.json",
    "dtd": "dtd_wikiRAG_all_clean.json",
}
dataset2num_per_sentence={
    "cifar100": 4,
    "imagenet": 4,
    "eurosat": 25,
    "dtd": 8,
}

dataset2num_class={
    "cifar100": 100,
    "imagenet": 1000,
    "eurosat": 10,
    "dtd": 47,
}


def parse_args():
    parser = argparse.ArgumentParser('data generation')
    parser.add_argument('--dataset', type=str, default="imagenet")
    parser.add_argument('--start_class_id', type=int, default=0)
    parser.add_argument('--end_class_id', type=int, default=100)
    parser.add_argument('--save_folder', type=str, default="generated_img")
    # parser.add_argument('--guidance_scale', type=float, default=1)
    return parser.parse_args()

def text_to_image(prompt, guidance_scale):
    from diffusers import StableDiffusionPipeline
    from compel import Compel
    device = f"cuda"

    def get_inputs(prompt_embeds,negative_prompt_embeds,guidance_scale):                                                                                                                                                                         
        return {"prompt_embeds": prompt_embeds,"negative_prompt_embeds":negative_prompt_embeds,"guidance_scale": guidance_scale}   

    model_id = "runwayml/stable-diffusion-v1-5"
    pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
    pipe = pipe.to(device)
    compel = Compel(tokenizer=pipe.tokenizer, text_encoder=pipe.text_encoder,truncate_long_prompts=False)
   
    conditioning = compel.build_conditioning_tensor(prompt)
    negative_prompt = ""
    negative_conditioning = compel.build_conditioning_tensor(negative_prompt)
    [conditioning, negative_conditioning] = compel.pad_conditioning_tensors_to_same_length([conditioning, negative_conditioning])

    images = pipe(**get_inputs(conditioning,negative_conditioning,guidance_scale)).images  
    return images[0]


if __name__ == '__main__':
    args = parse_args()
    num_per_sentence= dataset2num_per_sentence[args.dataset]

    if args.end_class_id>dataset2num_class[args.dataset]: 
        args.end_class_id=dataset2num_class[args.dataset]


    with open(dataset2filename[args.dataset]) as f:
        cls_to_prompts= json.load(f)
    
    run_classes = list(range( args.start_class_id ,args.end_class_id))
    print(run_classes)
    
    for id, class_id in enumerate(run_classes):
        folder_path=os.path.join(args.save_folder,f'{args.dataset}_randomscale', str(class_id))
        if not os.path.exists(folder_path):
            os.makedirs(folder_path)
        per_class_prompt_items= cls_to_prompts[str(class_id)]
        
        prompts=[]
        for i in  range(len(per_class_prompt_items)):
            try: # to remove some specifal words
                fact= per_class_prompt_items[i]["gpt_fact"].replace('\n','').replace('\"','').replace("\\",'') 
            except:
                fact = per_class_prompt_items[i]["fact"]
            prompts.append(fact)

        sample_ids= [int(per_class_prompt_items[i]["id"]) for i in  range(len(per_class_prompt_items))]  
        print(class_id, prompts, len(prompts))

        for prompt_id in range(len(prompts)):
            prompt=prompts[prompt_id]
            prompt_for_fname= prompt.replace(' ','_').replace('/','_')
            if len(prompt_for_fname)>100:
                prompt_for_fname=prompt_for_fname[:99]

            for repeat_num in range(num_per_sentence):
                fname= "{}_{}_repeat{}.png".format(sample_ids[prompt_id],prompt_for_fname,repeat_num)
                save_path = os.path.join(folder_path, fname)
                if os.path.exists(save_path):
                    print(f"Image {save_path} already exists, skipping...")
                    continue
                guidance_scale=random.choice(range(1,9))
                image= text_to_image(prompt, guidance_scale) # size 512 x 512 
                image.save(save_path)
                # print(f"Image {save_path} saved")

