from diffusers import StableDiffusionPipeline
import torch
import os
import json
from math import ceil, sqrt
from PIL import Image
from utils import save_image, concat_images_in_square_grid, TaskVector
import argparse
import torch.nn.functional as F
import open_clip
import time

#add parser function
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_pretrained', type=str, default="stabilityai/stable-diffusion-2", help='pretrained model')
    parser.add_argument('--model_finetuned', type=str, default="", help='finetuned model')
    parser.add_argument('--model_finetuned_lora', type=str, default="", help='finetuned model with lora layer')
    parser.add_argument('--prompts', nargs='+', type=str, help='list of prompts')
    parser.add_argument('--num_images', type=int, default=30, help='number of images')
    parser.add_argument('--output_dir', type=str, default="diffusers_ckpt/output", help='output directory')
    parser.add_argument('--lora_edit_alpha', type=float, default=-0.97, help='amount of edit to lora layer')
    parser.add_argument('--tv_edit_alpha', type=float, default=0.5, help='amount of edit to task vector layer')
    parser.add_argument('--create_grid', action='store_true', help='if set, create grid of images')
    parser.add_argument('--seed', type=int, default=42, help='random seed')
    return parser.parse_args()

if __name__ == "__main__":
    args = parse_args()

    device = "cuda" if torch.cuda.is_available() else "cpu"
    pipe_pretrained = StableDiffusionPipeline.from_pretrained(args.model_pretrained, torch_dtype=torch.float16, safety_checker=None)
    pipe_pretrained.to(device)
    gen = torch.Generator(device=device)
    gen.manual_seed(args.seed)

    os.makedirs(args.output_dir, exist_ok=True)


    print("Generating images ...")
    print("Edit prompt: ", args.prompts)

    print("Sampling from standard finetuning edited model")    
    
    pipe_finetuned = StableDiffusionPipeline.from_pretrained(args.model_finetuned, torch_dtype=torch.float16, safety_checker=None)
    pipe_finetuned.to("cuda")

    #edit process
    unet_pretrained = pipe_pretrained.unet
    unet_finetuned = pipe_finetuned.unet

    curr_time = int(time.time())

    #save model unet
    torch.save(unet_pretrained, f"unet_pretrained_{int(curr_time)}.pt")
    torch.save(unet_finetuned, f"unet_finetuned_{int(curr_time)}.pt")
            

    task_vector_unet = TaskVector(pretrained_checkpoint=f"unet_pretrained_{int(curr_time)}.pt", 
                            finetuned_checkpoint=f"unet_finetuned_{int(curr_time)}.pt")
    
    dict_ = {}
    for i, (n, k) in enumerate(task_vector_unet.vector.items()):
        val = torch.mean(torch.abs(k.flatten())).item()
        dict_[n] = -val
        
    #pick n keys with highest values
    def pick_n_highest(d, n):
        return dict(sorted(d.items(), key=lambda item: item[1], reverse=True)[:n])
    
    def set_zero(task_vector, dict_, n):
        max_name = pick_n_highest(dict_, n)
        for name in max_name:
            task_vector_unet.vector[name] = torch.zeros(task_vector_unet.vector[name].shape).to(device)  
        return task_vector_unet
    
    def set_zero_block(task_vector, block_l):
        for key, value in task_vector_unet.vector.items():
            for block in block_l:
                if block in key:
                    task_vector_unet.vector[key] = torch.zeros(task_vector_unet.vector[key].shape).to(device)
        return task_vector_unet
    
    # task_vector_unet = set_zero(task_vector_unet, dict_, 500) 
    
    # alphas = [0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0]
    #create range of alphas float with step 0.5 consisting of 20 elements, starting from 0.5
    alphas = [round(0.5 + 0.25 * i, 2) for i in range(20)]
    # alphas = range(0, 500, 25)
    avg_scores = []
    for alpha in alphas:
        
        task_vector_unet = TaskVector(pretrained_checkpoint=f"unet_pretrained_{int(curr_time)}.pt", 
                            finetuned_checkpoint=f"unet_finetuned_{int(curr_time)}.pt")
        task_vector_unet = set_zero(task_vector_unet, dict_, 200) 
                
        neg_task_vector_unet = -task_vector_unet
        unet_edited = neg_task_vector_unet.apply_to(f"unet_pretrained_{int(curr_time)}.pt", scaling_coef=alpha)
        
        pipe_pretrained.unet = unet_edited

        pipe_finetuned = pipe_pretrained

        for p in args.prompts:
            for i in range(args.num_images):
                while(True):
                    nsfw = save_image(pipe_finetuned, p, os.path.join(args.output_dir, f"{p}_{i}.png"), gen)
                
                    #check if nsfw is a list
                    if isinstance(nsfw, list):
                        nsfw = nsfw[0]

                    if not nsfw:
                        break
        

        model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')
        tokenizer = open_clip.get_tokenizer('ViT-B-32')
        model.to(device)

        score = {}
        for img_path in os.listdir(args.output_dir):
            img = Image.open(os.path.join(args.output_dir, img_path))
            img_input = preprocess(img).unsqueeze(0).to(device)
            text = tokenizer([args.prompts[0]]).to(device)
            with torch.no_grad():
                image_features = model.encode_image(img_input)
                text_features = model.encode_text(text)
                image_features /= image_features.norm(dim=-1, keepdim=True)
                text_features /= text_features.norm(dim=-1, keepdim=True)
                
                score[img_path] = F.cosine_similarity(image_features, text_features).item()

        avg_scores.append(sum(score.values())/len(score))
    
    #save average scores as json
    name = "avg_scores_prune_200_layer_{}.json".format(args.output_dir.split('/')[1])
    file_name = os.path.join("src/diffusers/examples/text_to_image", name)
    with open(file_name, "w") as f:
        json.dump(avg_scores, f)

    
    os.remove(f"unet_finetuned_{int(curr_time)}.pt")
    os.remove(f"unet_pretrained_{int(curr_time)}.pt")

    
