import json
import argparse


def split_words(line):
    return line.lower().strip().split(' ')


def get_correct_predictions_word_cluster(target, prediction, word_cluster):
    """
    Calculate predictions based on word cluster generated by CodeWordNet.
    """
    true_positive, false_positive, false_negative = 0, 0, 0
    replacement = dict()
    skip = set()
    for j, p in enumerate(prediction):
        if p in target:
            skip.add(j)

    for i, t in enumerate(target):
        for j, p in enumerate(prediction):
            if t != p and j not in replacement and j not in skip:
                if t in word_cluster and p in word_cluster:
                    t_cluster = word_cluster[t]
                    p_cluster = word_cluster[p]
                    t_cluster, p_cluster = set(t_cluster), set(p_cluster)
                    if len(t_cluster.intersection(p_cluster)) > 0:
                        replacement[j] = t

    for k, v in replacement.items():
        prediction[k] = v

    if target == prediction:
        true_positive = len(target)
    else:
        target = set(target)
        prediction = set(prediction)
        true_positive += len(target.intersection(prediction))
        false_negative += len(target.difference(prediction))
        false_positive += len(prediction.difference(target))

    return true_positive, false_positive, false_negative


def calculate_results(true_positive, false_positive, false_negative):
    # avoid dev by 0
    if true_positive + false_positive == 0:
        return 0, 0, 0
    precision = true_positive / (true_positive + false_positive)
    recall = true_positive / (true_positive + false_negative)
    if precision + recall > 0:
        f1 = 2 * precision * recall / (precision + recall)
    else:
        f1 = 0
    return precision, recall, f1


def main(args):
    input_file = f'{args.input_file}/{args.arch}_{args.opt}_function.txt'

    with open('CodeWordNet/word_cluster.json', 'r') as f:
        word_cluster = json.load(f)

    true_positive, false_positive, false_negative = 0, 0, 0
    total = 0
    targets = []
    predictions = []

    total_prec = 0.0
    total_recall = 0.0
    total_f1 = 0.0
    #print (args.arch, args.opt)
    with open(input_file, 'r') as f:
        for i, line in enumerate(f):
            line = line.strip('\n')
            lines = line.split(',')

            assert isinstance(lines[1], str) and len(lines[1]) > 0, "Don't give empty prediction"
            targets.append(lines[0])
            predictions.append(lines[1])
            target = split_words(lines[0])
            prediction = split_words(lines[1])

            tp, fp, fn = get_correct_predictions_word_cluster(target, prediction, word_cluster)
            total += 1
            prec, recall, f1 = calculate_results(tp, fp, fn)
            total_prec += prec
            total_recall += recall
            #total_f1 += f1


    prec = total_prec/total
    recall = total_recall/total
    f1 = 2*prec*recall/(prec + recall)
    print("Probability Precision: {} Recall: {} F1: {}".format(prec,recall,f1))


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Evaluate function name prediction using the provided input file')
    parser.add_argument('-i', '--input_file', type=str, required=True,
        # default='',
        help='Path to the evaluation input file.')
    parser.add_argument('-o', '--opt', type=str, required=True)
    parser.add_argument('-a', '--arch', type=str, required=True)
    parser.add_argument('-n', '--name', type=str, required=False)

    args = parser.parse_args()

    main(args)
