import json
import argparse

def get_top_prob(prediction, probability, prob_threshold=0.5):
    preds = prediction.split(' ')
# preds = preds[:topK]
    probs = probability.replace('[', '').replace(']', '')
    probs = probs.split(' ')
    res = []
    if len(probs) == 1 and probs[0] == '':
        probs[0] = '1.0'
    for i, prob in enumerate(probs):
        prob = float(prob)
        if prob >= prob_threshold:
            res.append(preds[i])
        if len(res) == 0:
            res = preds[:1]
        return ' '.join(res)

def split_words(line):
    return line.lower().strip().split(' ')

def get_correct_predictions_word_cluster(target, prediction, word_cluster):
    """
    Calculate predictions based on word cluster generated by CodeWordNet.
    """
    true_positive, false_positive, false_negative = 0, 0, 0
    replacement = dict()
    skip = set()
    for j, p in enumerate(prediction):
        if p in target:
            skip.add(j)
    for i, t in enumerate(target):
        for j, p in enumerate(prediction):
            if t != p and j not in replacement and j not in skip:
                if t in word_cluster and p in word_cluster:
                    t_cluster = word_cluster[t]
                    p_cluster = word_cluster[p]
                    t_cluster, p_cluster = set(t_cluster), set(p_cluster)
                    if len(t_cluster.intersection(p_cluster)) > 0:
                        replacement[j] = t
    for k, v in replacement.items():
        prediction[k] = v
    if target == prediction:
        true_positive = len(target)
    else:
        target = set(target)
        prediction = set(prediction)

        true_positive += len(target.intersection(prediction))
        false_negative += len(target.difference(prediction))
        false_positive += len(prediction.difference(target))
    return true_positive, false_positive, false_negative

def calculate_results(true_positive, false_positive, false_negative):
# avoid dev by 0
    if true_positive + false_positive == 0:
        return 0, 0, 0
    precision = true_positive / (true_positive + false_positive)
    recall = true_positive / (true_positive + false_negative)
    if precision + recall > 0:
        f1 = 2 * precision * recall / (precision + recall)
    else:
        f1 = 0
    return precision, recall, f1
    #return true_positive, recall, f1

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--evaluation-input', type=str,
            help='Path to the evaluation input file')
    parser.add_argument('--prob-threshold', type=float, default=0.3,
            help='Probability threshold for selecting the predicted words')
    args = parser.parse_args()
    input_file = args.evaluation_input

    with open("word_cluster.json", 'r') as f:
        word_cluster = json.load(f)

    true_positive, false_positive, false_negative = 0, 0, 0
    total = 0
    targets = []
    predictions = []
    threshold = args.prob_threshold
    correct = 0
    with open(input_file, 'r') as f:
        for i, line in enumerate(f):
            line = line.strip('\n')
            lines = line.split(',')
            if len(lines) != 3:
              continue
            total += 1
            if lines[1] == lines[0]:
                correct += 1
            #if len(lines[0]) == 0 or len(lines[1]) == 0:
            #  continue
            lines[1] = get_top_prob(lines[1], lines[2], prob_threshold=threshold)
            #print (len(lines[1]), lines[1], len(split_words(lines[0])), lines[0])
            #assert isinstance(lines[1], str) and len(lines[1]) > 0, "Don't give empty prediction"
            if isinstance(lines[1], str) and len(lines[1]) == 0: #, "Don't give empty prediction"
                false_positive += len(split_words(lines[0]))
                continue

            targets.append(lines[0])
            predictions.append(lines[1])
            target = split_words(lines[0])
            prediction = split_words(lines[1])
            tp, fp, fn = get_correct_predictions_word_cluster(target, prediction, word_cluster)
            true_positive += tp
            false_positive += fp
            false_negative += fn

    precision, recall, f1 = calculate_results(true_positive, false_positive, false_negative)

    temps = input_file.split('_')
    print("{} {} {} {}: {}".format(temps[2], temps[3], temps[4], temps[5], f1))


if __name__ == '__main__':
    main()
