import os
from PIL import Image
import pandas as pd
import numpy as np
from transformers import CLIPProcessor, CLIPModel
from tqdm import tqdm
from argparse import ArgumentParser
import torch
import csv


@torch.no_grad()
def mean_clip_score(image_dir, prompts_path, results_dir, row_prefix):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 
    model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").eval().to(device)
    processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

    text_df=pd.read_csv(prompts_path)
    texts=list(text_df['prompt'])
    image_filenames=os.listdir(image_dir)
    image_filenames=[x for x in image_filenames if '.png' in x]
    print(len(texts)), print(len(image_filenames))
    assert len(texts)==len(image_filenames), "Number of images and prompts don't match"
    
    sorted_image_filenames = sorted(image_filenames, key=lambda x: int(x.split("_")[1].split('.png')[0]))
    similarities=[]
    for i in tqdm(range(len(texts))):
        text=texts[i]
        imagename=sorted_image_filenames[i]
        image=Image.open(os.path.join(image_dir,imagename))
        inputs = processor(text=text, images=image, return_tensors="pt", padding=True)
        outputs = model(**{k : v.to(device) for k, v in inputs.items()})
        clip_score= outputs.logits_per_image[0][0].detach().cpu()  # this is the image-text similarity score
        # print(text)
        # print(imagename)
        # print(clip_score)
        similarities.append(clip_score)
    similarities=np.array(similarities)
    
    mean_similarity=np.mean(similarities)
    std_similarity = np.std(similarities)

    print('-------------------------------------------------')
    print('\n')
    print(f"Mean CLIP score ± Standard Deviation: {mean_similarity:.4f}±{std_similarity:.4f}")   

    fields_list = row_prefix.split(',')
    extracted_fields = {}
    # save as excel file
    extracted_fields['algo_name'] = fields_list[0]
    extracted_fields['change'] = fields_list[1]
    extracted_fields['task'] = fields_list[2]
    extracted_fields['config'] = fields_list[3]
    extracted_fields['finetune_algo'] = fields_list[4]
    extracted_fields['finetune_task'] = fields_list[5]
    extracted_fields['finetune_config'] = fields_list[6]
    extracted_fields['prompts_csv'] = fields_list[7]
    extracted_fields['random_seed'] = fields_list[8]
    extracted_fields['type'] =  "mscoco-10k"
    extracted_fields['metric'] = 'clip_score'
    extracted_fields['value'] = f"{mean_similarity:.4f}±{std_similarity:.4f}"

    with open(results_dir + 'metrics.csv', 'a', newline='') as csvfile:
        writer = csv.DictWriter(csvfile, fieldnames=extracted_fields.keys())
        writer.writerow(extracted_fields)  # write the data row 
    

if __name__=='__main__':
    parser = ArgumentParser()
    parser.add_argument("--image_dir", type=str, default='path/to/generated_images')
    parser.add_argument("--prompts_path", type=str, default='./prompts_csv/coco_30k.csv')
    parser.add_argument("--results_dir", type=str, default='/data/cluster_name/scratch/$(whoami)/projects/MACE-Update/experiments/experimental_results.csv')
    parser.add_argument("--row_prefix", type=str, default='GCD')
    args = parser.parse_args()

    image_dir=args.image_dir
    prompts_path=args.prompts_path
    results_dir=args.results_dir
    row_prefix=args.row_prefix
    
    mean_clip_score(image_dir, prompts_path, results_dir, row_prefix)