import os
import torch
import numpy as np
from PIL import Image
from transformers import AutoModel, AutoProcessor
from sklearn.metrics.pairwise import cosine_similarity

def get_average_embedding(folder_path, artist_name, model, processor):
    embeddings = []
    for filename in os.listdir(folder_path):
        if artist_name in filename:
            image_path = os.path.join(folder_path, filename)
            try:
                image = Image.open(image_path).convert("RGB")
            except Exception as e:
                print(f"Error opening image {image_path}: {e}")
                continue
            inputs = processor(images=image, return_tensors="pt")
            with torch.no_grad():
                outputs = model(**inputs)
                # Use the pooled output as the image embedding
                embedding = outputs.pooler_output.squeeze().numpy()
                embeddings.append(embedding)
    if not embeddings:
        raise ValueError(f"No images found with artist name '{artist_name}' in {folder_path}")
    average_embedding = np.mean(embeddings, axis=0)
    return average_embedding

def compute_cosine_similarity(embedding1, embedding2):
    cos_sim = cosine_similarity([embedding1], [embedding2])[0][0]
    return cos_sim

if __name__ == "__main__":

    parser = argparse.ArgumentParser(description='Compute FID score between two directories.')
    parser.add_argument('--dir1', type=str, required=True, help='Path to the first directory')
    parser.add_argument('--dir2', type=str, required=True, help='Path to the second directory')
    parser.add_argument("--results_dir", type=str, default='/data/cluster_name/scratch/$(whoami)/projects/MACE-Update/experiments/experimental_results.csv')
    parser.add_argument("--row_prefix", type=str, default='GCD')
    args = parser.parse_args()
    # Load the model and processor
    model = AutoModel.from_pretrained("tomg-group-umd/CSD-ViT-L")
    processor = AutoProcessor.from_pretrained("tomg-group-umd/CSD-ViT-L")
    
    # Specify the folders and artist names
    folder1 = args.dir1
    artist_name1 = "artist_name_in_filenames_1"
    folder2 = args.dir2
    artist_name2 = "artist_name_in_filenames_2"

    # Compute the average embeddings
    embedding1 = get_average_embedding(folder1, artist_name1, model, processor)
    embedding2 = get_average_embedding(folder2, artist_name2, model, processor)

    # Calculate cosine similarity
    cos_sim = compute_cosine_similarity(embedding1, embedding2)
    print(f"Cosine similarity between the average embeddings: {cos_sim}")
