import os
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
from sklearn.neighbors import NearestNeighbors

# Function to load images from a folder
def load_images_from_folder(folder_path):
    """Load images from the specified folder and return them as a tensor."""
    images = []
    transform = transforms.Compose([
        transforms.Resize((224, 224)),  # Resize to match model input size
        transforms.ToTensor(),            # Convert image to tensor
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # Normalize
    ])
    
    for filename in os.listdir(folder_path):
        if filename.endswith(('.png', '.jpg', '.jpeg')):
            image_path = os.path.join(folder_path, filename)
            image = Image.open(image_path).convert("RGB")  # Load image and convert to RGB
            image_tensor = transform(image)  # Apply transformations
            images.append(image_tensor)  # Add to list
            
    return torch.stack(images)  # Convert list of tensors to a single tensor

# Function to extract features using a pre-trained model
def extract_features(images, model):
    """Extract features from a batch of images using the specified model."""
    model.eval()  # Set the model to evaluation mode
    with torch.no_grad():
        features = model(images)
    return features.cpu().numpy()  # Return features as a NumPy array

# Function to compute precision
def compute_precision(real_images, generated_images, threshold):
    """Compute the precision of generated images compared to real images."""
    # Extract features from real and generated images
    real_features = extract_features(real_images, model)
    generated_features = extract_features(generated_images, model)

    # Use KNN to find the nearest neighbors
    nbrs = NearestNeighbors(n_neighbors=1).fit(real_features)
    distances, _ = nbrs.kneighbors(generated_features)

    # Count how many generated images are within the threshold
    realistic_count = (distances < threshold).sum()

    # Calculate precision
    precision = realistic_count / len(generated_images)
    return precision

# Load the pre-trained Inception model for feature extraction
model = models.inception_v3(pretrained=True, transform_input=False)

def get_all_subfolders(parent_folder):
    subfolders = []
    for dirpath, dirnames, _ in os.walk(parent_folder):
        for dirname in dirnames:
            subfolder_path = os.path.join(dirpath, dirname)
            subfolders.append(subfolder_path)
    return subfolders

import pandas as pd

def save_fid_to_csv(foldername, fid_value, output_csv):
    df = pd.DataFrame({"foldername": foldername, 'FID': [fid_value]})
    df.to_csv(output_csv, mode='a', index=False, header=not os.path.exists(output_csv))

reference_folder = 'outputs/style/wikiart/vangogh_ensemble_ASPL_style_loss_upscaling/image_clean/image_van_gogh_small'
target_folder = 'evaluate/generate'    
output_csv = 'fid_results.csv'  

all_folders = get_all_subfolders(target_folder)
output_csv = 'precision_results.csv'           

from tqdm import tqdm
for generated_folder in tqdm(all_folders):
    foldername = os.path.basename(generated_folder)
    
    threshold = 10 

    real_images = load_images_from_folder(reference_folder)
    generated_images = load_images_from_folder(generated_folder)
    precision_value = compute_precision(real_images, generated_images, threshold)
    print(foldername)
    print(f'Precision: {precision_value:.4f}')

    save_fid_to_csv(foldername, precision_value, output_csv)
    print(f'Results saved to {output_csv}')
    
