import os
from PIL import Image
import pandas as pd
import re
import clip
from tqdm import tqdm
from argparse import ArgumentParser
import torch
import csv
import json


@torch.no_grad()
def calculate_mean_prob(image_dir, object_ls, object_dic, save_path, row_prefix):
    '''
    Returns a dataframe, where the first column is the image name, and the next ten columns are the clip_score with each object
    ------------------------------
    image_dir: Path to the image folder str
    object_ls: List of objects to classify [str, str, ...]
    '''
    
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model, preprocess = clip.load("ViT-B/32", device=device)
    texts_ls=[f'a photo of the {object}' for object in object_ls]
    text_tokens = clip.tokenize(texts_ls).to(device)

    image_filenames=os.listdir(image_dir)
    sorted_image_filenames = sorted(image_filenames, key=lambda x: int(x.split("_")[1].split('.png')[0]))

    prob_results=[]
    for i in tqdm(range(len(sorted_image_filenames))):
        image_name=sorted_image_filenames[i]
        image = preprocess(Image.open(os.path.join(image_dir,image_name))).unsqueeze(0).to(device)
        
        image_features = model.encode_image(image).float()
        text_features = model.encode_text(text_tokens).float()
        image_features /= image_features.norm(dim=-1, keepdim=True)
        text_features /= text_features.norm(dim=-1, keepdim=True)

        probs = (100.0 * image_features @ text_features.T).softmax(dim=-1).cpu().tolist()[0]
    
        prob_result= {"ImageName": image_name,}
        for j, object in enumerate(object_ls):
            prob_result[object] = probs[j]
        prob_results.append(prob_result)

    prob_df = pd.DataFrame(prob_results)
    print(prob_df)

    fields_list = row_prefix.split(',')
    extracted_fields = {}
    # save as excel file
    extracted_fields['algo_name'] = fields_list[0]
    extracted_fields['change'] = fields_list[1]
    extracted_fields['task'] = fields_list[2]
    extracted_fields['config'] = fields_list[3]
    extracted_fields['finetune_algo'] = fields_list[4]
    extracted_fields['finetune_task'] = fields_list[5]
    extracted_fields['finetune_config'] = fields_list[6]
    extracted_fields['prompts_csv'] = fields_list[7]
    extracted_fields['random_seed'] = fields_list[8]
    extracted_fields['type'] = os.path.basename(os.path.dirname(image_dir))
    extracted_fields['metric'] = 'clip_acc'

    # Compute if the image was classified correctly
    if extracted_fields['type'] != 'synonyms':
        prob_df['MaxProbObject'] = prob_df.iloc[:, 1:].idxmax(axis=1)
        prob_df['CorrectClassification'] = prob_df.apply(lambda row: row['MaxProbObject'] == row['ImageName'].split('_')[0].split(' ')[-1], axis=1)
    else:
        prob_df['MaxProbObject'] = prob_df.iloc[:, 1:].idxmax(axis=1)

        prob_df['Key'] = prob_df.apply(lambda row: [key for key, value in object_dic.items() if row['ImageName'].split('_')[0].split(' ')[-1] in value], axis=1)
        prob_df['CorrectClassification'] = prob_df.apply(lambda row: row['MaxProbObject'] in row['Key'], axis=1)

    prob_df.to_csv(results_dir + 'clip_acc_prob_{}.csv'.format(extracted_fields['type']))

    extracted_fields['value'] = prob_df['CorrectClassification'].mean()

    with open(results_dir + 'metrics.csv', 'a', newline='') as csvfile:
        writer = csv.DictWriter(csvfile, fieldnames=extracted_fields.keys())
        writer.writerow(extracted_fields)  # write the data row 


if __name__=='__main__':
    parser = ArgumentParser()
    parser.add_argument("--base_folder", type=str, default='path/to/generated_images')
    parser.add_argument("--results_dir", type=str, default='/data/cluster_name/scratch/$(whoami)/projects/MACE-Update/experiments/experimental_results.csv')
    parser.add_argument("--row_prefix", type=str, default='GCD')  
    args = parser.parse_args()

    base_folder=args.base_folder
    results_dir=args.results_dir
    row_prefix=args.row_prefix

    with open('/data/cluster_name/scratch/$(whoami)/projects/MACE-Update/tasks/object/10_objects_paraphrase.json','r') as file:
        object_dic=json.load(file)
    object_ls=list(object_dic.keys())

    result = calculate_mean_prob(base_folder, object_ls, object_dic, results_dir, row_prefix)

