'''
Taken from the Instruction Induction paper: https://arxiv.org/pdf/2205.10782.pdf
'''

import re
import string
from collections import Counter
import torch
import random
import numpy as np

TASKS=[
    'antonyms', 'common_concept',  'first_word_letter',
    'informal_to_formal', 'larger_animal', 'letters_list', 'taxonomy_animal', 'negation', 
    'num_to_verbal', 'active_to_passive', 'singular_to_plural', 'rhymes',
    'second_word_letter',  'sentiment', 'orthography_starts_with',
    'sum', 'synonyms', 'translation_en-de', 'translation_en-es',
    'translation_en-fr', 'word_in_context', 'auto_categorization', 'auto_debugging', 'ascii', 'cs_algorithms',
    'periodic_elements', 'word_sorting', 'word_unscrambling', 'odd_one_out', 'object_counting'
]

# TODO: add some more metrics here for the new tasks.

TASK_TO_METRIC = {'common_concept': 'f1', 'informal_to_formal': 'f1', 'orthography_starts_with': 'es',
                  'taxonomy_animal': 'es', 'synonyms': 'contains'}
bbh_multi_choice = [
    'cause_and_effect','date_understanding', 'disambiguation_qa',
    'geometric_shapes', 'hyperbaton',
    'logical_deduction_five_objects', 'logical_deduction_seven_objects', 'logical_deduction_three_objects',
    'movie_recommendation', 'sentence_similarity',
    'penguins_in_a_table', 'reasoning_about_colored_objects', 'ruin_names',
    'salient_translation_error_detection', 'snarks', 'temporal_sequences',
    'tracking_shuffled_objects_five_objects', 'tracking_shuffled_objects_seven_objects',
    'tracking_shuffled_objects_three_objects','diff'
]
bbh_binary_choice = [
    'boolean_expressions', 'casual_judgement','dyck_languages','formal_fallacies', 'multistep_arithmetic_two','navigate','sports_understanding','web_of_lies'
]
for task in bbh_multi_choice:
    TASK_TO_METRIC[task] = 'multichoice'
for task in bbh_binary_choice:
    TASK_TO_METRIC[task] = 'binarychoice'    
default_metric = 'em'


def normalize_prediction(prediction, lowercase=True):
    prediction = prediction.replace(' and ', ' ')
    prediction = prediction.replace('Sentence 1:', ' ')
    prediction = prediction.replace('Sentence 2:', ' ')
    prediction = prediction.strip()
    prediction = prediction.split("\n")[0]
    prediction = prediction.split(".")[0]

    if lowercase:
        prediction = prediction.lower()

    # remove punctuation
    prediction = prediction.replace('-', ' ')
    prediction = prediction.translate(
        str.maketrans('', '', string.punctuation))

    return prediction


def get_f1_score(prediction, ground_truth):
    prediction_tokens = normalize_prediction(
        prediction, lowercase=True).split()
    ground_truth_tokens = normalize_prediction(
        ground_truth, lowercase=True).split()
    common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
    num_same = sum(common.values())
    if num_same == 0:
        return 0
    precision = 1.0 * num_same / len(prediction_tokens)
    recall = 1.0 * num_same / len(ground_truth_tokens)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1


def get_em_score(prediction, ground_truth):
    prediction_normalized = normalize_prediction(prediction, lowercase=True)
    ground_truth_normalized = normalize_prediction(
        ground_truth, lowercase=True)
    return prediction_normalized == ground_truth_normalized


def get_exact_set_score(prediction, ground_truth):
    prediction_normalized = normalize_prediction(
        prediction, lowercase=True).split()
    ground_truth_normalized = normalize_prediction(
        ground_truth, lowercase=True).split()
    return int(set(prediction_normalized) == set(ground_truth_normalized))


def get_contains_score(prediction, ground_truth):
    prediction_normalized = normalize_prediction(prediction, lowercase=True)
    ground_truth_normalized = normalize_prediction(
        ground_truth, lowercase=True)
    if re.search(r'\b({0})\b'.format(ground_truth_normalized), prediction_normalized):
        return 1


def get_multi_answer_em(prediction, answers):
    for answer in answers:
        if get_em_score(prediction, answer) == 1:
            return 1
    return 0


def get_binary_choice_score(prediction, ground_truth):
    prediction1 = prediction.split("\n")[0]
    prediction_normalized1 = prediction1.split(".")[0]
    prediction2 = prediction.split("\n")[-1]
    prediction_normalized2 = prediction2.split(".")[-1]
    return ground_truth in prediction_normalized1 or ground_truth in prediction_normalized2

def get_multi_choice_score(prediction, ground_truth):
    prediction = prediction.split("\n")[-1]
    prediction_normalized = prediction.split(".")[-1]
    ground_truth_normalized = normalize_prediction(
        ground_truth, lowercase=False)
    return ground_truth_normalized in prediction_normalized


def get_multi_choice(prediction, answers):
    for answer in answers:
        if get_multi_choice_score(prediction, answer) == 1:
            return 1
    return 0

def get_binary_choice(prediction, answers):
    for answer in answers:
        if get_binary_choice_score(prediction, answer) == 1:
            return 1
    return 0

def get_multi_answer_f1(prediction, answers):
    f1_scores = []
    for answer in answers:
        f1_scores.append(get_f1_score(prediction, answer))
    return max(f1_scores)


def get_multi_answer_exact_set(prediction, answers):
    for answer in answers:
        if get_exact_set_score(prediction, answer) == 1:
            return 1
    return 0


def get_multi_answer_contains(prediction, answers):
    for answer in answers:
        if get_contains_score(prediction, answer) == 1:
            return 1
    return 0

def set_all_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    
    return f"Set all the seeds to {seed} successfully!"