import string
from tqdm import tqdm

def filter_yes_no_answer(questions_list, answers_list):
    # list a few
    # ignore 'can', 'could' since questions may be like 'Can you name a United Nations observer state?'
    punc = string.punctuation
    verb_list = ['is', 'am', 'are', 'was', 'were', 'do', 'does', 'did', 'have', 'has', 'had', 'will', 'would', 'may', 'might', 'shall', 'should', 'must']
    yn_word_list = ['yes', 'no']
    ret = []
    for i, cur_answer in enumerate(answers_list):
        # newly modified (04/21/2023)
        if len(cur_answer.split(' ')) > 1:
            continue
        for p in punc:
            cur_answer = cur_answer.replace(p, '')
        if any(_ == cur_answer.lower() for _ in yn_word_list): # First, find answer only contains "yes"/"no"
            ret.append(i)
    if not ret: # newly added (04/21/2023)
        for i, cur_answer in enumerate(answers_list): # Second, include answers like "Yes, xxxxx", or "No, xxxxx"
            first_token = cur_answer.split(' ')[0]
            for p in punc:
                first_token = first_token.replace(p, '')
            if any(_ == first_token.lower() for _ in yn_word_list):
                ret.append(i)
    if not ret: # Third, answer does not begin with yes/no, find questions that should start with yes/no.
        for i, cur_question in enumerate(questions_list):
            if any(_ == cur_question.strip().split(' ')[0].lower() for _ in verb_list) and 'unknown' not in answers_list[i].lower():
                ret.append(i)
    return ret

def generate_yes_no_index_list(data):
    yn_qa_idx_list = []
    for i in tqdm(range(data.num_rows)):
        questions_list = data['questions'][i]
        answers_list = data['answers'][i]['input_text']
        tmp_idx_list = filter_yes_no_answer(questions_list, answers_list)
        yn_qa_idx_list.append(tmp_idx_list)
    return yn_qa_idx_list