"""
This script is used to post-process the predictions generated by a relation extraction model.
The processed predictions are ready to be evaluated using the external database (e.g., Wikidata).
It reads a CSV file containing the predictions and performs various operations to analyze the correctness of the predictions.
The main function of this script is `main()`, which takes the path to the predictions file and the name of the log file as input.
It processes the predictions, calculates recall and precision metrics, and writes the processed predictions to a new CSV file.
The recall and precision metrics are logged to the specified log file.
"""

import pandas as pd
import ast
import json
import argparse
import os
from os.path import dirname,basename
from collections import namedtuple

 

def get_logger(log_file):
    logging.basicConfig(level=logging.DEBUG, format='%(message)s', filename=log_file, filemode='w')
    console = logging.StreamHandler()
    console.setLevel(logging.INFO)
    logging.getLogger('').addHandler(console)

    def log(s):
        logging.info(s)

    return log

def process_predictions(args, saved_args):
    df = pd.read_csv(args.path)

    df.objects = df.objects.apply(ast.literal_eval)
    df.subjects = df.subjects.apply(ast.literal_eval)

    #By default, crop the answers until "\n\n"
    df["outputs"] = df.outputs.apply(lambda x: x.split("\n\n")[0])

    #df["predictions_objects"] = df.apply(parse_object_triplet, args=(saved_args,), axis=1)
    #df["predictions_subjects"] = df.apply(parse_subject_triplet, args=(saved_args,), axis=1)
    df= df.apply(parse_triplet, args=(saved_args,), axis=1)

    df["num_rels"] = df.objects.apply(len)
    df["num_preds"] = df.predictions_objects.apply(len)

    df["num_correct_objects"] = df.apply(check_correctness_object, axis=1)

    df["num_correct_subjects"] = df.apply(check_correctness_subject, axis=1)

    #Also check if both are correct
    df["num_correct_both"] = df.apply(check_correctness_both, axis=1)

    return df

#BELOW are triplets for (rel, sub, obj)
def parse_object_triplet(row, args):

    text_triplets = row.outputs.split("\n")
    objects_list = []

    for text in text_triplets:
        temp = text.split(args.separator+" ")
        """
        if len(temp) < 3:
            objects_list.append("")
        """
        #Skip the final line, if its prediction is not complete
        if len(temp) < 3 or text[-1] != ")":
            continue
        else:
            obj = temp[2].split(")")[0]
            objects_list.append(obj)

    return objects_list


def parse_subject_triplet(row, args):

    text_triplets = row.outputs.split("\n")
    subjects_list = []

    for text in text_triplets:
        temp = text.split(args.separator+" ")
        """
        if len(temp) < 2:
            subjects_list.append("")
        """
        #Skip the final line, if its prediction is not complete
        if len(temp) < 3 or text[-1] != ")":
            continue
        else:
            sub = temp[1]
            subjects_list.append(sub)

    return subjects_list

def parse_triplet(row, args):
    text_triplets = row.outputs.split("\n")
    subject_object_pairs = []

    for text in text_triplets:
        temp = text.split(args.separator+" ")
        #Skip the final line, if its prediction is not complete
        if len(temp) < 3 or text[-1] != ")":
            continue
        else:
            sub = temp[1]
            obj = temp[2].split(")")[0]
            subject_object_pairs.append((sub, obj))

    #Some lines could be duplicate, remove them
    subject_object_pairs = list(set(subject_object_pairs))

    subjects_list = []
    objects_list = []
    for pair in subject_object_pairs:
        subjects_list.append(pair[0])
        objects_list.append(pair[1])

    row["predictions_subjects"] = subjects_list
    row["predictions_objects"] = objects_list

    return row

def check_correctness_object(row):
    #Iterate over groundtruth list of list of objects
    #if any prediction matches with groundtruth, we delete that prediction from further search.
    objects = row.objects.copy() #list of list of objects
    predictions_objects = row.predictions_objects.copy() #list of objects
    num_correct = 0
 
    for list_obj in objects:
        i = 0

        #If there is no prediction, we won't do the search
        flag_search=len(predictions_objects)>0

        while flag_search:
            pred_obj = predictions_objects[i]
            if pred_obj in list_obj:
                flag_search = False
                num_correct+=1
                del predictions_objects[i]
            else:
                i += 1
                if i >= len(predictions_objects):
                    flag_search = False

    return num_correct


def check_correctness_subject(row):
    #Iterate over groundtruth list of list of subjects
    #if any prediction matches with groundtruth, we delete that prediction from further search.

    subjects = row.subjects.copy() #list of list of objects
    predictions_subjects = row.predictions_subjects.copy() #list of objects
    num_correct = 0
 
    for list_sub in subjects:
        i = 0
        
        #If there is no prediction, we won't do the search
        flag_search=len(predictions_subjects)>0

        while flag_search:
            pred_sub = predictions_subjects[i]
            if pred_sub in list_sub:
                flag_search = False
                num_correct+=1
                del predictions_subjects[i]
            else:
                i += 1
                if i >= len(predictions_subjects):
                    flag_search = False

    return num_correct

def check_correctness_both(row):
    #Iterate over groundtruth list of list of subjects
    #if any prediction matches with groundtruth, we delete that prediction from further search.

    subjects = row.subjects.copy() #list of list of objects
    predictions_subjects = row.predictions_subjects.copy() #list of objects
    objects = row.objects.copy() #list of list of objects
    predictions_objects = row.predictions_objects.copy() #list of objects

    num_correct = 0
 
    for ind in range(len(subjects)):
        list_obj = objects[ind]
        list_sub = subjects[ind]
        i = 0

        #If there is no prediction, we won't do the search
        flag_search=len(predictions_objects)>0

        while flag_search:
            pred_sub = predictions_subjects[i]
            pred_obj= predictions_objects[i]
            if pred_sub in list_sub and pred_obj in list_obj:
                flag_search = False
                num_correct+=1
                del predictions_subjects[i]
                del predictions_objects[i]
            else:
                i += 1
                if i >= len(predictions_subjects):
                    flag_search = False

    return num_correct

def main(args):
     #save this processed_predictions to the same folder
    dir_name = dirname(args.path)
    file_name = basename(args.path)

    #get the saved arguments from the model inference
    with open(os.path.join(dir_name, 'commandline_args.txt'), 'r') as f:
        saved_args_dict_ = json.load(f)
        
    saved_args = namedtuple("SavedArgs", saved_args_dict_.keys())(*saved_args_dict_.values())

    #configure our logger
    log = get_logger(os.path.join(saved_args.experiments_main_folder, saved_args.experiment_folder, args.log))

    processed_predictions = process_predictions(args, saved_args)

    write_file = os.path.join(dir_name, "processed_"+file_name)

    processed_predictions.to_csv(write_file, index=False)

    rec_obj = processed_predictions.num_correct_objects.sum() / processed_predictions.num_rels.sum()
    rec_sub = processed_predictions.num_correct_subjects.sum() / processed_predictions.num_rels.sum()
    rec_both = processed_predictions.num_correct_both.sum() / processed_predictions.num_rels.sum()
    prec_both = processed_predictions.num_correct_both.sum() / processed_predictions.num_preds.sum()

    rec_obj = round(rec_obj,4)
    rec_sub = round(rec_sub,4)
    rec_both = round(rec_both,4)
    prec_both = round(prec_both,4)

    log("Recall is " + str(rec_obj*100) + "% for " + str(processed_predictions.num_rels.sum()) + " predictions.")

    log("Recall is " + str(rec_sub*100) + "% for " + str(processed_predictions.num_rels.sum()) + " predicted subjects.")

    log("Recall for 'subject and object' is " + str(rec_both*100) + "% for " + str(processed_predictions.num_rels.sum()) + " relations.")

    log("Precision for 'subject and object' is " + str(prec_both*100) + "% for " + str(processed_predictions.num_preds.sum()) + " predictions.")

if __name__ == "__main__":

    parser = argparse.ArgumentParser()

    parser.add_argument('-p', '--path', type=str, default="experiment_P1001/seed0/predictions.csv", help='full path to the predictions.csv file')
    parser.add_argument('-l', '--log', type=str, default="eval.log", help='Name of the postprocess log file')
    

    args = parser.parse_args()

    main(args)