import os
import clip
import argparse
from PIL import Image
from tqdm import tqdm
from pathlib import Path

import sys

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from appendix_experiments.generate_images_from_paraphrases import load_paraphrases
from torch.nn.functional import cosine_similarity
import torch
import pandas as pd


@torch.no_grad()
def main():
    parser = argparse.ArgumentParser(
        description="Evaluate CLIP alignment for Wanda vs No-Wanda generations for a prompt and its rephrases."
    )
    parser.add_argument(
        "--paraphrases_file",
        default="paraphrases.json",
        type=str,
        help="Path to the JSON file containing paraphrases.",
    )
    parser.add_argument(
        "--num_samples",
        default=10,
        type=int,
        help="Number of image samples generated per prompt variation.",
    )
    parser.add_argument(
        "--rtpt_name",
        default="EvalAlign",
        type=str,
        help="RTPT user name initials for this evaluation script.",
    )
    args = parser.parse_args()

    df = pd.read_csv("prompts/memorized_laion_prompts.csv", sep=";")

    # 1. Load Paraphrases for the given main prompt
    prompt_slug = (
        lambda prompt: prompt.replace(" ", "_")
        .replace("/", "_")
        .replace("\\", "_")[:30]
    )

    wanda_path = "concept_testing_all"

    model, preprocess = clip.load("ViT-B/32", device="cuda")

    all_medians_wanda = {"VM": [], "TM": [], "all": []}
    all_medians_nowanda = {"VM": [], "TM": [], "all": []}

    for _, row in tqdm(df.iterrows()):
        original_prompt = row["Caption"]

        if not os.path.exists(f"{wanda_path}/{prompt_slug(original_prompt)}"):
            print(f"Skipping {prompt_slug(original_prompt)} because it doesn't exist")
            continue

        if (
            len(os.listdir(f"{wanda_path}/{prompt_slug(original_prompt)}"))
            < args.num_samples
        ):
            print(
                f"Skipping {prompt_slug(original_prompt)} because it doesn't have enough samples"
            )
            continue

        paraphrased_variants = load_paraphrases(original_prompt, args.paraphrases_file)
        wanda_dir = Path(f"{wanda_path}/{prompt_slug(original_prompt)}")
        nowanda_dir = Path(f"concept_testing_nowanda/{prompt_slug(original_prompt)}")

        prompt_medians_wanda = []
        prompt_medians_nowanda = []

        for prompt_id, prompt in enumerate(paraphrased_variants):
            images_wanda = torch.stack(
                [
                    preprocess(Image.open(wanda_dir / f"{img_id}_{prompt_id}.png"))
                    for img_id in range(args.num_samples)
                ]
            )
            images_nowanda = torch.stack(
                [
                    preprocess(Image.open(nowanda_dir / f"{img_id}_{prompt_id}.png"))
                    for img_id in range(args.num_samples)
                ]
            )

            image_features_wanda = model.encode_image(images_wanda.to("cuda"))
            image_features_nowanda = model.encode_image(images_nowanda.to("cuda"))

            try:
                text = clip.tokenize([prompt], truncate=True).to("cuda")
            except Exception as e:
                print(f"Error tokenizing prompt: {prompt}")
                print(e)
                continue

            text_features = model.encode_text(text)

            similarity_score_wanda = cosine_similarity(
                image_features_wanda, text_features
            ).cpu()
            similarity_score_nowanda = cosine_similarity(
                image_features_nowanda, text_features
            ).cpu()

            prompt_medians_wanda.append(similarity_score_wanda)
            prompt_medians_nowanda.append(similarity_score_nowanda)

        if len(prompt_medians_wanda) == 0:
            print(
                f"Skipping {prompt_slug(original_prompt)} because it doesn't have any valid samples"
            )
            continue

        all_medians_nowanda[row["type"]].append(torch.cat(prompt_medians_nowanda).max())
        all_medians_wanda[row["type"]].append(torch.cat(prompt_medians_wanda).max())

        all_medians_nowanda["all"].append(torch.cat(prompt_medians_nowanda).max())
        all_medians_wanda["all"].append(torch.cat(prompt_medians_wanda).max())

    # compute statistics over the whole set
    for key in ["VM", "TM", "all"]:
        if len(all_medians_wanda[key]) > 0:
            scores_wanda = torch.stack(all_medians_wanda[key])
            scores_nowanda = torch.stack(all_medians_nowanda[key])

            median_wanda = scores_wanda.median().item()
            deviation_wanda = (scores_wanda - median_wanda).abs().median().item()
            print(
                f"Median similarity score with WANDA ({key}): {median_wanda:.4f}±{deviation_wanda:.2f}"
            )

            median_nowanda = scores_nowanda.median().item()
            deviation_nowanda = (scores_nowanda - median_nowanda).abs().median().item()
            print(
                f"Median similarity score without WANDA ({key}): {median_nowanda:.4f}±{deviation_nowanda:.2f}"
            )


if __name__ == "__main__":
    main()
