import numpy as np
import string
import re
from collections import Counter
import re
from typing import List

# nltk.download('punkt_tab')

def convert_to_capitalized(s):
    if s.isupper():
        return s.capitalize()
    return s

def exact_match_score(prediction, ground_truth):
    return (normalize_answer(prediction) == normalize_answer(ground_truth))

def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
    scores_for_ground_truths = []
    for ground_truth in ground_truths:
        score = metric_fn(prediction, ground_truth)
        scores_for_ground_truths.append(score)
    return max(scores_for_ground_truths)

def accuracy(preds, labels):
    match_count = 0

    for pred, label in zip(preds, labels):
        target = label[0]
        if pred == target or pred[0]==target:
            match_count += 1
    if match_count == 0:
        print(repr(pred))
        print(target)
    return 100 * (match_count / len(preds))


def f1(decoded_preds, decoded_labels):
    f1_all = []
    for prediction, answers in zip(decoded_preds, decoded_labels):
        if type(answers) == list:
            if len(answers) == 0:
                return 0
            f1_all.append(np.max([qa_f1_score(prediction, gt)
                          for gt in answers]))
        else:
            f1_all.append(qa_f1_score(prediction, answers))
    return 100 * np.mean(f1_all)


def qa_f1_score(prediction, ground_truth):
    prediction_tokens = normalize_answer(prediction).split()
    ground_truth_tokens = normalize_answer(ground_truth).split()
    common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
    num_same = sum(common.values())
    if num_same == 0:
        return 0
    precision = 1.0 * num_same / len(prediction_tokens)
    recall = 1.0 * num_same / len(ground_truth_tokens)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1


def normalize_answer(s):
    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)

    def white_space_fix(text):
        return ' '.join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()
    return white_space_fix(remove_articles(remove_punc(lower(s))))

def find_entity_tags(sentence):
    entity_regex = r'(.+?)(?=\s<|$)'
    tag_regex = r'<(.+?)>'
    entity_names = re.findall(entity_regex, sentence)
    tags = re.findall(tag_regex, sentence)

    results = {}
    for entity, tag in zip(entity_names, tags):
        if "<" in entity:
            results[entity.split("> ")[1]] = tag
        else:
            results[entity] = tag
    return results

def match(question, prediction, ground_truth):
    prediction = prediction.replace('\n','').strip().lower()
    for gt in ground_truth:
        if prediction in gt.lower():
            return 1
    return 0


def rouge(rouge, prediction: str, ground_truths: List[str]) -> float:
    if prediction=="":
        return 0.0
    
    highest = 0.0
    
    for ref in ground_truths:
        if ref=="":
            return 0.0
        score = rouge.get_scores(ref, prediction)
        highest = max(highest, score[0]['rouge-l']['f'])

    return highest