from diffusers import StableDiffusionPipeline
import torch
import os
import json
from math import ceil, sqrt
from PIL import Image
from utils import save_image, concat_images_in_square_grid, TaskVector
import argparse
import torch.nn.functional as F
import open_clip
import time
from tqdm import tqdm

#add parser function
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_pretrained', type=str, default="stabilityai/stable-diffusion-2", help='pretrained model')
    parser.add_argument('--model_finetuned', type=str, default="", help='finetuned model')
    parser.add_argument('--inversed_models', nargs='+', type=str, help='list of inversed model')
    parser.add_argument('--output_dir', type=str, default="diffusers_ckpt/output", help='output directory')
    parser.add_argument('--seed', type=int, default=42, help='random seed')
    parser.add_argument('--initial_prompt', type=str, default="", help='initial prompt')
    return parser.parse_args()

if __name__ == "__main__":
    args = parse_args()

    device = "cuda" if torch.cuda.is_available() else "cpu"
    pipe_pretrained = StableDiffusionPipeline.from_pretrained(args.model_pretrained, torch_dtype=torch.float16, safety_checker=None)
    pipe_pretrained.to(device)
    gen = torch.Generator(device=device)
    gen.manual_seed(args.seed)

    os.makedirs(args.output_dir, exist_ok=True)
    
    pipe_finetuned = StableDiffusionPipeline.from_pretrained(args.model_finetuned, torch_dtype=torch.float16, safety_checker=None)
    pipe_finetuned.to(device)

    #edit process
    unet_pretrained = pipe_pretrained.unet
    unet_finetuned = pipe_finetuned.unet

    curr_time = int(time.time())
    #save model unet
    torch.save(unet_pretrained, f"unet_pretrained_{int(curr_time)}.pt")
    torch.save(unet_finetuned, f"unet_finetuned_{int(curr_time)}.pt")
            
    task_vector_unet = TaskVector(pretrained_checkpoint=f"unet_pretrained_{int(curr_time)}.pt", 
                            finetuned_checkpoint=f"unet_finetuned_{int(curr_time)}.pt")
    
    dict_ = {}
    for i, (n, k) in enumerate(task_vector_unet.vector.items()):
        val = torch.mean(k.flatten()**2).item()
        dict_[n] = -val
            
    #pick n keys with highest values
    def pick_n_highest(d, n):
        return dict(sorted(d.items(), key=lambda item: item[1], reverse=True)[:n])
    
    def set_zero(task_vector, dict_, n):
        max_name = pick_n_highest(dict_, n)
        for name in max_name:
            task_vector_unet.vector[name] = torch.zeros(task_vector_unet.vector[name].shape).to(device)  
        return task_vector_unet
    
    def set_zero_block(task_vector, block_l):
        for key, value in task_vector_unet.vector.items():
            for block in block_l:
                if block in key:
                    task_vector_unet.vector[key] = torch.zeros(task_vector_unet.vector[key].shape).to(device)
        return task_vector_unet
    
    # target_task_prompt = ["a painting in the style of Van Gogh"]
    target_task_prompt = [args.initial_prompt]
    for i in range(1, 21):
        target_task_prompt.append(f"a photo of <object-{i}>")
        
    control_task_prompt = ["a painting in the style of Kilian Eng", "a painting in the style of Picasso", 
                           "a photo of a garbage truck", "a photo of a chain saw",
                           "a photo of Brad Pitt", "a photo of Angelina Jolie"]
    
    alphas = [round(0.5 + 0.25 * i, 2) for i in range(20)]
        
    model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')
    tokenizer = open_clip.get_tokenizer('ViT-B-32')
    model.to(device)

    #create folder for control task prompt
    os.makedirs(os.path.join(args.output_dir, "control_task_prompt"), exist_ok=True)
    
    #create folder for target task prompt
    os.makedirs(os.path.join(args.output_dir, "target_task_prompt"), exist_ok=True)
    
    avg_score_target = {}
    avg_score_control = {}
    max_score_target = {}
    max_score_control = {}
    
    for alpha in alphas:
        score_target = []
        score_control = []
        
        task_vector_unet = TaskVector(pretrained_checkpoint=f"unet_pretrained_{int(curr_time)}.pt", 
                            finetuned_checkpoint=f"unet_finetuned_{int(curr_time)}.pt")
        
        neg_task_vector_unet = -task_vector_unet
        unet_edited = neg_task_vector_unet.apply_to(f"unet_pretrained_{int(curr_time)}.pt", scaling_coef=alpha)
        
        pipe_pretrained.unet = unet_edited

        pipe_finetuned = pipe_pretrained
        pipe_finetuned.set_progress_bar_config(disable=True)
        
        #generating target task prompt and store avg score
        print("Alpha: ", alpha)
        print("Calculating target task score")
        for i, inversed_models_path in enumerate(tqdm(args.inversed_models)):
            pipe_inversed = StableDiffusionPipeline.from_pretrained(inversed_models_path, torch_dtype=torch.float16, safety_checker=None).to(device)
            pipe_finetuned.tokenizer = pipe_inversed.tokenizer
            pipe_finetuned.text_encoder = pipe_inversed.text_encoder
            
            text = tokenizer([target_task_prompt[0]]).to(device)
            text_features = model.encode_text(text)
            text_features /= text_features.norm(dim=-1, keepdim=True)
            
            images = pipe_finetuned(prompt=target_task_prompt, generator=gen, guidance_scale=7.5).images
            for j, img in enumerate(images):
                img.save(os.path.join(args.output_dir, "target_task_prompt", f"{target_task_prompt[j]}_{i}_{j}.png"))
                
                img = Image.open(os.path.join(args.output_dir, "target_task_prompt", f"{target_task_prompt[j]}_{i}_{j}.png"))
                img_input = preprocess(img).unsqueeze(0).to(device)
                
                with torch.no_grad():
                    image_features = model.encode_image(img_input)
                    image_features /= image_features.norm(dim=-1, keepdim=True)
                
                score_target.append(F.cosine_similarity(image_features, text_features).item())   
                    

        avg_score_target[alpha] = sum(score_target)/len(score_target)
        max_score_target[alpha] = max(score_target)
        
        #resetting text_encoder and tokenizer
        pipe_finetuned.tokenizer = pipe_pretrained.tokenizer
        pipe_finetuned.text_encoder = pipe_pretrained.text_encoder
        
        print("Calculating control task score")
        #generating control task prompt and store avg score
        for i, p in enumerate(tqdm(control_task_prompt)):
            text = tokenizer([p]).to(device)
            text_features = model.encode_text(text)
            text_features /= text_features.norm(dim=-1, keepdim=True)

            images = pipe_finetuned(prompt=[p]*7, generator=gen, guidance_scale=7.5).images
            for j, img in enumerate(images):
                img.save(os.path.join(args.output_dir, "control_task_prompt", f"{p}_{i}.png"))
                img = Image.open(os.path.join(args.output_dir, "control_task_prompt", f"{p}_{i}.png"))
                img_input = preprocess(img).unsqueeze(0).to(device)
            
                with torch.no_grad():
                    image_features = model.encode_image(img_input)
                    image_features /= image_features.norm(dim=-1, keepdim=True)
                
                score_control.append(F.cosine_similarity(image_features, text_features).item())   
        
        avg_score_control[alpha] = sum(score_control)/len(score_control)
        max_score_control[alpha] = max(score_control)

        print("Saving avg score for alpha: ", alpha)
        #save average scores as json every alpha
        name = "avg_score_target_{}.json".format(args.output_dir.split('/')[1])
        file_name = os.path.join("./", name)
        with open(file_name, "w") as f:
            json.dump(avg_score_target, f)
            
        name = "avg_score_control_{}.json".format(args.output_dir.split('/')[1])
        file_name = os.path.join("./", name)
        with open(file_name, "w") as f:
            json.dump(avg_score_control, f)
            
        name = "max_score_target_{}.json".format(args.output_dir.split('/')[1])
        file_name = os.path.join("./", name)
        with open(file_name, "w") as f:
            json.dump(max_score_target, f)
            
        name = "max_score_control_{}.json".format(args.output_dir.split('/')[1])
        file_name = os.path.join("./", name)
        with open(file_name, "w") as f:
            json.dump(max_score_control, f)
    
    os.remove(f"unet_finetuned_{int(curr_time)}.pt")
    os.remove(f"unet_pretrained_{int(curr_time)}.pt")

    
