import os
import argparse
from tqdm import tqdm
from PIL import Image
import numpy as np

from dsg.csv_loader import load_dsg_file
from eval.ClipScore import ClipScore
from eval.PickScore import PickScore


def main(): 
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_file', type=str, default="./sdv21_new_output.csv", help='Directory of data files. ')
    parser.add_argument('--img_dir', type=str, default="./sdv21_new_output", help='Directory of image files. ')
    parser.add_argument('--metric', type=str, default="clipscore", help='Metric type. ')
    parser.add_argument('--text_field', type=str, default="text", help='Text field for the prompts to be evaluated. ')
    args = parser.parse_args()

    assert args.metric in ["clipscore", "pickscore"]
    assert args.text_field in ["text", "rewritten_text", "decorated_text"]

    assert args.data_file.endswith(".csv")
    dsg_data = load_dsg_file(args.data_file)

    prompt_dict = dsg_data[args.text_field]

    total_score = []

    if args.metric == "clipscore":
        metric = ClipScore()
    elif args.metric == "pickscore": 
        metric = PickScore()
    
    for k, v in tqdm(prompt_dict.items()): 
        text = list(v.values())[0]
        img_path = os.path.join(args.img_dir, k + ".png")
        img = Image.open(os.path.join(args.img_dir, k + ".png"))

        score = metric(text, img)
        # print(score)
        total_score.append(score.detach().item())
    
    score = np.mean(total_score)
    print("CLIP Score: ", score)


if __name__ == "__main__": 
    main()