import torch
import os
import argparse
def load_state_dict(model, state_dict, prefix='', ignore_missing="relative_position_index"):
    missing_keys = []
    unexpected_keys = []
    error_msgs = []
    # copy state_dict so _load_from_state_dict can modify it
    metadata = getattr(state_dict, '_metadata', None)
    state_dict = state_dict.copy()
    if metadata is not None:
        state_dict._metadata = metadata

    def load(module, prefix=''):
        local_metadata = {} if metadata is None else metadata.get(
            prefix[:-1], {})
        module._load_from_state_dict(
            state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
        for name, child in module._modules.items():
            if child is not None:
                load(child, prefix + name + '.')

    load(model, prefix=prefix)

    warn_missing_keys = []
    ignore_missing_keys = []
    for key in missing_keys:
        keep_flag = True
        for ignore_key in ignore_missing.split('|'):
            if ignore_key in key:
                keep_flag = False
                break
        if keep_flag:
            warn_missing_keys.append(key)
        else:
            ignore_missing_keys.append(key)

    missing_keys = warn_missing_keys

    if len(missing_keys) > 0:
        print("Weights of {} not initialized from pretrained model: {}".format(
            model.__class__.__name__, missing_keys))
    if len(unexpected_keys) > 0:
        print("Weights from pretrained model not used in {}: {}".format(
            model.__class__.__name__, unexpected_keys))
    if len(ignore_missing_keys) > 0:
        print("Ignored weights of {} not initialized from pretrained model: {}".format(
            model.__class__.__name__, ignore_missing_keys))
    if len(error_msgs) > 0:
        print('\n'.join(error_msgs))
        
def parse_args():
    parser = argparse.ArgumentParser(description="Simple example of a training script.")
    parser.add_argument(
        "--ckpt_dir",
        type=str,
        default="sd-model-finetuned-lora",
        help="The output directory where the model predictions and checkpoints will be written.",
    )
    
    parser.add_argument(
        "--output_dir",
        type=str,
        default="sd-model-finetuned-lora",
        help="The output directory where the model predictions and checkpoints will be written.",
    )
    parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
    
    parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")

    args = parser.parse_args()
    env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
    if env_local_rank != -1 and env_local_rank != args.local_rank:
        args.local_rank = env_local_rank

    return args      

import torch
from diffusers import StableDiffusionPipeline
from peft import PeftConfig, PeftModel
from peft import LoraConfig, get_peft_model
import random
args = parse_args()
ckpt_dir = args.ckpt_dir

image_num = 20
seed = args.seed
output_dir = os.path.join(args.output_dir,'agg')

concepts = os.listdir("./data/celebs")
concepts.remove('paths.txt')
client_num = 50
client_concepts = concepts[:50]
#client_concepts = ['Elon Musk','Donald Trump','Barack Obama','Tom Hiddleston','Rihanna','Arnold Schwarzenegger','Tom Cruise','Leonardo Dicaprio','Andrew Garfield','Joe Biden']

#concepts = os.listdir("./data/artists")
#concepts.remove('paths.txt')
#client_num = 10
#client_concepts = ['Vincent van Gogh','Leonardo da Vinci','Claude Monet','Wassily Kandinsky','J.M.W. Turner','Albrecht Anker','Francisco Goya','Henri Matisse','Hilma af Klint','Paul Gauguin']


for idx in range(len(client_concepts)):
    concept = client_concepts[idx]
    
    random.seed(seed+idx)
    save_concept = concepts[random.randint(0,len(concepts))]
    
    prompt_forget = "an image of "+ concept
    prompt_save = "an image of "+ save_concept
    
    #prompt_forget = "An artwork in "+ concept +" style."
    #prompt_save = "An artwork in "+ save_concept +" style."

    
    out_dir = os.path.join(output_dir,concept)
    os.makedirs(out_dir, exist_ok=True)
    #pipe = StableDiffusionPipeline.from_pretrained("/home/yangmingzhao/2024_5/bk-sdm-tiny", torch_dtype=torch.float16)
    pipe = StableDiffusionPipeline.from_pretrained("/home/yangmingzhao/2024_5/sd_v1_5", torch_dtype=torch.float16)
    pipe = pipe.to("cuda")
    
    generator = torch.Generator(device="cuda").manual_seed(seed)
    prompt = prompt_forget
    for i in range(image_num):
        image = pipe(prompt,generator = generator).images[0]  
        image.save(os.path.join(out_dir,"ori_"+concept+'_'+str(i)+".png"))
        
    generator = torch.Generator(device="cuda").manual_seed(seed)
    prompt = prompt_save
    for i in range(image_num):
        image = pipe(prompt,generator = generator).images[0]  
        image.save(os.path.join(out_dir,"ori_"+save_concept+'_'+str(i)+".png"))
    '''
    pipe = StableDiffusionPipeline.from_pretrained("/home/yangmingzhao/2024_5/sd_v1_5", torch_dtype=torch.float16)
    text_encoder = pipe.text_encoder
    pipe.text_encoder = PeftModel.from_pretrained(text_encoder, os.path.join(ckpt_dir,concept))
    state_dict = torch.load(os.path.join(ckpt_dir,"agg_text_encoder.ckpt"),map_location = 'cpu')
    load_state_dict(pipe.text_encoder,state_dict)
    
    pipe = pipe.to("cuda")

    generator = torch.Generator(device="cuda").manual_seed(seed)
    prompt = prompt_forget
    for i in range(image_num):
        image = pipe(prompt,generator = generator).images[0]  
        image.save(os.path.join(out_dir,"text_only_"+concept+'_'+str(i)+".png"))

    generator = torch.Generator(device="cuda").manual_seed(seed)
    prompt = prompt_save
    for i in range(image_num):
        image = pipe(prompt,generator = generator).images[0]  
        image.save(os.path.join(out_dir,"text_only_"+save_concept+'_'+str(i)+".png"))

    pipe = StableDiffusionPipeline.from_pretrained("/home/yangmingzhao/2024_5/sd_v1_5", torch_dtype=torch.float16)
    text_encoder = pipe.text_encoder
    pipe.text_encoder = PeftModel.from_pretrained(text_encoder, os.path.join(ckpt_dir,concept))
    state_dict = torch.load(os.path.join(ckpt_dir,"agg_text_encoder.ckpt"),map_location = 'cpu')
    load_state_dict(pipe.text_encoder,state_dict)
    
    state_dict = torch.load(os.path.join(ckpt_dir,"agg_unet.ckpt"),map_location = 'cpu')
    unet_lora_config = LoraConfig(
        r=4,
        lora_alpha=2,
        init_lora_weights="gaussian",
        target_modules=["to_k", "to_q"],
    )
    pipe.unet.add_adapter(unet_lora_config) 
    load_state_dict(pipe.unet,state_dict)
    pipe = pipe.to("cuda")

    generator = torch.Generator(device="cuda").manual_seed(seed)
    prompt = prompt_forget
    for i in range(image_num):
        image = pipe(prompt,generator = generator).images[0]  
        image.save(os.path.join(out_dir,"unet_"+concept+'_'+str(i)+".png"))
        
    generator = torch.Generator(device="cuda").manual_seed(seed)
    prompt = prompt_save
    for i in range(image_num):
        image = pipe(prompt,generator = generator).images[0]  
        image.save(os.path.join(out_dir,"unet_"+save_concept+'_'+str(i)+".png"))

        '''
    pipe = StableDiffusionPipeline.from_pretrained("/home/yangmingzhao/2024_5/sd_v1_5", torch_dtype=torch.float16)
    state_dict = torch.load(os.path.join(ckpt_dir,"agg_unet.ckpt"),map_location = 'cpu')
    
    unet_lora_config = LoraConfig(
        r=4,
        lora_alpha=2,
        init_lora_weights="gaussian",
        target_modules=["to_k", "to_q"],
    )
    pipe.unet.add_adapter(unet_lora_config) 
    load_state_dict(pipe.unet,state_dict)
    pipe = pipe.to("cuda")

    generator = torch.Generator(device="cuda").manual_seed(seed)
    prompt = prompt_forget
    for i in range(image_num):
        image = pipe(prompt,generator = generator).images[0]  
        image.save(os.path.join(out_dir,"only_unet_"+concept+'_'+str(i)+".png"))
        
    generator = torch.Generator(device="cuda").manual_seed(seed)
    prompt = prompt_save
    for i in range(image_num):
        image = pipe(prompt,generator = generator).images[0]  
        image.save(os.path.join(out_dir,"only_unet_"+save_concept+'_'+str(i)+".png"))
