import torch
import os
import argparse
def load_state_dict(model, state_dict, prefix='', ignore_missing="relative_position_index"):
    missing_keys = []
    unexpected_keys = []
    error_msgs = []
    # copy state_dict so _load_from_state_dict can modify it
    metadata = getattr(state_dict, '_metadata', None)
    state_dict = state_dict.copy()
    if metadata is not None:
        state_dict._metadata = metadata

    def load(module, prefix=''):
        local_metadata = {} if metadata is None else metadata.get(
            prefix[:-1], {})
        module._load_from_state_dict(
            state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
        for name, child in module._modules.items():
            if child is not None:
                load(child, prefix + name + '.')

    load(model, prefix=prefix)

    warn_missing_keys = []
    ignore_missing_keys = []
    for key in missing_keys:
        keep_flag = True
        for ignore_key in ignore_missing.split('|'):
            if ignore_key in key:
                keep_flag = False
                break
        if keep_flag:
            warn_missing_keys.append(key)
        else:
            ignore_missing_keys.append(key)

    missing_keys = warn_missing_keys

    if len(missing_keys) > 0:
        print("Weights of {} not initialized from pretrained model: {}".format(
            model.__class__.__name__, missing_keys))
    if len(unexpected_keys) > 0:
        print("Weights from pretrained model not used in {}: {}".format(
            model.__class__.__name__, unexpected_keys))
    if len(ignore_missing_keys) > 0:
        print("Ignored weights of {} not initialized from pretrained model: {}".format(
            model.__class__.__name__, ignore_missing_keys))
    if len(error_msgs) > 0:
        print('\n'.join(error_msgs))
        
def parse_args():
    parser = argparse.ArgumentParser(description="Simple example of a training script.")
    parser.add_argument(
        "--output_dir",
        type=str,
        default="sd-model-finetuned-lora",
        help="The output directory where the model predictions and checkpoints will be written.",
    )
    parser.add_argument(
        "--round",
        type=int,
        default=0,
        help="The output directory where the model predictions and checkpoints will be written.",
    )
    
    parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")

    args = parser.parse_args()
    env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
    if env_local_rank != -1 and env_local_rank != args.local_rank:
        args.local_rank = env_local_rank

    return args

import torch
from diffusers import StableDiffusionPipeline
from peft import PeftConfig, PeftModel
from peft import LoraConfig, get_peft_model

args = parse_args()

#output_dir = args.output_dir
output_dir = os.path.join(args.output_dir,'round'+str(args.round))

#concepts = os.listdir("./data/celebs")
#concepts.remove('paths.txt')
#client_num = 10
#client_concepts = ['Elon Musk','Donald Trump','Barack Obama','Tom Hiddleston','Rihanna','Arnold Schwarzenegger','Tom Cruise','Leonardo Dicaprio','Andrew Garfield','Joe Biden']

concepts = os.listdir("./data/artists")
concepts.remove('paths.txt')
client_num = 10
client_concepts = ['Vincent van Gogh','Leonardo da Vinci','Claude Monet','Wassily Kandinsky','J.M.W. Turner','Albrecht Anker','Francisco Goya','Henri Matisse','Hilma af Klint','Paul Gauguin']

weights = [1/len(client_concepts)]*len(client_concepts)

'''
pipe = StableDiffusionPipeline.from_pretrained("/home/yangmingzhao/2024_5/sd_v1_5", torch_dtype=torch.float16)
text_encoder = pipe.text_encoder
pipe.text_encoder = PeftModel.from_pretrained(text_encoder, os.path.join(output_dir,client_concepts[0]))
agg_text_encoder_state_dict = pipe.text_encoder.state_dict()
del pipe

for st in agg_text_encoder_state_dict:
    agg_text_encoder_state_dict[st] = agg_text_encoder_state_dict[st]*weights[0]
                    
for i in range(1,len(client_concepts)):
    print(i)
    pipe = StableDiffusionPipeline.from_pretrained("/home/yangmingzhao/2024_5/sd_v1_5", torch_dtype=torch.float16)
    text_encoder = pipe.text_encoder
    pipe.text_encoder = PeftModel.from_pretrained(text_encoder, os.path.join(output_dir,client_concepts[i]))
    client_text_encoder_state_dict = pipe.text_encoder.state_dict()
    del pipe
    
    for st in agg_text_encoder_state_dict:
        agg_text_encoder_state_dict[st] += client_text_encoder_state_dict[st]*weights[i]   
        

torch.save(agg_text_encoder_state_dict,os.path.join(output_dir,"agg_text_encoder.ckpt"))
'''

agg_unet_state_dict = torch.load(os.path.join(output_dir+"/"+client_concepts[0],"checkpoint-125.ckpt"),map_location = 'cpu')
for st in agg_unet_state_dict:
    agg_unet_state_dict[st] = agg_unet_state_dict[st]*weights[0]
    
    
for i in range(1,len(client_concepts)):
    print(i)
    client_unet_state_dict = torch.load(os.path.join(output_dir+"/"+client_concepts[i],"checkpoint-125.ckpt"),map_location = 'cpu')
    
    for st in agg_unet_state_dict:
        agg_unet_state_dict[st] += client_unet_state_dict[st]*weights[i]  
        
torch.save(agg_unet_state_dict,os.path.join(output_dir,"agg_unet.ckpt"))    