import pandas as pd
import json, ast
from datetime import datetime, date

pred_extraction_file = "/Volumes/Academic/Projects/PRoMTd/evaluation/2_dateU/data/excel/2_extract_output.xlsx"
gold_extraction_file = "/Volumes/Academic/Projects/PRoMTd/evaluation/2_dateU/data/excel/1_extract_gt.xlsx"

pred_date = pd.read_excel(pred_extraction_file)
gold_date = pd.read_excel(gold_extraction_file)

gold_dict = dict([(gold_date.iloc[idx]["question_idx"], gold_date.iloc[idx]["ground_truth"]) for idx in range(len(gold_date))])

pred_dict = []
for idx in range(len(pred_date)):
    attempt_1_date = pred_date.iloc[idx]["Prompt0_extracted_output"].replace("[", "").replace("]", "").replace("'", "")
    attempt_2_date = pred_date.iloc[idx]["Prompt1_extracted_output"].replace("[", "").replace("]", "").replace("'", "")
    attempt_3_date = pred_date.iloc[idx]["Prompt2_extracted_output"].replace("[", "").replace("]", "").replace("'", "")
    q_idx = pred_date.iloc[idx]["question_idx"]
    pred_dict.append((q_idx, {"attempt_1":attempt_1_date, "attempt_2":attempt_2_date, "attempt_3":attempt_3_date}))

pred_dict = dict(pred_dict)


assert pred_dict.keys() == gold_dict.keys()

#all the dates are in MM/DD/YYYY format, so we should first parse the date from gold_dict and then compare with all the attemps in the pred_dict, and then comapre the accuracy count finally
a_1_count, a_2_count, a_3_count, topk_count = 0, 0, 0, 0
a1_valid_count, a2_valid_count, a3_valid_count, topk_valid_count = 0, 0, 0, 0
error_indices = []
for gold_index in gold_dict:
    x, y = 0, 0
    gold_date = datetime.strptime(gold_dict[gold_index], '%m/%d/%Y').date()
    pred_date = pred_dict[gold_index]
    try:
        pred_date_1 = datetime.strptime(pred_date["attempt_1"], '%m/%d/%Y').date()
        a1_valid_count += 1
        y += 1
        if pred_date_1 == gold_date:
            a_1_count += 1
            x+=1
        else:
            print("Predicted Date 1: {}\nGold Date: {}".format(pred_date_1, gold_date))
    except:
        error_indices.append(gold_index)
    try:
        pred_date_2 = datetime.strptime(pred_date["attempt_2"], '%m/%d/%Y').date()
        a2_valid_count += 1
        y += 1
        if pred_date_2 == gold_date:
            a_2_count += 1       
            x+=1  
    except:
        error_indices.append(gold_index)
    try:
        pred_date_3 = datetime.strptime(pred_date["attempt_3"], '%m/%d/%Y').date()
        a3_valid_count += 1
        y += 1
        if pred_date_3 == gold_date:
            a_3_count += 1
            x+=1
    except:
        error_indices.append(gold_index)
    if(x>0):
        topk_count += 1
    if(y>0):
        topk_valid_count += 1



print(f"Attempt 1 Full accuracy: {a_1_count/len(gold_dict)}")
print(f"Attempt 2 Full accuracy: {a_2_count/len(gold_dict)}")
print(f"Attempt 3 Full accuracy: {a_3_count/len(gold_dict)}")


print(f"Attempt 1 Valid accuracy: {a_1_count/a1_valid_count}")
print(f"Attempt 2 Valid accuracy: {a_2_count/a2_valid_count}")
print(f"Attempt 3 Valid accuracy: {a_3_count/a3_valid_count}")



print(f"Top-K Full accuracy: {topk_count/len(gold_dict)}")
print(f"Top-K Valid accuracy: {topk_count/topk_valid_count}")


print("Error Indices: ", list(set(error_indices)))
