import pandas as pd


class Question_Passage_Match_Classifier(object):
    """result are from a T5-3b model finetuned on train set of MMQA."""

    def __init__(self):
        self.qa_pairs_should_retrieve = None
        self.load_retrieve_info()

    def load_retrieve_info(self):
        df = pd.read_csv("./qpmc_mmqa_dev.csv")
        qa_pairs_should_retrieve = {}
        for index, row in df.iterrows():
            qa = row['question'].lower()
            prediction = row['prediction']
            qa_pairs_should_retrieve[qa] = True if prediction == "['yes']" else False
        self.qa_pairs_should_retrieve = qa_pairs_should_retrieve

    def judge_match(self, question, passage):
        return self.qa_pairs_should_retrieve['qa: {} \n {}'.format(question.lower(), passage.lower())]