from transformers.data.processors.squad import SquadV2Processor
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
import torch


def process_data(path="src/data/fine_tuning/squad/"):
    # this processor loads the SQuAD2.0 dev set examples
    processor = SquadV2Processor()
    examples = processor.get_dev_examples(path, filename="dev-v2.0.json")
    qid_to_example_index = {example.qas_id: i for i, example in enumerate(examples)}
    qid_to_has_answer = {example.qas_id: bool(example.answers) for example in examples}
    answer_qids = [qas_id for qas_id, has_answer in qid_to_has_answer.items() if has_answer]
    no_answer_qids = [qas_id for qas_id, has_answer in qid_to_has_answer.items() if not has_answer]
    return examples, qid_to_example_index, qid_to_has_answer, answer_qids, no_answer_qids

def get_prediction(qid, examples, qid_to_example_index, model, tokenizer, device='cpu'):
    # given a question id (qas_id or qid), load the example, get the model outputs and generate an answer
    question = examples[qid_to_example_index[qid]].question_text
    context = examples[qid_to_example_index[qid]].context_text
    # if len(context.split()) + len(question.split()) > 512:
    #     context_len = 511 - len(question.split())
    #     context = " ".join(context.split()[:context_len])
    inputs = tokenizer.encode_plus(question, context, return_tensors='pt', truncation='only_second')

    inputs.to(device)
    outputs = model(**inputs)
    answer_start = torch.argmax(outputs[0])  # get the most likely beginning of answer with the argmax of the score
    answer_end = torch.argmax(outputs[1]) + 1 

    answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0][answer_start:answer_end]))

    return answer


def normalize_text(s):
    """Removing articles and punctuation, and standardizing whitespace are all typical text processing steps."""
    import string, re

    def remove_articles(text):
        regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
        return re.sub(regex, " ", text)

    def white_space_fix(text):
        return " ".join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return "".join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))

def compute_exact_match_single(prediction, truth):
    return int(normalize_text(prediction) == normalize_text(truth))

def compute_exact_match(predictions, ground_truths):
    total_count = 0
    for i in range(len(predictions)):
        prediction = predictions[i]
        gold_answers = ground_truths[i]
        max_count = max((compute_exact_match_single(prediction, answer)) for answer in gold_answers)
        total_count += max_count
    return total_count/len(predictions)

def compute_f1_single(prediction, truth):
    pred_tokens = normalize_text(prediction).split()
    truth_tokens = normalize_text(truth).split()
    
    # if either the prediction or the truth is no-answer then f1 = 1 if they agree, 0 otherwise
    if len(pred_tokens) == 0 or len(truth_tokens) == 0:
        return int(pred_tokens == truth_tokens)
    
    common_tokens = set(pred_tokens) & set(truth_tokens)
    
    # if there are no common tokens then f1 = 0
    if len(common_tokens) == 0:
        return 0
    
    prec = len(common_tokens) / len(pred_tokens)
    rec = len(common_tokens) / len(truth_tokens)
    
    return 2 * (prec * rec) / (prec + rec)

def get_match_count(prediction, truth):
    # pred_tokens = normalize_text(prediction).split()
    # truth_tokens = normalize_text(truth).split()

    pred_tokens = prediction
    truth_tokens = truth

    if len(pred_tokens) == 0 or len(truth_tokens) == 0:
        return int(pred_tokens == truth_tokens), len(pred_tokens), len(truth_tokens)
    
    true_positive = set(pred_tokens) & set(truth_tokens)
    false_positive = set(pred_tokens) - set(truth_tokens)
    false_negative = set(truth_tokens) - set(pred_tokens)
    
    return len(true_positive), len(false_positive), len(false_negative)

    
def compute_f1(predictions, ground_truths):
    total_true_positive = 0
    total_flase_positive = 0
    total_false_negative = 0
    for i in range(len(predictions)):
        prediction = predictions[i]
        answer = ground_truths[i]

        true_positive, false_positive, false_negative = get_match_count(prediction, answer)
        
        total_true_positive += true_positive
        total_flase_positive += false_positive
        total_false_negative += false_negative
    
    # if either the prediction or the truth is no-answer then f1 = 1 if they agree, 0 otherwise
    prec = total_true_positive / (total_flase_positive + total_true_positive)
    rec = total_true_positive / (total_false_negative + total_true_positive)
    
    f1 = 2 * (prec * rec) / (prec + rec)
    return f1, prec, rec

def get_gold_answers(example):
    """helper function that retrieves all possible true answers from a squad2.0 example"""
    
    gold_answers = [answer["text"] for answer in example.answers if answer["text"]]

    # if gold_answers doesn't exist it's because this is a negative example - 
    # the only correct answer is an empty string
    if not gold_answers:
        gold_answers = [""]
        
    return gold_answers

def evaluate_model(model, tokenizer, device):

    # tokenizer = AutoTokenizer.from_pretrained("twmkn9/distilbert-base-uncased-squad2")
    # model = AutoModelForQuestionAnswering.from_pretrained("twmkn9/distilbert-base-uncased-squad2")

    examples, qid_to_example_index, qid_to_has_answer, answer_qids, no_answer_qids = process_data()

    prediction_list = []
    ground_truth_list = []
    for qids in answer_qids:
        prediction = get_prediction(qids, examples, qid_to_example_index, model, tokenizer, device)
        example = examples[qid_to_example_index[qids]]

        gold_answers = get_gold_answers(example)
        prediction_list.append(prediction)
        ground_truth_list.append(gold_answers)


    em_score = compute_exact_match(prediction_list, ground_truth_list)
    f1_score = compute_f1(prediction_list, ground_truth_list)

    # print(f"Question: {example.question_text}")
    # print(f"Prediction: {prediction}")
    # print(f"True Answers: {gold_answers}")
    print(f"EM: {em_score} \t F1: {f1_score}")

# evaluate_model()

# tokenizer = AutoTokenizer.from_pretrained("twmkn9/distilbert-base-uncased-squad2")
# model = AutoModelForQuestionAnswering.from_pretrained("twmkn9/distilbert-base-uncased-squad2")

# examples, qid_to_example_index, qid_to_has_answer, answer_qids, no_answer_qids = process_data()


# prediction = get_prediction('57265e455951b619008f70bb', examples, qid_to_example_index, model, tokenizer)
# example = examples[qid_to_example_index['57265e455951b619008f70bb']]

# gold_answers = get_gold_answers(example)

# em_score = max((compute_exact_match_single(prediction, answer)) for answer in gold_answers)
# f1_score = max((compute_f1_single(prediction, answer)) for answer in gold_answers)

# print(f"Question: {example.question_text}")
# print(f"Prediction: {prediction}")
# print(f"True Answers: {gold_answers}")
# print(f"EM: {em_score} \t F1: {f1_score}")