from diffusers import StableDiffusionPipeline
import torch
import os
import json
from math import ceil, sqrt
from PIL import Image
from utils import save_image, concat_images_in_square_grid, get_random_prompt, get_clip_score
import argparse

#add parser function
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_pretrained', type=str, default="CompVis/stable-diffusion-v1-4", help='pretrained model')
    parser.add_argument('--prompts', nargs='+', type=str, help='edit prompt')
    parser.add_argument('--num_images', type=int, default=30, help='number of images')
    parser.add_argument('--output_dir', type=str, default="diffusers_ckpt/output", help='output directory')
    parser.add_argument('--clip_filtering_threshold', type=float, default=0.25, help='clip filtering threshold')
    parser.add_argument('--clip_filtering_tolerance', type=int, default=10, help='how many iterations before continue')
    parser.add_argument('--seed', type=int, default=0, help='random seed')

    #create a store_true argument
    parser.add_argument('--create_metadata', action='store_true', help='if set, create json file with metadata')
    parser.add_argument('--create_grid', action='store_true', help='if set, create grid of images')
    parser.add_argument('--random_prompt', action='store_true', help='if set, select random prompt from prompts.json file')
    parser.add_argument('--clip_filtering', action='store_true', help='filter images based on clip similarity')
    
    #esd arguments
    parser.add_argument('--esd_checkpoint', type=str, default="", help='esd checkpoint')

    return parser.parse_args()

if __name__ == "__main__":
    args = parse_args()

    device = "cuda" if torch.cuda.is_available() else "cpu"

    gen = torch.Generator(device=device)
    gen.manual_seed(args.seed)

    pipe_pretrained = StableDiffusionPipeline.from_pretrained(args.model_pretrained, torch_dtype=torch.float16, safety_checker=None)

    if(args.esd_checkpoint != ""):
        pipe_pretrained.unet.load_state_dict(torch.load(args.esd_checkpoint))

    pipe_pretrained.to(device)

    os.makedirs(args.output_dir, exist_ok=True)
    os.makedirs(os.path.join(args.output_dir, "train"), exist_ok=True)

    print("Generating images of pretrained model")
    print("Edit prompt: ", args.prompts)

    if(args.create_metadata):
        metadata = []

    if(args.random_prompt):
        ''''
        if random prompt is selected, just generate args.num_images images with random prompts
        '''

        for i in range(4370,args.num_images):
            while(True):
                prompt = get_random_prompt(p)
                nsfw = save_image(pipe_pretrained, get_random_prompt(p), os.path.join(args.output_dir, f"train/{prompt}_{i}.png"), gen=gen)
            
                if(args.clip_filtering):
                    history = {}
                    tolerance = 0
                    while(tolerance < args.clip_filtering_tolerance):
                        clip_score = get_clip_score(p, os.path.join(args.output_dir, f"train/{prompt}_{i}.png"))
                        
                        print(f"Prompt: {prompt}, clip score: {clip_score}")
                        if(clip_score > args.clip_filtering_threshold):    
                            print("Accepting image") 
                            #delete all images in history
                            for k, v in history.items():
                                print(f"Deleting image with clip score {v}")
                                os.remove(k)
                            break
                        else:
                            print("Rejecting image")
                            #rename image with clip score and save it in history
                            history[os.path.join(args.output_dir, f"train/{prompt}_{i}_{clip_score}.png")] = clip_score
                            os.rename(os.path.join(args.output_dir, f"train/{prompt}_{i}.png"), os.path.join(args.output_dir, f"train/{prompt}_{i}_{clip_score}.png"))
                            
                            #generate new image
                            nsfw = save_image(pipe_pretrained, prompt, os.path.join(args.output_dir, f"train/{prompt}_{i}.png"), gen=gen)
                            tolerance += 1     

                    #if tolerance is reached, select the image with the highest clip score and rename it, delete the others
                    if(tolerance == args.clip_filtering_tolerance):
                        print("Tolerance reached")
                        max_clip_score = max(history.values())
                        for k, v in history.items():
                            if(v == max_clip_score):
                                print(f"Accepting image with clip score {v}")
                                os.rename(k, os.path.join(args.output_dir, f"train/{prompt}_{i}.png"))
                            else:
                                print(f"Deleting image with clip score {v}")
                                os.remove(k)

                #check if nsfw is a list
                if isinstance(nsfw, list):
                    nsfw = nsfw[0]

                if not nsfw:
                    break

            if(args.create_metadata):
                metadata.append({'file_name': f"train/{prompt}_{i}.png", 'text': prompt})    
    
    else:
        '''
        if we use the preset prompts, we generate args.num_images/len(args.prompts) images for each prompt
        '''

        assert args.num_images % len(args.prompts) == 0, "Number of images must be divisible by number of prompts"

        for i in range(args.num_images // len(args.prompts)):
            
            output = pipe_pretrained(prompt=args.prompts, generator=gen, guidance_scale=7.5).images
            
            for j, image in enumerate(output):
                image.save(os.path.join(args.output_dir, f"train/{args.prompts[j]}_{i}.png"))

                if(args.create_metadata):
                    metadata.append({'file_name': f"train/{args.prompts[j]}_{i}.png", 'text': args.prompts[j]})
        
    if(args.create_grid):
        for p in args.prompts:
            concat_images_in_square_grid(os.path.join(args.output_dir, "train"), p, os.path.join(args.output_dir, f"grid {p}.png"))
    
    if(args.create_metadata):
        
        #save metadata to jsonl file
        with open(os.path.join(args.output_dir, 'metadata.jsonl'), 'w') as f:
            for m in metadata:
                f.write(json.dumps(m) + "\n")