from diffusers import StableDiffusionPipeline
import torch
import os
import json
from math import ceil, sqrt
from PIL import Image
from utils import save_image, concat_images_in_square_grid, TaskVector
import argparse
import torch.nn.functional as F
import open_clip

#add parser function
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_pretrained', type=str, default="stabilityai/stable-diffusion-2", help='pretrained model')
    parser.add_argument('--model_finetuned', type=str, default="", help='finetuned model')
    parser.add_argument('--model_finetuned_lora', type=str, default="", help='finetuned model with lora layer')
    parser.add_argument('--prompts', nargs='+', type=str, help='list of prompts')
    parser.add_argument('--num_images', type=int, default=30, help='number of images')
    parser.add_argument('--output_dir', type=str, default="diffusers_ckpt/output", help='output directory')
    parser.add_argument('--lora_edit_alpha', type=float, default=-0.97, help='amount of edit to lora layer')
    parser.add_argument('--tv_edit_alpha', type=float, default=0.5, help='amount of edit to task vector layer')
    parser.add_argument('--create_grid', action='store_true', help='if set, create grid of images')
    parser.add_argument('--seed', type=int, default=42, help='random seed')
    return parser.parse_args()

if __name__ == "__main__":
    args = parse_args()

    device = "cuda" if torch.cuda.is_available() else "cpu"
    pipe_pretrained = StableDiffusionPipeline.from_pretrained(args.model_pretrained, torch_dtype=torch.float16, safety_checker=None)
    pipe_pretrained.to(device)
    gen = torch.Generator(device=device)
    gen.manual_seed(args.seed)

    # if(args.create_grid):
    #     #check if folder exists
    #     if os.path.exists(args.output_dir):
    #        #delete folder
    #         os.system(f"rm -rf {args.output_dir}")
    os.makedirs(args.output_dir, exist_ok=True)
    
    # args.prompts = ["Van Gogh"]
    # for i in range(1, 11):
    #     args.prompts.append(f"<art-style>_{i}")

    print("Generating images ...")
    print("Edit prompt: ", args.prompts)

    if(args.model_finetuned != ""): #task vector edit
        print("Sampling from standard finetuning edited model")
        pipe_pretrained = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16, safety_checker=None)
        pipe_finetuned = StableDiffusionPipeline.from_pretrained("task_vector_exp/van_gogh_sd_1.4_empty_finetune=[unet]", torch_dtype=torch.float16, safety_checker=None)
        pipe_pretrained.to("cuda")
        pipe_finetuned.to("cuda")
        
        #edit process
        unet_pretrained = pipe_pretrained.unet
        unet_finetuned = pipe_finetuned.unet

        #save model unet
        torch.save(unet_pretrained, "unet_pretrained.pt")
        torch.save(unet_finetuned, "unet_finetuned.pt")

        task_vector_unet = TaskVector(pretrained_checkpoint="unet_pretrained.pt", 
                                finetuned_checkpoint="unet_finetuned.pt")
        
        
        task_vector_unet = -task_vector_unet
        
        def set_zero_block(task_vector, block_l):
            task_vector_cp = TaskVector(vector=task_vector.vector.copy())
            for key, value in task_vector_cp.vector.items():
                for block in block_l:
                    if block in key:
                        task_vector_cp.vector[key] = torch.zeros(task_vector_cp.vector[key].shape).to(device)
            return task_vector_cp
        
        # task_vector_unet = set_zero_block(task_vector_unet, ["down_blocks.0", "down_blocks.1", "down_blocks.2"])
        
        unet_edited = task_vector_unet.apply_to("unet_pretrained.pt", scaling_coef=args.tv_edit_alpha)
        
        pipe_pretrained.unet = unet_edited

        pipe_pretrained.save_pretrained("eccv2024_checkpoints/sd_v1.4_ft=van_gogh_5000_tv={}".format(args.tv_edit_alpha))
        
        os.remove("unet_pretrained.pt")
        os.remove("unet_finetuned.pt")
        
        assert 0
        
        
        pipe_finetuned = StableDiffusionPipeline.from_pretrained(args.model_finetuned, torch_dtype=torch.float16, safety_checker=None)
        pipe_finetuned.to("cuda")

        #edit process
        unet_pretrained = pipe_pretrained.unet
        unet_finetuned = pipe_finetuned.unet

        #save model unet
        torch.save(unet_pretrained, "unet_pretrained.pt")
        torch.save(unet_finetuned, "unet_finetuned.pt")
              

        task_vector_unet = TaskVector(pretrained_checkpoint="unet_pretrained.pt", 
                                finetuned_checkpoint="unet_finetuned.pt")
        
        dict_ = {}
        for i, (n, k) in enumerate(task_vector_unet.vector.items()):
            val = torch.mean(torch.abs(k.flatten())).item()
            dict_[n] = -val
            
        #pick n keys with highest values
        def pick_n_highest(d, n):
            return dict(sorted(d.items(), key=lambda item: item[1], reverse=True)[:n])
        
        def set_zero(task_vector, dict_, n):
            max_name = pick_n_highest(dict_, n)
            for name in max_name:
                task_vector_unet.vector[name] = torch.zeros(task_vector_unet.vector[name].shape).to(device)  
            return task_vector_unet
        

        
        # task_vector_unet = set_zero(task_vector_unet, dict_, 500) 
                        
        neg_task_vector_unet = -task_vector_unet
        unet_edited = neg_task_vector_unet.apply_to("unet_pretrained.pt", scaling_coef=args.tv_edit_alpha)
        
        pipe_pretrained.unet = unet_edited

        pipe_finetuned = pipe_pretrained

        # pipe_pretrained.save_pretrained("task_vector_exp/test")
        os.remove("unet_finetuned.pt")
        os.remove("unet_pretrained.pt")

    elif(args.model_finetuned_lora != ""):
        print("Sampling from lora finetuning edited model")
        pipe_finetuned = StableDiffusionPipeline.from_pretrained(args.model_pretrained, torch_dtype=torch.float16, safety_checker=None)
        pipe_finetuned.unet.load_attn_procs(args.model_finetuned_lora)
        pipe_finetuned.to("cuda")

        #scale lora layer
        for name, param in pipe_finetuned.unet.named_parameters():
            if("_lora.up.weight" in name):
                with torch.no_grad():
                    print("Editing lora layer: ", name)
                    #flip the sign of the lora layer
                    param.copy_(torch.nn.Parameter(args.lora_edit_alpha * param))

    for p in args.prompts:
        for i in range(args.num_images):
            while(True):
                nsfw = save_image(pipe_finetuned, p, os.path.join(args.output_dir, f"{p}_{i}.png"), gen)
            
                #check if nsfw is a list
                if isinstance(nsfw, list):
                    nsfw = nsfw[0]

                if not nsfw:
                    break
    
    if(args.create_grid):
        for p in args.prompts:
            concat_images_in_square_grid(args.output_dir, p, os.path.join(args.output_dir, f"grid {p}.png"))

    model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')
    tokenizer = open_clip.get_tokenizer('ViT-B-32')
    model.to(device)

    score = {}
    for img_path in os.listdir(args.output_dir):
        img = Image.open(os.path.join(args.output_dir, img_path))
        img_input = preprocess(img).unsqueeze(0).to(device)
        # text = tokenizer([args.prompts[0]]).to(device)
        text = tokenizer(["a painting in the style of Van Gogh"]).to(device)
        with torch.no_grad():
            image_features = model.encode_image(img_input)
            text_features = model.encode_text(text)
            image_features /= image_features.norm(dim=-1, keepdim=True)
            text_features /= text_features.norm(dim=-1, keepdim=True)
            
            # score.append(F.cosine_similarity(image_features, text_features).item())
            score[img_path] = F.cosine_similarity(image_features, text_features).item()
    # print("Average cosine similarity score: ", sum(score)/len(score))
    # print("Max cosine similarity score: ", max(score))
    # print(args.model_finetuned, args.tv_edit_alpha)
    # print()
    print("Average cosine similarity score: ", sum(score.values())/len(score))
    print("Max cosine similarity score: ", max(score.values()))
    print(score)
    
    #return index of max score
    print("Index of max score: ", os.listdir(args.output_dir)[list(score.values()).index(max(list(score.values())))])
    print("Done!")
    