import enum

import torch

from transformers import (
    BertConfig,
    BertForMultipleChoice,
    BertForSequenceClassification,
    BertTokenizer,
    RobertaConfig,
    RobertaForMultipleChoice,
    RobertaForSequenceClassification,
    RobertaTokenizer,
    T5Config,
    T5Tokenizer,
    XLNetConfig,
    XLNetForMultipleChoice,
    XLNetForSequenceClassification,
    XLNetTokenizer,
)

try:
    from t5_models import T5ForSequenceClassification, T5ForMultipleChoice
except ImportError as e:
    T5ForSequenceClassification = None
    T5ForMultipleChoice = None

import utils_multiple_choice
import utils_text_classification
import utils_seq2choices

class TaskType(enum.Enum):
    MULTIPLE_CHOICE = 'mc'
    TEXT_CLASSIFICATION = 'tc'
    SEQUENCE_MULTIPLE_CHOICE = 'seq2mc'

MODEL_CLASSES = {
    TaskType.MULTIPLE_CHOICE: {
        "bert": (BertConfig, BertForMultipleChoice, BertTokenizer),
        "xlnet": (XLNetConfig, XLNetForMultipleChoice, XLNetTokenizer),
        "roberta": (RobertaConfig, RobertaForMultipleChoice, RobertaTokenizer),
    },
    TaskType.TEXT_CLASSIFICATION: {
        #"t5": (T5Config, T5ForSequenceClassification, T5Tokenizer),
        "bert": (BertConfig, BertForSequenceClassification, BertTokenizer),
        "xlnet": (XLNetConfig, XLNetForSequenceClassification, XLNetTokenizer),
        "roberta": (RobertaConfig, RobertaForSequenceClassification, RobertaTokenizer),
    },
    TaskType.SEQUENCE_MULTIPLE_CHOICE: {
        "t5": (T5Config, T5ForMultipleChoice, T5Tokenizer),
    }
}

def get_task_type(task):
    # TODO: Check if there is key overlap and raise an error if so
    if task in utils_multiple_choice.processors:
        return TaskType.MULTIPLE_CHOICE
    elif task in utils_text_classification.processors:
        return TaskType.TEXT_CLASSIFICATION
    elif task.startswith('seq2mc-') and get_seq2mc_subtask(task) in set.union(set(utils_text_classification.processors.keys()), set(utils_multiple_choice.processors.keys())):
        return TaskType.SEQUENCE_MULTIPLE_CHOICE
    raise ValueError("Unknown task {}".format(task))

def get_all_task_keys():
    return list(utils_multiple_choice.processors.keys()) + list(utils_text_classification.processors.keys())

def get_seq2mc_subtask(task):
    return task[len('seq2mc-'):]

def get_task_processor(task):
    tasktype = get_task_type(task)
    if tasktype == TaskType.MULTIPLE_CHOICE:
        return utils_multiple_choice.processors[task]()
    elif tasktype == TaskType.TEXT_CLASSIFICATION:
        return utils_text_classification.processors[task]()
    elif tasktype == TaskType.SEQUENCE_MULTIPLE_CHOICE:
        subtask = get_seq2mc_subtask(task)
        subtasktype = get_task_type(subtask)
        if subtasktype == TaskType.MULTIPLE_CHOICE:
            return utils_seq2choices.Seq2McMultipleChoiceWrapper(utils_multiple_choice.processors[subtask]())
        elif subtasktype == TaskType.TEXT_CLASSIFICATION:
            return utils_seq2choices.Seq2McTextClassificationWrapper(utils_text_classification.processors[subtask]())
    raise ValueError("Unknown task {}".format(task))

def load_and_cache_examples_for_task(task, *args, **kwargs):
    tasktype = get_task_type(task)
    if tasktype == TaskType.MULTIPLE_CHOICE:
        return utils_multiple_choice.load_and_cache_examples_from_file(*args, **kwargs)
    elif tasktype == TaskType.TEXT_CLASSIFICATION:
        return utils_text_classification.load_and_cache_examples_from_file(*args, **kwargs)
    elif tasktype == TaskType.SEQUENCE_MULTIPLE_CHOICE:
        return utils_seq2choices.load_and_cache_examples_from_file(*args, **kwargs)
    raise ValueError("Unknown task {}".format(task))

def examples_to_dataset(task_name, model_type, max_seq_length, examples, label_list, tokenizer):
    tasktype = get_task_type(task_name)
    convert_examples_to_features_fn = None
    select_field_fn = None
    if tasktype == TaskType.MULTIPLE_CHOICE:
        convert_examples_to_features_fn = utils_multiple_choice.convert_examples_to_features
        select_field_fn = utils_multiple_choice.select_field
    elif tasktype == TaskType.TEXT_CLASSIFICATION:
        convert_examples_to_features_fn = utils_text_classification.convert_examples_to_features
        select_field_fn = utils_text_classification.select_field
    else:
        raise ValueError("Unknown task {}".format(task_name))

    features = convert_examples_to_features_fn(
        examples,
        label_list,
        max_seq_length,
        tokenizer,
        pad_on_left=bool(model_type in ["xlnet"]),  # pad on the left for xlnet
        pad_token_segment_id=4 if model_type in ["xlnet"] else 0,
    )

    # Convert to Tensors and build dataset
    all_input_ids = torch.tensor(select_field_fn(features, "input_ids"), dtype=torch.long)
    all_input_mask = torch.tensor(select_field_fn(features, "input_mask"), dtype=torch.long)
    all_segment_ids = torch.tensor(select_field_fn(features, "segment_ids"), dtype=torch.long)
    all_label_ids = torch.tensor([f.label for f in features], dtype=torch.long)

    dataset = torch.utils.data.TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
    return dataset

