from argparse import ArgumentParser
from os.path import splitext
from numpy import arange
from numpy import trapz
from numpy import nan_to_num
import random


def process(predictions_with_scores_file):
    
    facts_to_max_scores = {} 

    #  First we extract the facts and scores, aggregating scores of a fact 
    for line in open(predictions_with_scores_file, "r").readlines():
        if line.endswith('\n'):
            line = line[:-1]
        entry = line.split(',') 
        #  Every line contains (at least) the triple query, a copy of the real head, and its score. 
        assert len(entry) >= 5
        predicate = entry[0]
        real_head = entry[1]
        tail = entry[2]
        proposed_heads = entry[3:-1]
        score_real_head = float(entry[-1])

        for proposed_head in proposed_heads:
            fact = (real_head,predicate,tail)
            #  We store the lower bound for the score of the fact, given by the score of the
            #  real head, because DRUM does not output the score for all derived facts.
            facts_to_max_scores[fact] = max(facts_to_max_scores.get(fact,0),score_real_head) 
            
    return facts_to_max_scores


def evaluate(facts_to_scores_dict, truths, output_file):

    threshold_list = [0.0000000001,0.000000001,0.000000001,0.00000001,0.0000001,0.000001,0.00001,0.0001,0.001] + arange(0.01,1,0.01).tolist()
    threshold_list = [ round(elem,10) for elem in threshold_list]

    number_of_positives = 0
    number_of_negatives = 0
  
    num_scored_facts = 0 
    num_unscored_facts = 0
    num_unscored_negative_facts = 0
    
    # This stores the result. Each threshold is mapped to a 4-tuple containing true and false positives and negatives.
    threshold_to_counter = {} 
    entry_for = {"true_positives":0, "false_positives":1, "true_negatives":2, "false_negatives":3}
    threshold_to_counter[0] = [0,0,0,0]
    for threshold in threshold_list:
        threshold_to_counter[threshold] = [0,0,0,0]
  
    debug_flag = False
    LinesTruths = open(truths, 'r').readlines()
    for line in LinesTruths:
        head, relation, tail, truth = line.split()
        # Remove end-of-line character 
        if truth.endswith('\n'): 
            truth = truth[:-1]
        # Check that there is a score for this fact 
        try:
            facts_to_scores_dict[(head, relation, tail)]
            num_scored_facts += 1
        except:
            if not debug_flag: 
                 print("WARNING: No score detected for fact: \n {} \n {} \n {}".format(head, relation, tail))
                 debug_flag = True  
            num_unscored_facts += 1
        # Positive example 
        if truth == '1':
            number_of_positives +=1
            # First consider threshold 0
            # True positive 
            if facts_to_scores_dict.get((head, relation, tail),0) > 0:
                threshold_to_counter[0][entry_for["true_positives"]] +=  1      
            # False negative
            else:
                threshold_to_counter[0][entry_for["false_negatives"]] +=  1      
            # Consider all other thresholds 
            for threshold in threshold_list:
                # True positive 
                if facts_to_scores_dict.get((head, relation, tail),0) >= threshold:
                    threshold_to_counter[threshold][entry_for["true_positives"]] +=  1      
                # False negative
                else:
                    threshold_to_counter[threshold][entry_for["false_negatives"]] +=  1      
        # Negative example 
        else: 
            assert truth == '0', "ERROR: No truth value detected for line {}".format(line)
            try:
                facts_to_scores_dict[(head, relation, tail)]
            except:
                num_unscored_negative_facts += 1
            number_of_negatives +=1
            # First consider threshold 0 
            # False positive 
            if facts_to_scores_dict.get((head, relation, tail),0) > 0:
                threshold_to_counter[0][entry_for["false_positives"]] +=  1      
            # True negative
            else:
                threshold_to_counter[0][entry_for["true_negatives"]] +=  1      
            # Consider all other thresholds 
            for threshold in threshold_list:
                # False positive 
                if facts_to_scores_dict.get((head, relation, tail),0) >= threshold :
                    threshold_to_counter[threshold][entry_for["false_positives"]] +=  1      
                # True negative
                else:
                    threshold_to_counter[threshold][entry_for["true_negatives"]] +=  1      
  
    print("DATASET: {}".format(output_file))
    print("Number of unscored facts: {}, of which {} were negative".format(num_unscored_facts, num_unscored_negative_facts))
    print("Number of scored facts: {}".format(num_scored_facts))

    #  Compute and print result 
    recall_vector = []
    precision_vector = []
    with open(output_file, 'w') as f:
        f.write("Threshold" + '\t' + "Precision" + '\t' + "Recall"+ '\t' + "Accuraccy"+ '\t' + "F1 Score" + '\n')
        for threshold in threshold_to_counter:
            tp,fp,tn,fn = threshold_to_counter[threshold]
            f.write("{}\t{}\t{}\t{}\t{}\n".format(threshold, precision(tp,fp,tn,fn),
                recall(tp,fp,tn,fn), accuracy(tp,fp,tn,fn), f1score(tp,fp,tn,fn)))
            recall_vector.append(recall(tp,fp,tn,fn))
            precision_vector.append(precision(tp,fp,tn,fn))
        # Add extremal points. This ensures a perfect classifier has AUC 1, a random classifier has AUC 0.5, and an `always wrong' classifier has an AUC 0.
        # Without this, a perfect classifier would have a score of 0!! 
        precision_vector.insert(0,0)
        precision_vector.append(1)
        recall_vector.insert(0,1)
        recall_vector.append(0)
        recall_vector = nan_to_num(recall_vector)
        precision_vector = nan_to_num(precision_vector)
        f.write("Area under precision recall curve: {}".format(auprc(precision_vector, recall_vector)))
        f.close()    


def precision(tp,fp,tn,fn):
    value = 0 
    try:
        value = tp/(tp+fp)
    except:
        value = float("NaN")
    finally:
        return value


def recall(tp,fp,tn,fn):
    value = 0 
    try:
        value = tp / (tp+fn) 
    except:
        value = float("NaN")
    finally:
        return value


def accuracy(tp,fp,tn,fn):
    value = 0 
    try:
        value = (tn+tp)/(tp+fp+tn+fn) 
    except:
        value = float("NaN")
    finally:
        return value


def f1score(tp,fp,tn,fn):
    value = 0 
    try:
        value = tp/(tp +  0.5*(fp+fn))
    except:
        value = float("NaN")
    finally:
        return value

#def specificity(tp,fp,tn,fn):
#    value = 0
#    try:
#        value = fp/(fp+tn)
#    except: 
#        value = float("NaN")
#    finally:
#        return value

def auprc(precision_vector, recall_vector):
    return -1 * trapz(precision_vector, recall_vector)

if __name__ == '__main__': 
    # Read the arguments from the command line
    parser = ArgumentParser()
    parser.add_argument('--scores',help="File containing scores for each positive and negative example.")
    parser.add_argument("--truths", help= "File containing the truth values for each positive and negative example.")
    parser.add_argument("--output", help="File where the result of the metrics calculation will be stored.")
    args = parser.parse_args()

    dict_facts_to_scores = process(args.scores)
    evaluate(dict_facts_to_scores, args.truths, args.output)
