import torch
import clip
from PIL import Image
import json 
import pandas as pd 
from pathlib import Path
from reward import ImageRewardScorer
import csv 

# ------------------------
# Setup
# ------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-L/14", device=device, jit=False)
model.eval()

def get_prompt_by_id(data, target_id):
    for item in data:
        if item["id"] == target_id:
            return item["prompt"]
    return None



with open("prompts.json", "r", encoding="utf-8") as file:
    data = json.load(file)



# with open('complete_prompts.csv', 'w', newline='') as csvfile:
#     csvwriter = csv.writer(csvfile, delimiter=' ',
#                             quotechar='|', quoting=csv.QUOTE_MINIMAL)
#     for line in data: 
#         csvwriter.writerow([line['id']])
with open('image_rewards.csv', 'w', newline='') as csvfile:
    csvwriter = csv.writer(csvfile, delimiter=' ',
                            quotechar='|', quoting=csv.QUOTE_MINIMAL)





    model_name = 'IR/30/v1.4'

    ids_csv = 'complete_prompts.csv'
    file = pd.read_csv(ids_csv)

    all_ids = file["id"].tolist()

    all_image_path = 'all_images/'
    k = 4


    image_reward_scorer = ImageRewardScorer('cuda')
    for img_id in all_ids:
        image_paths = [f'{model_name}/{all_image_path}{img_id}_img_{i}_k_4_chains_3_.jpg' for i in range(k)]
        check_path = Path(image_paths[0])
        assert check_path.exists()
        prompt = get_prompt_by_id(data, img_id)
        score = image_reward_scorer.score_images(image_paths, prompt)
        for s in range(score.size(0)):
            csvwriter.writerow([f'ID: {img_id}, IR: {score[s]}'])
