from diffusers import StableDiffusionPipeline
import torch
import os
import json
from math import ceil, sqrt
from PIL import Image
from utils import save_image, concat_images_in_square_grid, TaskVector
import argparse
import torch.nn.functional as F
import open_clip
import time
from tqdm import tqdm

#add parser function
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_pretrained', type=str, default="stabilityai/stable-diffusion-2", help='pretrained model')
    parser.add_argument('--model_finetuned', type=str, default="", help='finetuned model')
    parser.add_argument('--inversed_models', nargs='+', type=str, help='list of inversed model')
    parser.add_argument('--output_dir', type=str, default="diffusers_ckpt/output", help='output directory')
    parser.add_argument('--seed', type=int, default=42, help='random seed')
    parser.add_argument('--initial_prompt', type=str, default="", help='initial prompt')
    parser.add_argument('--pruning_criteria', type=str, default="block", help='pruning criteria', choices=["block", "value"])
    return parser.parse_args()

if __name__ == "__main__":
    args = parse_args()

    device = "cuda" if torch.cuda.is_available() else "cpu"
    pipe_pretrained = StableDiffusionPipeline.from_pretrained(args.model_pretrained, torch_dtype=torch.float16, safety_checker=None)
    pipe_pretrained.to(device)
    gen = torch.Generator(device=device)
    gen.manual_seed(args.seed)

    os.makedirs(args.output_dir, exist_ok=True)
    
    pipe_finetuned = StableDiffusionPipeline.from_pretrained(args.model_finetuned, torch_dtype=torch.float16, safety_checker=None)
    pipe_finetuned.to(device)

    #edit process
    unet_pretrained = pipe_pretrained.unet
    unet_finetuned = pipe_finetuned.unet

    curr_time = int(time.time())
    #save model unet
    torch.save(unet_pretrained, f"unet_pretrained_{int(curr_time)}.pt")
    torch.save(unet_finetuned, f"unet_finetuned_{int(curr_time)}.pt")
            
    task_vector_unet = TaskVector(pretrained_checkpoint=f"unet_pretrained_{int(curr_time)}.pt", 
                            finetuned_checkpoint=f"unet_finetuned_{int(curr_time)}.pt")
        
    dict_ = {}
    for i, (n, k) in enumerate(task_vector_unet.vector.items()):
        val = torch.mean(torch.abs(k.flatten())).item()
        dict_[n] = -val
            
    #pick n keys with highest values
    def pick_n_highest(d, n):
        return dict(sorted(d.items(), key=lambda item: item[1], reverse=True)[:n])
    
    def set_zero(task_vector, dict_, n):
        task_vector_cp = TaskVector(vector=task_vector.vector.copy())
        max_name = pick_n_highest(dict_, n)
        for name in max_name:
            task_vector_cp.vector[name] = torch.zeros(task_vector_cp.vector[name].shape).to(device)  
        return task_vector_cp
    
    def set_zero_block(task_vector, block_l):
        task_vector_cp = TaskVector(vector=task_vector.vector.copy())
        for key, value in task_vector_cp.vector.items():
            for block in block_l:
                if block in key:
                    task_vector_cp.vector[key] = torch.zeros(task_vector_cp.vector[key].shape).to(device)
        return task_vector_cp
    
    target_task_prompt = [args.initial_prompt]
    for i in range(1, 21):
        target_task_prompt.append(f"a photo of <object-{i}>")
        
    control_task_prompt = ["a painting in the style of Kilian Eng", "a painting in the style of Picasso", 
                           "a photo of a garbage truck", "a photo of a chain saw",
                           "a photo of Brad Pitt", "a photo of Angelina Jolie"]
    
    alphas = [round(0.5 + 0.25 * i, 2) for i in range(20)]
        
    model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')
    tokenizer = open_clip.get_tokenizer('ViT-B-32')
    model.to(device)

    #create folder for control task prompt
    os.makedirs(os.path.join(args.output_dir, "control_task_prompt"), exist_ok=True)
    
    #create folder for target task prompt
    os.makedirs(os.path.join(args.output_dir, "target_task_prompt"), exist_ok=True)
    
    if args.pruning_criteria == "block":
        # pruning_thres_list = [["down_blocks.0"],
        #                       ["down_blocks.0", "down_blocks.1"],
        #                       ["down_blocks.0", "down_blocks.1", "down_blocks.2"],
        #                       ["down_blocks.0", "down_blocks.1", "down_blocks.2", "down_blocks.3"],
        #                       ["down_blocks.0", "down_blocks.1", "down_blocks.2", "down_blocks.3", "up_blocks.0"],
        #                       ["down_blocks.0", "down_blocks.1", "down_blocks.2", "down_blocks.3", "up_blocks.0", "up_blocks.1"]]
        
        pruning_thres_list = [[],
                        ["down_blocks.0", "down_blocks.1", "down_blocks.2"],
                        ["down_blocks.0", "down_blocks.1", "down_blocks.2", "down_blocks.3"],
                        ["down_blocks.0", "down_blocks.1", "down_blocks.2", "down_blocks.3", "up_blocks.0"],
                        ["down_blocks.0", "down_blocks.1", "down_blocks.2", "down_blocks.3", "up_blocks.0", "up_blocks.1"]]
        
    elif args.pruning_criteria == "value":
        pruning_thres_list = list(range(100, 500, 50))
    
    for pruning_thres in pruning_thres_list:
        avg_score_target = {}
        avg_score_control = {}
        max_score_target = {}
        max_score_control = {}
        
        neg_task_vector_unet = -task_vector_unet
        
        if(args.pruning_criteria == "block"):
            neg_task_vector_unet = set_zero_block(neg_task_vector_unet, pruning_thres)
        elif(args.pruning_criteria == "value"):
            neg_task_vector_unet = set_zero(neg_task_vector_unet, dict_, pruning_thres)
        
        for alpha in alphas:
            score_target = []
            score_control = []
                
            unet_edited = neg_task_vector_unet.apply_to(f"unet_pretrained_{int(curr_time)}.pt", scaling_coef=alpha)
            
            pipe_pretrained.unet = unet_edited

            pipe_finetuned = pipe_pretrained
            pipe_finetuned.set_progress_bar_config(disable=True)
            
            #generating target task prompt and store avg score
            print("Alpha: ", alpha)
            print("Calculating target task score")
            for i, inversed_models_path in enumerate(tqdm(args.inversed_models)):
                pipe_inversed = StableDiffusionPipeline.from_pretrained(inversed_models_path, torch_dtype=torch.float16, safety_checker=None).to(device)
                pipe_finetuned.tokenizer = pipe_inversed.tokenizer
                pipe_finetuned.text_encoder = pipe_inversed.text_encoder
                
                text = tokenizer([target_task_prompt[0]]).to(device)
                text_features = model.encode_text(text)
                text_features /= text_features.norm(dim=-1, keepdim=True)
                
                images = pipe_finetuned(prompt=target_task_prompt, generator=gen, guidance_scale=7.5).images
                for j, img in enumerate(images):
                    img.save(os.path.join(args.output_dir, "target_task_prompt", f"{target_task_prompt[j]}_{i}_{j}.png"))
                    
                    img = Image.open(os.path.join(args.output_dir, "target_task_prompt", f"{target_task_prompt[j]}_{i}_{j}.png"))
                    img_input = preprocess(img).unsqueeze(0).to(device)
                    
                    with torch.no_grad():
                        image_features = model.encode_image(img_input)
                        image_features /= image_features.norm(dim=-1, keepdim=True)
                    
                    score_target.append(F.cosine_similarity(image_features, text_features).item())   
                        

            avg_score_target[alpha] = sum(score_target)/len(score_target)
            max_score_target[alpha] = max(score_target)
            
            #resetting text_encoder and tokenizer
            pipe_finetuned.tokenizer = pipe_pretrained.tokenizer
            pipe_finetuned.text_encoder = pipe_pretrained.text_encoder
            
            print("Calculating control task score")
            #generating control task prompt and store avg score
            for i, p in enumerate(tqdm(control_task_prompt)):
                text = tokenizer([p]).to(device)
                text_features = model.encode_text(text)
                text_features /= text_features.norm(dim=-1, keepdim=True)

                images = pipe_finetuned(prompt=[p]*7, generator=gen, guidance_scale=7.5).images
                for j, img in enumerate(images):
                    img.save(os.path.join(args.output_dir, "control_task_prompt", f"{p}_{i}.png"))
                    img = Image.open(os.path.join(args.output_dir, "control_task_prompt", f"{p}_{i}.png"))
                    img_input = preprocess(img).unsqueeze(0).to(device)
                
                    with torch.no_grad():
                        image_features = model.encode_image(img_input)
                        image_features /= image_features.norm(dim=-1, keepdim=True)
                    
                    score_control.append(F.cosine_similarity(image_features, text_features).item())   
            
            avg_score_control[alpha] = sum(score_control)/len(score_control)
            max_score_control[alpha] = max(score_control)

            print("Saving avg score for alpha: ", alpha)
            #save average scores as json every alpha
            name = "avg_score_target_{}_{}_{}.json".format(args.output_dir.split('/')[1], args.pruning_criteria, pruning_thres)
            file_name = os.path.join("./", name)
            with open(file_name, "w") as f:
                json.dump(avg_score_target, f)
                
            name = "avg_score_control_{}_{}_{}.json".format(args.output_dir.split('/')[1], args.pruning_criteria, pruning_thres)
            file_name = os.path.join("./", name)
            with open(file_name, "w") as f:
                json.dump(avg_score_control, f)
                
            name = "max_score_target_{}_{}_{}.json".format(args.output_dir.split('/')[1], args.pruning_criteria, pruning_thres)
            file_name = os.path.join("./", name)
            with open(file_name, "w") as f:
                json.dump(max_score_target, f)
                
            name = "max_score_control_{}_{}_{}.json".format(args.output_dir.split('/')[1], args.pruning_criteria, pruning_thres)
            file_name = os.path.join("./", name)
            with open(file_name, "w") as f:
                json.dump(max_score_control, f)
    
    os.remove(f"unet_finetuned_{int(curr_time)}.pt")
    os.remove(f"unet_pretrained_{int(curr_time)}.pt")