import os
import numpy as np
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from scipy.linalg import sqrtm
from PIL import Image
import pandas as pd
from tqdm import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"
vgg = models.vgg19(pretrained=True).features.eval().to(device)

def get_image_tensor(image_path):
    image = Image.open(image_path).convert("RGB")
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    return transform(image).unsqueeze(0).to(device)

def get_features_from_folder(folder_path):
    features = []
    for filename in os.listdir(folder_path):
        if filename.endswith(('.png', '.jpg', '.jpeg')):
            image_path = os.path.join(folder_path, filename)
            image_tensor = get_image_tensor(image_path)
            with torch.no_grad():
                feature = vgg(image_tensor).view(image_tensor.size(0), -1)
                features.append(feature.cpu().numpy().squeeze())
    return np.array(features)

def calculate_fid(real_features, generated_features):
    mu1  = real_features.mean(axis=0)
    mu2 = generated_features.mean(axis=0)

    ssd = np.sum((mu1 - mu2) ** 2)
    # cov_sqrt = sqrtm(sigma1 @ sigma2)

    fid = ssd 
    # + np.trace(sigma1 + sigma2 - 2 * cov_sqrt)
    return fid

def compute_fid(reference_folder, generated_folder):
    print("compute fid")
    real_features = get_features_from_folder(reference_folder)[:6]
    generated_features = get_features_from_folder(generated_folder)[:6]

    fid = calculate_fid(real_features, generated_features)
    return fid

def save_fid_to_csv(foldername, fid_value, output_csv):
    df = pd.DataFrame({"foldername": foldername, 'FID': [fid_value]})
    df.to_csv(output_csv, mode='a', index=False, header=not os.path.exists(output_csv))

def get_all_subfolders(parent_folder):
    subfolders = []
    for dirpath, dirnames, _ in os.walk(parent_folder):
        for dirname in dirnames:
            subfolder_path = os.path.join(dirpath, dirname)
            subfolders.append(subfolder_path)
    return subfolders

reference_folder = 'outputs/style/wikiart/vangogh_ensemble_ASPL_style_loss_upscaling/image_clean/image_van_gogh_small' 
target_folder = 'evaluate/generate'      
output_csv = 'fid_results.csv'     

all_folders = get_all_subfolders(target_folder)

for dirpath in tqdm(all_folders):
    generated_folder = dirpath
    foldername = os.path.basename(dirpath)

    fid_value = compute_fid(reference_folder, generated_folder)
    print(f'FID for {foldername}: {fid_value}')

    save_fid_to_csv(foldername, fid_value, output_csv)
    print(f'Results saved to {output_csv}')
    