


import torch
from transformers import BertTokenizer, BertForQuestionAnswering


def main():
    # Finetuned BERT for SQUAD Task
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    model = BertForQuestionAnswering.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')

    question = '''Why was the student group called "the Methodists?"'''

    paragraph = ''' The movement which would become The United Methodist Church began in the mid-18th century within the Church of England.
                A small group of students, including John Wesley, Charles Wesley and George Whitefield, met on the Oxford University campus.
                They focused on Bible study, methodical study of scripture and living a holy life.
                Other students mocked them, saying they were the "Holy Club" and "the Methodists", being methodical and exceptionally detailed in their Bible study, opinions and disciplined lifestyle.
                Eventually, the so-called Methodists started individual societies or classes for members of the Church of England who wanted to live a more religious life. '''

    encoding = tokenizer.encode_plus(text=question, text_pair=paragraph, add_special_tokens=True)

    inputs = encoding['input_ids']  # Token embeddings
    sentence_embedding = encoding['token_type_ids']  # Segment embeddings
    tokens = tokenizer.convert_ids_to_tokens(inputs)  # input tokens

    result = model(input_ids=torch.tensor([inputs]),
                                     token_type_ids=torch.tensor([sentence_embedding]))

    start_index = torch.argmax(result.start_logits)
    end_index = torch.argmax(result.end_logits)


    answer = ' '.join(tokens[start_index:end_index + 1])
    print (answer)

if __name__ == '__main__':
    main()