from data.vocab import load_esm_alphabet
from functools import partial
from datasets import Dataset, Optional
from transformers import PreTrainedTokenizerBase
from lightning_transformers.task.nlp.language_modeling import LanguageModelingDataModule


class ProteinLanguageModelingDataModule(LanguageModelingDataModule):

    def process_data(self, dataset: Dataset, stage: Optional[str] = None) -> Dataset:
        # `process_data` converting the dataset into features.
        # The dataset is pre-loaded using `load_dataset`.
        ...
        return dataset

    @staticmethod
    def tokenize_function(
        examples,
        tokenizer: Union[PreTrainedTokenizerBase],
        text_column_name: str = None,
    ):
        # tokenizes the data in a specific column using the AutoTokenizer,
        # called by `process_data`
        return tokenizer(examples[text_column_name])

    @staticmethod
    def convert_to_features(examples, block_size: int = None):
        # `process_data` calls this function to convert samples in the dataset into features
        ...

    @property
    def collate_fn(self) -> Callable:
        # `Describes how to collate the samples for the batch given to the model`
        return default_data_collator

class ProteinLanguageModelingDataModule(LanguageModelingDataModule):

    def __init__(self, config, alphabet, *args, **kwargs):
        super().__init__(alphabet.tokenize, *args, **kwargs)
        self.alphabet = alphabet
        self.save_hyperparameters(config)
        self.config = config
        # self.tokenized_condition_term = tokenizer("This is a story: ")

    def load_dataset(self, ):
        pass

    def process_data(self, dataset: Dataset, stage: Optional[str] = None) -> Dataset:
        ...
        # Pass in our additional condition term when converting to features
        convert_to_features = partial(
            self.convert_to_features,
            block_size=self.effective_block_size,
            tokenized_condition_term=self.tokenized_condition_term
        )
        ...
        return dataset

    @staticmethod
    def convert_to_features(examples, block_size: int, **kwargs):
        # Our argument is passed in via kwargs
        tokenized_condition_term = kwargs['tokenized_condition_term']

        ...
        # Add the term to the tokenized blocks of text
        result = {
            k: [tokenized_condition_term + t[i:i + block_size] for i in range(0, total_length, block_size)]
            for k, t in concatenated_examples.items()
        }
        result["labels"] = result["input_ids"].copy()
        return result
    

if __name__ == "__main__":
    pass