import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import os
import argparse
from diffusers import StableDiffusionPipeline           
import numpy as np  
from torchvision import transforms      
from utils import normalize, set_dataset_specs
from tqdm import tqdm   
from PIL import ImageFile
from synthesize import chop_the_patches
import copy

ImageFile.LOAD_TRUNCATED_IMAGES = True

@torch.no_grad()        
def run_inference(rank, world_size, prompt_addresses, syn_save_folders, 
                  place_holders, ipc, offset, gen_seed, ignore_embs=False, prompt_template="A photo of a "):
    

    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)

    rnd_gen= torch.Generator(device="cuda")     
    rnd_gen.manual_seed(gen_seed)       


    pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
    pipeline.safety_checker = None
    pipeline.requires_safety_checker = False        
                                                        
    pipeline.to('cuda')
    
    all_prompt_num = len(prompt_addresses)      
    divisions = np.array_split(np.arange(all_prompt_num), world_size)       

    prompt_addresses = [prompt_addresses[i] for i in divisions[rank]]   
    place_holders = [place_holders[i] for i in divisions[rank]] 
    syn_save_folders = [syn_save_folders[i] for i in divisions[rank]]

    class_names = [os.path.basename(os.path.normpath(s)) for s in syn_save_folders]       

    for p_ind, prompt_emb_add in enumerate(prompt_addresses):     
        if ignore_embs is True: 
            stirp_place_holder = place_holders[p_ind].replace("<", "").replace(">", "").replace(" ", "_")          
            prompt = f"{prompt_template}{stirp_place_holder}"
            print(f"Ignored embs Prompt: {prompt}")      

        else:
            pipeline.load_textual_inversion(prompt_emb_add)   
            prompt = f"A photo of {place_holders[p_ind]}"       

        img_lst = []        
        for infer_num in range(2):
            rnd_gen= torch.Generator(device="cuda")     
            rnd_gen.manual_seed(gen_seed+infer_num)     
            images = pipeline(prompt=prompt, num_inference_steps=50, num_images_per_prompt=ipc//2, generator=rnd_gen).images
            for img_cnt, img in enumerate(images):  
                img_lst.append(copy.deepcopy(img))    
            

        for img_cnt, img in enumerate(img_lst):      
            save_path = os.path.join(syn_save_folders[p_ind], f"img_{str(img_cnt+offset).zfill(4)}.jpg")           
            img.save(save_path)         

    del pipeline, images, img, prompt_addresses, syn_save_folders, place_holders, class_names   
    torch.cuda.empty_cache()    


def synth_dataset_parallel(args, gen_seed):        

    prompt_cls = sorted([n for n in os.listdir(args.emb_root) if n.startswith(".") == False])       

    prompt_addresses = []       
    for cls in prompt_cls:    
        prompt_address = os.path.join(args.emb_root, cls, "learned_embeds.safetensors")      
        prompt_addresses.append(prompt_address)     


    syn_save_folders = []       
    for cls in prompt_cls:    
        syn_save_folder = os.path.join(args.collage_save_dir, cls)      
        os.makedirs(syn_save_folder, exist_ok=True)       
        syn_save_folders.append(syn_save_folder)        

    #read text file line by line    
    with open(args.place_holder_path, "r") as f:        
        place_holders = f.readlines()       
        place_holders = [ph.strip() for ph in place_holders] #remove \n from each line

    torch.cuda.empty_cache()    

    mp.spawn(run_inference, args=(args.ngpu, prompt_addresses.copy(), 
                syn_save_folders.copy(), place_holders.copy(), args.ipc, args.offset, gen_seed, args.ignore_embs, args.prompt_template),
                nprocs=args.ngpu, join=True)
    

    if dist.is_initialized():
        dist.destroy_process_group()

    torch.cuda.empty_cache()    
    
    # chop_the_patches(args)
    
    
if __name__ == "__main__":
    mp.set_start_method('fork')
    parser = argparse.ArgumentParser()        
   
    parser.add_argument("--chopped_save_dir", type=str)   
    parser.add_argument("--emb_root", type=str, required=True)  
    parser.add_argument("--collage_save_dir", type=str, required=True)
    parser.add_argument("--subset", type=str, default="imagenette")     
    parser.add_argument("--place_holder_path", type=str, required=True)     

    parser.add_argument("--ipc", type=int, default=10)      
    parser.add_argument("--ngpu", type=int, default=4)      
    parser.add_argument("--diff_input_size", type=int, default=512)      
    parser.add_argument("--input_size", type=int, default=512)      
    
    parser.add_argument("--factor", type=int, default=4)        
    parser.add_argument("--st_cls", type=int, default=0)    
    parser.add_argument("--end_cls", type=int, default=None)    
    parser.add_argument("--nclass", type=int, default=1000)
    parser.add_argument("--classes", type=list)
    parser.add_argument("--offset", type=int, default=0)
    parser.add_argument("--ignore_embs", action="store_true", default=False)    
    parser.add_argument("--prompt_template", type=str, default="A photo of a ")   
    
    

    args = parser.parse_args() 

    set_dataset_specs(args)     

    # synth_dataset_parallel(args, gen_seed=0)     
    chop_the_patches(args)
