import os
import sys
sys.path.append("../LLaVA")
os.chdir("../LLaVA")
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import json
import random
import shutil

import numpy as np
from PIL import Image
from tqdm import tqdm

import torch
from datasets import load_dataset
from transformers import CLIPProcessor, CLIPModel
from experiments.blip_experiments.transferable.transform_image import DIM, SIM, SGA, SIA, TIM, Admix, render_typos, AIP

import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA

import pickle
from scipy.stats import circmean
from scipy.spatial.distance import pdist


def dd_score(enhanced_points, original_point, alpha=0.5):
    """
    Calculate the Distance-Diversity Score for a set of enhanced points
    
    :param enhanced_points: numpy array of shape (n, d) where n is the number of points and d is the dimension
    :param original_point: numpy array of shape (d,) representing the original image embedding
    :param alpha: weight factor to balance distance and diversity (0 <= alpha <= 1)
    :return: DD-Score
    """
    # Calculate average distance from enhanced points to original point
    distances_to_original = np.linalg.norm(enhanced_points - original_point, axis=1)
    avg_distance = np.mean(distances_to_original)
    
    # Calculate diversity using average pairwise distance
    if len(enhanced_points) > 1:
        pairwise_distances = pdist(enhanced_points)
        avg_pairwise_distance = np.mean(pairwise_distances)
    else:
        avg_pairwise_distance = 0
    
    return (avg_distance, avg_pairwise_distance)

def angle_score(enhanced_points, original_point):
    """
    Calculate the Direction Diversity Score for a set of enhanced points
    
    :param enhanced_points: numpy array of shape (n, d) where n is the number of points and d is the dimension
    :param original_point: numpy array of shape (d,) representing the original image embedding
    :return: tuple (mean_direction, spread)
    """
    # Calculate vectors from original point to enhanced points
    vectors = enhanced_points - original_point
    
    # Calculate angles of these vectors (in 2D space)
    angles = np.arctan2(vectors[:, 1], vectors[:, 0])
    
    # Calculate mean direction
    mean_direction = circmean(angles)
    
    # Calculate spread (circular standard deviation)
    spread = np.sqrt(-2 * np.log(np.mean(np.cos(angles - mean_direction))))
    
    return (mean_direction, spread)

def calculate_direction_diversity(all_enhanced_points, original_point):
    methods = ['TATM', 'DIM', 'BC', 'SIM', 'SIA', 'TIM', 'ADMIX', 'AIP']
    scores = {}
    mean_directions = []

    for method, points in zip(methods, all_enhanced_points):
        # Calculate mean direction
        vectors = points - original_point
        angles = np.arctan2(vectors[:, 1], vectors[:, 0])
        mean_direction = circmean(angles)
        scores[method] = {'mean_direction': mean_direction}
        mean_directions.append(mean_direction)

    # Calculate angular differences for specific method comparisons
    comparison_groups = [
        (['ADMIX', 'AIP', 'TATM'], ['SIM', 'BC', 'TIM']),
        (['ADMIX', 'AIP', 'TATM'], ['DIM', 'BC', 'TIM', 'SIM', 'SIA'])
    ]

    for group1, group2 in comparison_groups:
        for method1 in group1:
            angular_diffs = []
            for method2 in group2:
                idx1 = methods.index(method1)
                idx2 = methods.index(method2)
                diff = np.abs(np.angle(np.exp(1j * (mean_directions[idx1] - mean_directions[idx2]))))
                angular_diffs.append(diff)
            scores[method1][f'avg_angular_diff_to_{",".join(group2)}'] = np.mean(angular_diffs)

    return scores

def embeddings_to_numpy(embeddings):
    return np.vstack([emb.cpu().detach().numpy() for emb in embeddings])

def remove_image_extensions(text):
    text = text.replace(".jpg", "")
    text = text.replace(".png", "")
    return text


adv_typo_folder = "dataset/transferable/describe-mscoco-llava-v1.5-7b-response_suicide-iter1000-random3typo-fs25-wordnetnoun"
adv_dim_folder = "dataset/transferable/describe-mscoco-llava-v1.5-7b-response_suicide-iter1000-dim"
adv_sim_folder = "dataset/transferable/describe-mscoco-llava-v1.5-7b-response_suicide-iter1000-sim"
adv_sga_folder = "dataset/transferable/describe-mscoco-llava-v1.5-7b-response_suicide-iter1000-sga"
adv_sia_folder = "dataset/transferable/describe-mscoco-llava-v1.5-7b-response_suicide-iter1000-sia"
adv_tim_folder = "dataset/transferable/describe-mscoco-llava-v1.5-7b-response_suicide-iter1000-tim"
adv_admix_folder = "dataset/transferable/describe-mscoco-llava-v1.5-7b-response_suicide-iter1000-admix"
adv_aip_folder = "dataset/transferable/describe-mscoco-llava-v1.5-7b-response_suicide-iter1000-eps16-aip"

image_folder = "dataset/transferable/mscoco_clean300_crop336"
typo_num = 3
typo_size = 15
typo_color = (255, 255, 255)
typo_font = 'fonts/arial_bold.ttf'
noun_file = 'dataset/transferable/nouns.json'
emb_num = 400
is_pca, is_tsne, is_umap = True, False, False
with_adv = False

log_dir = 'semantic_visualization_including_origin'
if os.path.exists(log_dir):
    shutil.rmtree(log_dir)
os.makedirs(log_dir)

class_pool = []
with open(noun_file, 'r') as file:
    class_pool = [json.loads(line.strip()) for line in file if line.strip()]

image_files = os.listdir(image_folder)

device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = "openai/clip-vit-large-patch14-336"
model = CLIPModel.from_pretrained(model_name).to(device)
processor = CLIPProcessor.from_pretrained(model_name)

mscoco = load_dataset("lmms-lab/COCO-Caption2017")
descriptions = {}
for i in range(len(mscoco['val'])):
    image_file = mscoco['val'][i]['question_id']
    if image_file in image_files:
        descriptions[image_file] = mscoco['val'][i]['answer']

total_origin_logits, total_typo_logits, total_dim_logits, total_sim_logits, total_sga_logits, total_sia_logits, total_tim_logits, total_admix_logits, total_aip_logits = [], [], [], [], [], [], [], [], []
total_adv_typo_logits, total_adv_dim_logits, total_adv_sim_logits, total_adv_sga_logits, total_adv_sia_logits, total_adv_tim_logits, total_adv_admix_logits, total_adv_aip_logits = [], [], [], [], [], [], [], []


for k, image_file in enumerate(tqdm(image_files)):
    # if "000000255965" not in image_file:
    #     continue
    
    origin_logits = []
    origin_emb = []
    with torch.no_grad():
        for i in range(emb_num):
            if len(origin_emb) == 0:
                image = Image.open(os.path.join(image_folder, image_file)).convert('RGB')
                inputs = processor(text=descriptions[image_file], images=image, return_tensors="pt", padding=True).to(device)
                outputs = model(**inputs)
                emb = outputs.image_embeds
                logits = outputs.logits_per_text.mean().item()
            origin_emb.append(emb)
            origin_logits.append(logits)
    total_origin_logits.append(sum(origin_logits) / len(origin_logits))
    
    typo_logits = []
    typo_emb = []
    with torch.no_grad():
        for i in range(emb_num):
            image = Image.open(os.path.join(image_folder, image_file)).convert('RGB')
            typos = random.sample(class_pool, k=typo_num)
            image = render_typos(image, typos, typo_font, typo_size, typo_color)
            inputs = processor(text=descriptions[image_file], images=image, return_tensors="pt", padding=True).to(device)
            outputs = model(**inputs)
            emb = outputs.image_embeds
            logits = outputs.logits_per_text.mean().item()
            typo_emb.append(emb)
            typo_logits.append(logits)
    total_typo_logits.append(sum(typo_logits) / len(typo_logits))

    dim_logits = []
    dim_emb = []
    with torch.no_grad():
        for i in range(emb_num):
            image = Image.open(os.path.join(image_folder, image_file)).convert('RGB')
            image = DIM(image)
            inputs = processor(text=descriptions[image_file], images=image, return_tensors="pt", padding=True).to(device)
            outputs = model(**inputs)
            emb = outputs.image_embeds
            logits = outputs.logits_per_text.mean().item()
            dim_emb.append(emb)
            dim_logits.append(logits)
    total_dim_logits.append(sum(dim_logits) / len(dim_logits))
      
    sim_logits = []
    sim_emb = []
    with torch.no_grad():
        for i in range(emb_num):
            image = Image.open(os.path.join(image_folder, image_file)).convert('RGB')
            image = SIM(image)
            inputs = processor(text=descriptions[image_file], images=image, return_tensors="pt", padding=True).to(device)
            outputs = model(**inputs)
            emb = outputs.image_embeds
            logits = outputs.logits_per_text.mean().item()
            sim_emb.append(emb)
            sim_logits.append(logits)
    total_sim_logits.append(sum(sim_logits) / len(sim_logits))

    sga_logits = []
    sga_emb = []
    with torch.no_grad():
        for i in range(emb_num):
            image = Image.open(os.path.join(image_folder, image_file)).convert('RGB')
            image = SGA(image)
            inputs = processor(text=descriptions[image_file], images=image, return_tensors="pt", padding=True).to(device)
            outputs = model(**inputs)
            emb = outputs.image_embeds
            logits = outputs.logits_per_text.mean().item()
            sga_emb.append(emb)
            sga_logits.append(logits)
    total_sga_logits.append(sum(sga_logits) / len(sga_logits))

    sia_logits = []
    sia_emb = []
    with torch.no_grad():
        for i in range(emb_num):
            image = Image.open(os.path.join(image_folder, image_file)).convert('RGB')
            image = SIA(image)
            inputs = processor(text=descriptions[image_file], images=image, return_tensors="pt", padding=True).to(device)
            outputs = model(**inputs)
            emb = outputs.image_embeds
            logits = outputs.logits_per_text.mean().item()
            sia_emb.append(emb)
            sia_logits.append(logits)
    total_sia_logits.append(sum(sia_logits) / len(sia_logits))

    tim_logits = []
    tim_emb = []
    with torch.no_grad():
        for i in range(emb_num):
            image = Image.open(os.path.join(image_folder, image_file)).convert('RGB')
            image = TIM(image)
            inputs = processor(text=descriptions[image_file], images=image, return_tensors="pt", padding=True).to(device)
            outputs = model(**inputs)
            emb = outputs.image_embeds
            logits = outputs.logits_per_text.mean().item()
            tim_emb.append(emb)
            tim_logits.append(logits)
    total_tim_logits.append(sum(tim_logits) / len(tim_logits))
     
    admix_logits = []
    admix_emb = []
    with torch.no_grad():
        for i in range(emb_num):
            image = Image.open(os.path.join(image_folder, image_file)).convert('RGB')
            added_image = Image.open(os.path.join(image_folder, random.choice(image_files))).convert('RGB')
            image = Admix(image, added_image)
            inputs = processor(text=descriptions[image_file], images=image, return_tensors="pt", padding=True).to(device)
            outputs = model(**inputs)
            emb = outputs.image_embeds
            logits = outputs.logits_per_text.mean().item()
            admix_emb.append(emb)
            admix_logits.append(logits)
    total_admix_logits.append(sum(admix_logits) / len(admix_logits))
    
    aip_logits = []
    aip_emb = []
    with torch.no_grad():
        for i in range(emb_num):
            image = Image.open(os.path.join(image_folder, image_file)).convert('RGB')
            added_image = Image.open(os.path.join(image_folder, random.choice(image_files))).convert('RGB')
            image = AIP(image, added_image)
            inputs = processor(text=descriptions[image_file], images=image, return_tensors="pt", padding=True).to(device)
            outputs = model(**inputs)
            emb = outputs.image_embeds
            logits = outputs.logits_per_text.mean().item()
            aip_emb.append(emb)
            aip_logits.append(logits)
    total_aip_logits.append(sum(aip_logits) / len(aip_logits))

    adv_typo_logits = []
    adv_typo_emb = []
    with torch.no_grad():
        for i in range(emb_num):
            if len(adv_typo_emb) == 0:
                image = Image.open(os.path.join(adv_typo_folder, os.path.splitext(image_file)[0] + '.png')).convert('RGB')
                inputs = processor(text=descriptions[image_file], images=image, return_tensors="pt", padding=True).to(device)
                outputs = model(**inputs)
                emb = outputs.image_embeds
                logits = outputs.logits_per_text.mean().item()
            adv_typo_emb.append(emb)
            adv_typo_logits.append(logits)
    total_adv_typo_logits.append(sum(adv_typo_logits) / len(adv_typo_logits))

    adv_dim_logits = []
    adv_dim_emb = []
    with torch.no_grad():
        for i in range(emb_num):
            if len(adv_dim_emb) == 0:
                image = Image.open(os.path.join(adv_dim_folder, os.path.splitext(image_file)[0] + '.png')).convert('RGB')
                inputs = processor(text=descriptions[image_file], images=image, return_tensors="pt", padding=True).to(device)
                outputs = model(**inputs)
                emb = outputs.image_embeds
                logits = outputs.logits_per_text.mean().item()
            adv_dim_emb.append(emb)
            adv_dim_logits.append(logits)
    total_adv_dim_logits.append(sum(adv_dim_logits) / len(adv_dim_logits))
    
    adv_sim_logits = []
    adv_sim_emb = []
    with torch.no_grad():
        for i in range(emb_num):
            if len(adv_sim_emb) == 0:
                image = Image.open(os.path.join(adv_sim_folder, os.path.splitext(image_file)[0] + '.png')).convert('RGB')
                inputs = processor(text=descriptions[image_file], images=image, return_tensors="pt", padding=True).to(device)
                outputs = model(**inputs)
                emb = outputs.image_embeds
                logits = outputs.logits_per_text.mean().item()
            adv_sim_emb.append(emb)
            adv_sim_logits.append(logits)
    total_adv_sim_logits.append(sum(adv_sim_logits) / len(adv_sim_logits))
    
    adv_sga_logits = []
    adv_sga_emb = []
    with torch.no_grad():
        for i in range(emb_num):
            if len(adv_sga_emb) == 0:
                image = Image.open(os.path.join(adv_sga_folder, os.path.splitext(image_file)[0] + '.png')).convert('RGB')
                inputs = processor(text=descriptions[image_file], images=image, return_tensors="pt", padding=True).to(device)
                outputs = model(**inputs)
                emb = outputs.image_embeds
                logits = outputs.logits_per_text.mean().item()
            adv_sga_emb.append(emb)
            adv_sga_logits.append(logits)
    total_adv_sga_logits.append(sum(adv_sga_logits) / len(adv_sga_logits))
    
    adv_sia_logits = []
    adv_sia_emb = []
    with torch.no_grad():
        for i in range(emb_num):
            if len(adv_sia_emb) == 0:
                image = Image.open(os.path.join(adv_sia_folder, os.path.splitext(image_file)[0] + '.png')).convert('RGB')
                inputs = processor(text=descriptions[image_file], images=image, return_tensors="pt", padding=True).to(device)
                outputs = model(**inputs)
                emb = outputs.image_embeds
                logits = outputs.logits_per_text.mean().item()
            adv_sia_emb.append(emb)
            adv_sia_logits.append(logits)
    total_adv_sia_logits.append(sum(adv_sia_logits) / len(adv_sia_logits))
    
    adv_tim_logits = []
    adv_tim_emb = []
    with torch.no_grad():
        for i in range(emb_num):
            if len(adv_tim_emb) == 0:
                image = Image.open(os.path.join(adv_tim_folder, os.path.splitext(image_file)[0] + '.png')).convert('RGB')
                inputs = processor(text=descriptions[image_file], images=image, return_tensors="pt", padding=True).to(device)
                outputs = model(**inputs)
                emb = outputs.image_embeds
                logits = outputs.logits_per_text.mean().item()
            adv_tim_emb.append(emb)
            adv_tim_logits.append(logits)
    total_adv_tim_logits.append(sum(adv_tim_logits) / len(adv_tim_logits))
    
    adv_admix_logits = []
    adv_admix_emb = []
    with torch.no_grad():
        for i in range(emb_num):
            if len(adv_admix_emb) == 0:
                image = Image.open(os.path.join(adv_admix_folder, os.path.splitext(image_file)[0] + '.png')).convert('RGB')
                inputs = processor(text=descriptions[image_file], images=image, return_tensors="pt", padding=True).to(device)
                outputs = model(**inputs)
                emb = outputs.image_embeds
                logits = outputs.logits_per_text.mean().item()
            adv_admix_emb.append(emb)
            adv_admix_logits.append(logits)
    total_adv_admix_logits.append(sum(adv_admix_logits) / len(adv_admix_logits))
    
    adv_aip_logits = []
    adv_aip_emb = []
    with torch.no_grad():
        for i in range(emb_num):
            if len(adv_aip_emb) == 0:
                image = Image.open(os.path.join(adv_aip_folder, os.path.splitext(image_file)[0] + '.png')).convert('RGB')
                inputs = processor(text=descriptions[image_file], images=image, return_tensors="pt", padding=True).to(device)
                outputs = model(**inputs)
                emb = outputs.image_embeds
                logits = outputs.logits_per_text.mean().item()
            adv_aip_emb.append(emb)
            adv_aip_logits.append(logits)
    total_adv_aip_logits.append(sum(adv_aip_logits) / len(adv_aip_logits))
    

    print()
    print(image_file)
    print("average logits (typo):", sum(total_typo_logits) / len(total_typo_logits))
    print("average logits (dim):", sum(total_dim_logits) / len(total_dim_logits))
    print("average logits (sim):", sum(total_sim_logits) / len(total_sim_logits))
    print("average logits (sga):", sum(total_sga_logits) / len(total_sga_logits))
    print("average logits (sia):", sum(total_sia_logits) / len(total_sia_logits))
    print("average logits (tim):", sum(total_tim_logits) / len(total_tim_logits))
    print("average logits (admix):", sum(total_admix_logits) / len(total_admix_logits))
    print("average logits (aip):", sum(total_aip_logits) / len(total_aip_logits))
    print("average logits (origin):", sum(total_origin_logits) / len(total_origin_logits))
    
    print("average logits (adv-typo):", sum(total_adv_typo_logits) / len(total_adv_typo_logits))
    print("average logits (adv-dim):", sum(total_adv_dim_logits) / len(total_adv_dim_logits))
    print("average logits (adv-sim):", sum(total_adv_sim_logits) / len(total_adv_sim_logits))
    print("average logits (adv-sga):", sum(total_adv_sga_logits) / len(total_adv_sga_logits))
    print("average logits (adv-sia):", sum(total_adv_sia_logits) / len(total_adv_sia_logits))
    print("average logits (adv-tim):", sum(total_adv_tim_logits) / len(total_adv_tim_logits))
    print("average logits (adv-admix):", sum(total_adv_admix_logits) / len(total_adv_admix_logits))
    print("average logits (adv-aip):", sum(total_adv_aip_logits) / len(total_adv_aip_logits))
    print()
    
    
    # PCA
    if is_pca:
        if with_adv:
            numpy_typo_embeddings = embeddings_to_numpy(typo_emb)
            numpy_dim_embeddings = embeddings_to_numpy(dim_emb)
            numpy_sim_embeddings = embeddings_to_numpy(sim_emb)
            numpy_sga_embeddings = embeddings_to_numpy(sga_emb)
            numpy_sia_embeddings = embeddings_to_numpy(sia_emb)
            numpy_tim_embeddings = embeddings_to_numpy(tim_emb)
            numpy_admix_embeddings = embeddings_to_numpy(admix_emb)
            numpy_aip_embeddings = embeddings_to_numpy(aip_emb)
            numpy_origin_embeddings = embeddings_to_numpy(origin_emb)
            
            numpy_adv_typo_embeddings = embeddings_to_numpy(adv_typo_emb)
            numpy_adv_dim_embeddings = embeddings_to_numpy(adv_dim_emb)
            numpy_adv_sim_embeddings = embeddings_to_numpy(adv_sim_emb)
            numpy_adv_sga_embeddings = embeddings_to_numpy(adv_sga_emb)
            numpy_adv_sia_embeddings = embeddings_to_numpy(adv_sia_emb)
            numpy_adv_tim_embeddings = embeddings_to_numpy(adv_tim_emb)
            numpy_adv_admix_embeddings = embeddings_to_numpy(adv_admix_emb)
            numpy_adv_aip_embeddings = embeddings_to_numpy(adv_aip_emb)

            combined_embeddings = np.vstack((numpy_typo_embeddings, numpy_dim_embeddings, 
                                            numpy_sim_embeddings, numpy_sga_embeddings, 
                                            numpy_sia_embeddings, numpy_tim_embeddings,
                                            numpy_admix_embeddings, numpy_aip_embeddings, numpy_origin_embeddings, 
                                            numpy_adv_typo_embeddings, numpy_adv_dim_embeddings, numpy_adv_sim_embeddings, numpy_adv_sga_embeddings, numpy_adv_sia_embeddings, numpy_adv_tim_embeddings, numpy_adv_admix_embeddings, numpy_adv_aip_embeddings))

            scaler = StandardScaler()
            data_scaled = scaler.fit_transform(combined_embeddings)

            pca = PCA(n_components=2)
            principal_components = pca.fit_transform(data_scaled)

            idxs = [len(numpy_typo_embeddings), len(numpy_dim_embeddings), len(numpy_sim_embeddings), len(numpy_sga_embeddings),
                    len(numpy_sia_embeddings), len(numpy_tim_embeddings), len(numpy_admix_embeddings), len(numpy_aip_embeddings), len(numpy_origin_embeddings), 
                    len(numpy_adv_typo_embeddings), len(numpy_adv_dim_embeddings), len(numpy_adv_sim_embeddings), len(numpy_adv_sga_embeddings), len(numpy_adv_sia_embeddings), len(numpy_adv_tim_embeddings), len(numpy_adv_admix_embeddings), len(numpy_adv_aip_embeddings)]
            
            starts = np.insert(np.cumsum(idxs), 0, 0)[:-1]
            splits = [principal_components[start:start+length] for start, length in zip(starts, idxs)]

            plt.figure(figsize=(12, 8))
            colors = ['red', 'blue', 'green', 'purple', 'orange', 'brown', 'pink', 'gray', 'black', 'red', 'blue', 'green', 'purple', 'orange', 'brown', 'pink', 'gray']
            labels = ['TATM', 'DIM', 'BC', 'SIM', 'SIA', 'TIM', 'ADMIX', 'AIP', 'ORIGIN', 'ADV(TATM)', 'ADV(DIM)', 'ADV(BC)', 'ADV(SIM)', 'ADV(SIA)', 'ADV(TIM)', 'ADV(ADMIX)', 'ADV(AIP)']
            markers = ['o', 'o', 'o', 'o', 'o', 'o', 'o', 'o', '*', '^', '^', '^', '^', '^', '^', '^', '^']
            sizes = [10, 10, 10, 10, 10, 10, 10, 10, 300, 300, 300, 300, 300, 300, 300, 300, 300]

            for data, color, label, marker, size in zip(splits, colors, labels, markers, sizes):
                plt.scatter(data[:, 0], data[:, 1], color=color, label=label, alpha=0.5, s=size, marker=marker)

            plt.xlabel('Principal Component 1', fontsize=25)
            plt.ylabel('Principal Component 2', fontsize=25)
            # plt.title('PCA Visualization of Multiple Image Embeddings', fontsize=30)
            plt.xticks(fontsize=20)
            plt.yticks(fontsize=20)
            plt.legend(bbox_to_anchor=(1.25, 1), loc='upper right', fontsize=15, frameon=True, shadow=True)
            # plt.xlim(-20, 30)
            plt.savefig(os.path.join(log_dir, f"multi_pca_visualization_{remove_image_extensions(image_file)}.png"), bbox_inches='tight')
            plt.savefig(os.path.join(log_dir, f"multi_pca_visualization_{remove_image_extensions(image_file)}.pdf"), bbox_inches='tight')
            plt.show()
            plt.close()
        else:
            numpy_typo_embeddings = embeddings_to_numpy(typo_emb)
            numpy_dim_embeddings = embeddings_to_numpy(dim_emb)
            numpy_sim_embeddings = embeddings_to_numpy(sim_emb)
            numpy_sga_embeddings = embeddings_to_numpy(sga_emb)
            numpy_sia_embeddings = embeddings_to_numpy(sia_emb)
            numpy_tim_embeddings = embeddings_to_numpy(tim_emb)
            numpy_admix_embeddings = embeddings_to_numpy(admix_emb)
            numpy_aip_embeddings = embeddings_to_numpy(aip_emb)
            numpy_origin_embeddings = embeddings_to_numpy(origin_emb)

            combined_embeddings = np.vstack((numpy_typo_embeddings, numpy_dim_embeddings, 
                                            numpy_sim_embeddings, numpy_sga_embeddings, 
                                            numpy_sia_embeddings, numpy_tim_embeddings,
                                            numpy_admix_embeddings, numpy_aip_embeddings, numpy_origin_embeddings))

            scaler = StandardScaler()
            data_scaled = scaler.fit_transform(combined_embeddings)

            pca = PCA(n_components=2)
            principal_components = pca.fit_transform(data_scaled)

            idxs = [len(numpy_typo_embeddings), len(numpy_dim_embeddings), len(numpy_sim_embeddings), len(numpy_sga_embeddings),
                    len(numpy_sia_embeddings), len(numpy_tim_embeddings), len(numpy_admix_embeddings), len(numpy_aip_embeddings), len(numpy_origin_embeddings)]
            
            starts = np.insert(np.cumsum(idxs), 0, 0)[:-1]
            splits = [principal_components[start:start+length] for start, length in zip(starts, idxs)]

            plt.figure(figsize=(12, 8))
            colors = ['red', 'blue', 'green', 'purple', 'orange', 'brown', 'pink', 'gray', 'black',]
            labels = ['TATM', 'DIM', 'BC', 'SIM', 'SIA', 'TIM', 'ADMIX', 'AIP', 'ORIGIN',]
            markers = ['o', 'o', 'o', 'o', 'o', 'o', 'o', 'o', '*',]
            sizes = [10, 10, 10, 10, 10, 10, 10, 10, 300,]

            for data, color, label, marker, size in zip(splits, colors, labels, markers, sizes):
                plt.scatter(data[:, 0], data[:, 1], color=color, label=label, alpha=0.5, s=size, marker=marker)

            plt.xlabel('Principal Component 1', fontsize=25)
            plt.ylabel('Principal Component 2', fontsize=25)
            # plt.title('PCA Visualization of Multiple Image Embeddings', fontsize=30)
            plt.xticks(fontsize=20)
            plt.yticks(fontsize=20)
            plt.legend(bbox_to_anchor=(1.2, 1), loc='upper right', fontsize=15, frameon=True, shadow=True)
            # plt.xlim(-20, 30)
            plt.savefig(os.path.join(log_dir, f"multi_pca_visualization_{remove_image_extensions(image_file)}.png"), bbox_inches='tight')
            plt.savefig(os.path.join(log_dir, f"multi_pca_visualization_{remove_image_extensions(image_file)}.pdf"), bbox_inches='tight')
            plt.show()
            plt.close()
            

            # original_point = splits[-1][0]
            # dd_scores = {}
            # methods = ['TATM', 'DIM', 'BC', 'SIM', 'SIA', 'TIM', 'ADMIX', 'AIP']
            # for method, split in zip(methods, splits[:-1]):
            #     dd_scores[method] = dd_score(split, original_point)
            # for method, score in dd_scores.items():
            #     print(f"DD-Scores of {method}   avg_distance:{score[0]:.4f}   avg_pairwise_distance:{score[1]:.4f}")
            # print()
                
            original_point = splits[-1][0]
            direction_diversity_scores = calculate_direction_diversity(splits[:-1], original_point)

            for method, scores in direction_diversity_scores.items():
                print(f"\n{method}:")
                print(f"  Mean Direction: {scores['mean_direction']:.4f}")
                for key, value in scores.items():
                    if key.startswith('avg_angular_diff'):
                        print(f"  {key.replace('_', ' ').title()}: {value:.4f}")
            print()

            pkl_filename = os.path.join(log_dir, f"principal_components_{remove_image_extensions(image_file)}.pkl")
            with open(pkl_filename, 'wb') as f:
                pickle.dump(principal_components, f)
                
            sys.stdout.flush()