import argparse
import pprint
import re
import string
import collections
from collections import Counter
import os
import json

# utility to get max
def _metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
    scores_for_ground_truths = []
    for ground_truth in ground_truths:
        score = metric_fn(prediction, ground_truth)
        scores_for_ground_truths.append(score)
    return max(scores_for_ground_truths)


# answer nomalization
def normalize_answer(s):
    """Lower text and remove punctuation, articles and extra whitespace."""

    def remove_articles(text):
        return re.sub(r"\b(a|an|the)\b", " ", text)

    def white_space_fix(text):
        return " ".join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return "".join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))


# F1 score definition
def _f1_score(prediction, ground_truth):
    prediction_tokens = normalize_answer(prediction).split()
    ground_truth_tokens = normalize_answer(ground_truth).split()
    common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
    num_same = sum(common.values())
    if num_same == 0:
        return 0
    precision = 1.0 * num_same / len(prediction_tokens)
    recall = 1.0 * num_same / len(ground_truth_tokens)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1

def _calculate_metrics(gold_records, guess_records):

    assert len(gold_records) == len(
        guess_records
    ), "different size gold: {} guess: {}".format(len(gold_records), len(guess_records))

    total_count = 0

    normalized_f1 = 0
    for guess_item, gold_item in zip(guess_records, gold_records):
        total_count += 1
 
        local_f1 = _metric_max_over_ground_truths(
            _f1_score, guess_item, [gold_item]
        )
        normalized_f1 += local_f1

    normalized_f1 /= total_count
    return {
        "f1": normalized_f1 * 100
    }

def evaluate_f1(pred_file, reference_file):
    with open(pred_file, 'r') as f:
        guess_records = f.readlines()
        guess_records = [record.strip() for record in guess_records]
        
    with open(reference_file, 'r') as f:
        gold_records = f.readlines()
        gold_records = [record.strip() for record in gold_records]

    result = _calculate_metrics(gold_records, guess_records)
    return result

def open_reference_file(args):
    directory = args.data_dir
    reference_file = os.path.join(directory, "references.txt")
    data_path = os.path.join(args.data_dir, args.test_file)
    # if not os.path.exists(reference_file):
    texts = []
    with open(data_path, 'r') as f:
        lines = f.readlines()
        dataset = [json.loads(line) for line in lines]
    
    for data in dataset:
        texts.append(data["label"])

    with open(reference_file, 'w+') as f:
        for text in texts:
            f.write(text + '\n')
    
    return reference_file


if __name__ == "__main__":
    # pred_file = "./save/no_name/candidate.txt"
    pred_file = "./dataset/wow/references.txt"
    reference_file = "./dataset/wow/references.txt"
    # reference_file = "./save/no_name/candidate.txt"
    result = evaluate(pred_file, reference_file)
    print(result)