import os
import torch
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
from tqdm import tqdm
import numpy as np
from scipy.spatial.distance import pdist
import csv
import lpips
from torchvision import transforms

# Load the CLIP model and processor (using openai/clip-vit-large-patch14)
model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")

# Load LPIPS model
lpips_model = lpips.LPIPS(net='alex')

# Device configuration
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
lpips_model = lpips_model.to(device)

# Image preprocessing function
def preprocess_image(image_path):
    image = Image.open(image_path).convert("RGB")
    return processor(images=image, return_tensors="pt")['pixel_values'].squeeze(0)

# Function to preprocess image for LPIPS
def preprocess_image_lpips(image_path):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ])
    image = Image.open(image_path).convert("RGB")
    return transform(image).unsqueeze(0)

# Function to calculate CLIP-based metrics and LPIPS
def calculate_metrics(image_folder, K=20):
    image_folder = os.path.join(image_folder, "eval_vis")
    embeddings = []
    lpips_images = []
    image_files = [os.path.join(image_folder, file) for file in os.listdir(image_folder) if (file.endswith(('png', 'jpg', 'jpeg')) and not "ess" in file and not "intermediate_rewards" in file)]

    if len(image_files) == 0:
        raise ValueError(f"No images found in the folder: {image_folder}")

    # Preprocess images and compute embeddings
    for image_path in tqdm(image_files):
        try:
            # For CLIP
            pixel_values = preprocess_image(image_path).unsqueeze(0).to(device)
            with torch.no_grad():
                embedding = model.get_image_features(pixel_values).cpu().numpy().squeeze()
            embeddings.append(embedding)

            # For LPIPS
            lpips_image = preprocess_image_lpips(image_path).to(device)
            lpips_images.append(lpips_image)
        except Exception as e:
            print(f"Error processing image {image_path}: {e}")
            continue

    embeddings = np.array(embeddings)

    if len(embeddings) == 0:
        raise ValueError("No embeddings were generated. Please check your images and preprocessing steps.")
    
    # ---- Calculate Mean Pairwise Distance (CLIP-based) ----
    pairwise_distances = pdist(embeddings, metric='cosine')
    mean_distance = np.mean(pairwise_distances)
    num_distances = pairwise_distances.size
    std_error = np.std(pairwise_distances) / np.sqrt(num_distances)
    
    # ---- Calculate Truncated CLIP Entropy (TCE) ----
    covariance_matrix = np.cov(embeddings, rowvar=False)
    eigenvalues = np.linalg.eigvalsh(covariance_matrix)[-K:]
    TCE_K = (K / 2) * np.log(2 * np.pi * np.e) + (1 / 2) * np.sum(np.log(eigenvalues))
    
    # ---- Calculate LPIPS-based diversity ----
    lpips_distances = []
    num_images = len(lpips_images)
    for i in range(num_images):
        for j in range(i+1, num_images):
            with torch.no_grad():
                distance = lpips_model(lpips_images[i], lpips_images[j]).item()
            lpips_distances.append(distance)
    
    mean_lpips = np.mean(lpips_distances)
    std_lpips = np.std(lpips_distances)
    
    return mean_distance, std_error, TCE_K, mean_lpips, std_lpips

# Folder containing your images
image_folder = 'logs/align-prop/aesthetic/2024.09.25_12.33.16'

# Calculate metrics
try:
    mean_distance, std_error, TCE, mean_lpips, std_lpips = calculate_metrics(image_folder, K=20)
    print(f"Finished evaluating images in {image_folder}")
    print(f"Mean Pairwise Distance (CLIP-based Diversity Metric): {mean_distance}")
    print(f"Standard Error of the Distance: {std_error}")
    print(f"Truncated CLIP Entropy (TCE): {TCE}")
    print(f"Mean LPIPS Distance: {mean_lpips}")
    print(f"Standard Deviation of LPIPS Distance: {std_lpips}")

    # Save the results to a CSV file
    names = ["Mean Pairwise Distance (CLIP)", "Standard Error of the Distance (CLIP)", 
             "Truncated CLIP Entropy (TCE)", "Mean LPIPS Distance", "Std Dev LPIPS Distance"]
    values = [mean_distance, std_error, TCE, mean_lpips, std_lpips]

    # Format the values to 5 decimal places
    formatted_values = [f"{v:.5f}" for v in values]

    with open(os.path.join(image_folder, "eval_diversity_results.csv"), "w", newline='') as f:
        writer = csv.writer(f)
        writer.writerow(names)
        writer.writerow(formatted_values)

except Exception as e:
    print(f"An error occurred: {e}")