from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler
import torch
from PIL import Image
import copy
from utils_tv import get_sd_tv, clip_embed, cosine_similarity, clip_text_embed, get_many_tvs, get_non_linear_many_tvs, iqa_score
import argparse
import wandb
import os
import shutil
import numpy as np
from PIL import Image, ImageDraw, ImageFont
import random
from eval_dynamic_tv_as_function import eval_dynamic_which_vector

def update_params(pipeline, unet_edited):
    with torch.no_grad():
        for param, edited_param in zip(pipeline.unet.parameters(), unet_edited.parameters()):
            param.data = edited_param.data.to(param.data.dtype)
    return pipeline

parser = argparse.ArgumentParser(description='Example script with argparse')

parser.add_argument('--prompt', type=str, default="#")
parser.add_argument('--wandb_name', type=str, default="try")
parser.add_argument('--edit_alpha', type=float, default=1.0)
parser.add_argument('--get_tv_option', type=str, default="join_linear")
parser.add_argument('--clip_sim_ths', type=float, default=0.8)

args = parser.parse_args()

wandb.init(
    dir="wandb_directory",
    project=args.wandb_name, 
    config=args
)

control_prompts = ["Alphonse Mucha", "H.R. Giger", "Gustav Klimt", "Hayao Miyazaki", "M.C. Escher", "Yoshitaka Amano", "Salvador Dalí", "James Gurney", "Jean Giraud (Moebius)", "John Singer Sargent", "airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]

# Load the pipeline and configure the scheduler
pipeline = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", safety_checker=None)
pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)

# Check device and move pipeline to the appropriate device
device = "cuda" if torch.cuda.is_available() else "cpu"
pipeline.to(device)

# Callback function to modify UNet parameters
vae = pipeline.vae

vector_pairs = [
    "ajin_demi_human+kelly_mckernan_1000",
    "kilian_eng+thomas_kinkade_1000",
    "thomas_kinkade+tyler_edlin_1000",
    "tyler_edlin+van_gogh_1000",
    "van_gogh+ajin_demi_human_1000",
    "kelly_mckernan+kilian_eng_1000"
]

tv_vector_names = [
    ("ajin_demi_human", "kelly_mckernan"),
    ("kilian_eng", "thomas_kinkade"),
    ("thomas_kinkade", "tyler_edlin"),
    ("tyler_edlin", "van_gogh"),
    ("van_gogh", "ajin_demi_human"),
    ("kelly_mckernan", "kilian_eng")
]

tv_concept_names = [
    ("ajin demi human", "kelly mckernan"),
    ("kilian eng", "thomas kinkade"),
    ("thomas kinkade", "tyler edlin"),
    ("tyler edlin", "van gogh"),
    ("van gogh", "ajin demi human"),
    ("kelly mckernan", "kilian eng")
]

num_pairs = len(vector_pairs)

num_inference_steps = 50

control_acc_per_pair = []
target_acc_per_pair = []

for i in range(num_pairs):
        
    tv_concept_names_of_pair = tv_concept_names[i]

    if args.get_tv_option == "join_linear":

        ft_paths_of_pair = [
            f"path/to/{vec}_1000/"
            for vec in tv_vector_names[i]
        ]

        task_vector_unet, unet_edited, tmp_path = get_many_tvs(args.edit_alpha, ft_paths_of_pair)

    if args.get_tv_option == "non_linear":

        ft_paths_of_pair = [
            f"path/to/{vec}_1000/"
            for vec in tv_vector_names[i]
        ]

        task_vector_unet, unet_edited, tmp_path = get_non_linear_many_tvs(args.edit_alpha, ft_paths_of_pair)

    if args.get_tv_option == "co_train":
        task_vector_unet, unet_edited, tmp_path = get_sd_tv(args.edit_alpha, ft_path=f"path/to/{vector_pairs[i]}/")

    if not (args.get_tv_option == "ours"):
        pipeline = update_params(pipeline, unet_edited)

    print("\ncontrol_prompt\n")

    for control_prompt in control_prompts:

        if args.get_tv_option == "ours":
            vec_scores = eval_dynamic_which_vector(control_prompt, vec1=f"{tv_vector_names[i][0]}_1000/", vec2=f"{tv_vector_names[i][1]}_1000/")
            indices = [j for j, score in enumerate(vec_scores) if score.item() < args.clip_sim_ths]
                
            print("vec_scores", vec_scores)
            print("indices", indices)

            ft_paths_of_pair = [
                f"path/to/{tv_vector_names[i][j]}_1000/"
                for j in indices
            ]

            task_vector_unet, unet_edited, tmp_path = get_many_tvs(args.edit_alpha, ft_paths_of_pair)

            pipeline = update_params(pipeline, unet_edited)

        prompt = args.prompt.replace("#", control_prompt)
        print("prompt", prompt)

        with torch.no_grad():
            final_image = pipeline(prompt, num_inference_steps=num_inference_steps).images[0]

            emb_null_tv = clip_embed(final_image)
            emb_prompt = clip_text_embed(prompt)

            control_similarity = cosine_similarity(emb_null_tv[0], emb_prompt[0])

            control_acc_per_pair.append(control_similarity.numpy())
        
    print("\ntv_concept_name\n")

    for tv_concept_name in tv_concept_names_of_pair:

        if args.get_tv_option == "ours":

            print("tv_vector_names[i]", tv_vector_names[i])
            print("tv_concept_name", tv_concept_name)

            vec_scores = eval_dynamic_which_vector(tv_concept_name, vec1=f"{tv_vector_names[i][0]}_1000/", vec2=f"{tv_vector_names[i][1]}_1000/")
            indices = [j for j, score in enumerate(vec_scores) if score.item() < args.clip_sim_ths]
                
            print("vec_scores", vec_scores)
            print("indices", indices)

            ft_paths_of_pair = [
                f"path/to/{tv_vector_names[i][j]}_1000/"
                for j in indices
            ]

            task_vector_unet, unet_edited, tmp_path = get_many_tvs(args.edit_alpha, ft_paths_of_pair)

            pipeline = update_params(pipeline, unet_edited)

        with torch.no_grad():
            final_image = pipeline(tv_concept_name, num_inference_steps=num_inference_steps).images[0]

            emb_null_tv = clip_embed(final_image)
            emb_prompt = clip_text_embed(tv_concept_name)

            target_similarity = cosine_similarity(emb_null_tv[0], emb_prompt[0])

            target_acc_per_pair.append(target_similarity.numpy())

print("control_similarity", np.mean(control_acc_per_pair))
print("target_similarity", np.mean(target_acc_per_pair))

if os.path.exists(tmp_path):
    shutil.rmtree(tmp_path)
    print(f"Directory '{tmp_path}' and all its contents have been deleted.")
else:
    print(f"Directory '{tmp_path}' does not exist.")    

os.makedirs(os.path.join("results_positive/combination_comparison"), exist_ok=True)
file_path = os.path.join("results_positive/combination_comparison", args.wandb_name)

with open(file_path, 'w') as f:
    f.write("control_acc_per_i:\n")
    f.write(str(np.mean(control_acc_per_pair)) + "\n")
    f.write("target_acc_per_i:\n")
    f.write(str(np.mean(target_acc_per_pair)) + "\n")
