from __future__ import absolute_import, division, print_function

import collections
import logging
import math

import numpy as np
import torch
from transformers import (WEIGHTS_NAME, BertConfig,
                          BertForQuestionAnswering, BertTokenizer,
                          XLMConfig, XLMForQuestionAnswering,
                          XLMTokenizer, XLNetConfig,
                          XLNetForQuestionAnswering,
                          XLNetTokenizer,
                          DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer,
                          AlbertConfig, AlbertForQuestionAnswering, AlbertTokenizer)
from torch.utils.data import DataLoader, SequentialSampler, TensorDataset

from utils_squad import (get_predictions, read_squad_example,
                         convert_example_to_features, to_list, convert_examples_to_features, get_all_predictions)

import timeit

RawResult = collections.namedtuple("RawResult",
                                   ["unique_id", "start_logits", "end_logits"])






def load_model(model_path: str, do_lower_case=True):
    config = AlbertConfig.from_pretrained(model_path + "/config.json")
    tokenizer = AlbertTokenizer.from_pretrained('albert-large-v2', do_lower_case=do_lower_case)
    model = AlbertForQuestionAnswering.from_pretrained(model_path, from_tf=False, config=config)
    return model, tokenizer
