from PIL import Image
import os
import io
import numpy as np
import torch
import torchvision
import rewards
import csv

img_folder = "logs/SMC/pick/2024.09.25_20.39.04"

aesthetic_fn = rewards.aesthetic_score(torch_dtype = torch.float32, device = 'cuda')
hps_fn = rewards.hps_score(inference_dtype = torch.float32, device = 'cuda')
imagereward = rewards.ImageReward(inference_dtype = torch.float32, device = 'cuda')
pick_fn = rewards.PickScore(inference_dtype = torch.float32, device = 'cuda')
clip_fn = rewards.clip_score(inference_dtype = torch.float32, device = 'cuda')

aesthetic_score = []
hps_score = []
imagereward_score = []
pick_score = []
clip_score = []
image_names = [file for file in os.listdir(img_folder + "/eval_vis") if (file.endswith(('png', 'jpg', 'jpeg')) and not "ess" in file and not "intermediate_rewards" in file)]
for image_name in image_names:

    image_path = os.path.join(img_folder + "/eval_vis", image_name)

    image = Image.open(image_path).convert("RGB")
    image = torchvision.transforms.ToTensor()(image).unsqueeze(0).to('cuda')

    prompt = image_name.split("|")[0].split("_")[-1][:-1]
    print(prompt)

    with torch.no_grad():
        clip_score.append(clip_fn(image, prompt).item())
        aesthetic_score.append(aesthetic_fn(image, prompt).item())
        hps_score.append(hps_fn(image, prompt).item())
        imagereward_score.append(imagereward(image, prompt).item())
        pick_score.append(pick_fn(image, prompt).item())

print(f"Finished evaluating images in {img_folder}")
print("Aesthetic score: ", np.mean(aesthetic_score))
print("Aesthetic score std: ", np.std(aesthetic_score))
print("HPS score: ", np.mean(hps_score))
print("HPS score std: ", np.std(hps_score))
print("Image reward score: ", np.mean(imagereward_score))
print("Image reward score std: ", np.std(imagereward_score))
print("Pick score: ", np.mean(pick_score))
print("Pick score std: ", np.std(pick_score))
print("Clip score: ", np.mean(clip_score))
print("Clip score std: ", np.std(clip_score))
    
# Save the results to a text file
names = ["Aesthetic score", "Aesthetic score std", "HPS score", "HPS score std",
         "Image reward score", "Image reward score std", "Pick score", "Pick score std", "CLIP score", "CLIP score std"]

values = [np.mean(aesthetic_score), np.std(aesthetic_score),
          np.mean(hps_score), np.std(hps_score),
          np.mean(imagereward_score), np.std(imagereward_score),
          np.mean(pick_score), np.std(pick_score),
          np.mean(clip_score), np.std(clip_score)]

# Format the values to 5 decimal places
formatted_values = [f"{v:.5f}" for v in values]

with open(os.path.join(img_folder, "eval_results.csv"), "w", newline='') as f:
    writer = csv.writer(f)
    writer.writerow(names)
    writer.writerow(formatted_values)