from enum import Enum
from transformers import GPT2LMHeadModel
from ._datasets import Wikitext
import torch


class LanguageModel(str, Enum):
    GPT2 = "gpt2"
    GPT2_XL = "gpt2-xl"

    def load(self) -> torch.nn.Module:
        return GPT2LMHeadModel.from_pretrained(self.value)  # type: ignore

    def get_dataset(self):
        return Wikitext(self)
