import os
import argparse
import torch
from pathlib import Path
import subprocess
import diffusers
from diffusers import StableDiffusionPipeline
import torch
import sys


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--subject_name', type=str, default='dog')
    parser.add_argument('--class_name', type=str, default='dog')
    parser.add_argument('--use_natural_loss', type=bool, default=False)

    opt = parser.parse_args()

    print('subject name', opt.subject_name)
    Path(f"/path/diffusers/examples/dreambooth/Exp_Dreambooth/{opt.subject_name}").mkdir(parents=True, exist_ok=True)
    Path(f"/path/diffusers/examples/dreambooth/generated_samples_10/{opt.subject_name}").mkdir(parents=True, exist_ok=True)
    Path(f"path/diffusers/examples/dreambooth/Exp_Dreambooth/{opt.subject_name}/output").mkdir(parents=True, exist_ok=True)

    
    
    INSTANCE_DIR=f"path/diffusers/examples/dreambooth/dataset/{opt.subject_name}"
    CLASS_DIR=f"path/diffusers/examples/dreambooth/generated_samples_10/{opt.subject_name}"
    instance_prompt = f"a photo of sks {opt.class_name}"
    class_prompt = f"a photo of {opt.class_name}"
    

    # get the prompt list
    prompts = ["'a {0} {1} in the jungle'",
    "'a {0} {1} in the snow'",
    "'a {0} {1} on the beach'",
    "'a {0} {1} on a cobblestone street'",
    "'a {0} {1} on top of pink fabric'",
    "'a {0} {1} on top of a wooden floor'",
    "'a {0} {1} with a city in the background'",
    "'a {0} {1} with a mountain in the background'",
    "'a {0} {1} with a blue house in the background'",
    "'a {0} {1} on top of a purple rug in a forest'",
    "'a {0} {1} with a wheat field in the background'",
    "'a {0} {1} with a tree and autumn leaves in the background'",
    "'a {0} {1} with the Eiffel Tower in the background'",
    "'a {0} {1} floating on top of water'",
    "'a {0} {1} floating in an ocean of milk'",
    "'a {0} {1} on top of green grass with sunflowers around it'",
    "'a {0} {1} on top of a mirror'",
    "'a {0} {1} on top of the sidewalk in a crowded street'",
    "'a {0} {1} on top of a dirt road'",
    "'a {0} {1} on top of a white rug'",
    "'a red {0} {1}'",
    "'a purple {0} {1}'",
    "'a shiny {0} {1}'",
    "'a wet {0} {1}'",
    "'a cube shaped {0} {1}'"]

    # train the model
    if opt.use_natural_loss:

        # create a directory to generate images for Dreambooth
        Path("path/diffusers/examples/dreambooth/Exp_Dreambooth_Nat").mkdir(parents=True, exist_ok=True)
        OUTPUT_DIR=f"path/diffusers/examples/dreambooth/Exp_Dreambooth_Nat/{opt.subject_name}/output"


        subprocess.call( ['python', 'train_dreambooth_natural_constraint.py', \
        '--pretrained_model_name_or_path', 'CompVis/stable-diffusion-v1-4', \
        '--instance_data_dir', INSTANCE_DIR, \
        '--class_data_dir', CLASS_DIR, \
        '--output_dir', OUTPUT_DIR, \
        '--with_prior_preservation', '--prior_loss_weight', '1.0', \
        '--instance_prompt', instance_prompt, \
        '--class_prompt', class_prompt, \
        '--resolution', '512', \
        '--train_batch_size', '1', \
        '--gradient_accumulation_steps', '1', '--gradient_checkpointing', \
        '--learning_rate', '5e-6', \
        '--lr_scheduler', 'constant' ,\
        '--lr_warmup_steps', '0' ,\
        '--num_class_images', '10', \
        '--max_train_steps', '800', \
        '--mixed_precision','fp16'], )

        for i in range(len(prompts)):
            # crate directory for each prompt
            Path(f"path/diffusers/examples/dreambooth/Exp_Dreambooth_Nat/{opt.subject_name}/prompt_{i}").mkdir(parents=True, exist_ok=True)
            
            SAVE_DIR=f"path/diffusers/examples/dreambooth/Exp_Dreambooth_Nat/{opt.subject_name}/prompt_{i}"
                
            # inference
            #model_id = OUTPUT_DIR
            pipe = StableDiffusionPipeline.from_pretrained(OUTPUT_DIR, torch_dtype=torch.float16, ).to("cuda")
            
            prompt = "Photo of " + prompts[i].replace('{0}', 'sks').replace('{1}', opt.subject_name)
            # generate 4 images 

            for j in range(4):
                image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
                image.save(SAVE_DIR+'/'+ str(j)+'.png')
    
    else:

        # output directory to store the trained model
        OUTPUT_DIR=f"path/diffusers/examples/dreambooth/Exp_Dreambooth/{opt.subject_name}/output"

        subprocess.call( ['python', 'train_dreambooth.py', \
        '--pretrained_model_name_or_path', 'CompVis/stable-diffusion-v1-4', \
        '--instance_data_dir', INSTANCE_DIR, \
        '--class_data_dir', CLASS_DIR, \
        '--output_dir', OUTPUT_DIR, \
        '--with_prior_preservation', '--prior_loss_weight', '1.0', \
        '--instance_prompt', instance_prompt, \
        '--class_prompt', class_prompt, \
        '--resolution', '512', \
        '--train_batch_size', '1', \
        '--gradient_accumulation_steps', '1', '--gradient_checkpointing', \
        '--learning_rate', '5e-6', \
        '--lr_scheduler', 'constant' ,\
        '--lr_warmup_steps', '0' ,\
        '--num_class_images', '10', \
        '--max_train_steps', '800', \
        '--mixed_precision','fp16'], )


        for i in range(len(prompts)):
            # crate directory for each prompt
            Path(f"path/diffusers/examples/dreambooth/Exp_Dreambooth/{opt.subject_name}/prompt_{i}").mkdir(parents=True, exist_ok=True)
            
            SAVE_DIR=f"path/diffusers/examples/dreambooth/Exp_Dreambooth/{opt.subject_name}/prompt_{i}"
            
            # inference
            pipe = StableDiffusionPipeline.from_pretrained(OUTPUT_DIR, torch_dtype=torch.float16, ).to("cuda")
            
            prompt = "Photo of " + prompts[i].replace('{0}', 'sks').replace('{1}', opt.subject_name)
            # generate 4 images 

            for j in range(4):
                image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
                image.save(SAVE_DIR+'/'+ str(j)+'.png')
        



