import pytorch_lightning as pl
import os, torch, json
from functools import partial
from data.io import read_fasta
from data.vocab import load_esm_alphabet
from data.datasets import TemporalFastaDataset, TemporalMultiFastaDataset, PairwiseDataset, PairwiseRandomDataset
from data.utils import default_lm_collate_func, build_fake_pairwise_dataset, build_fake_multisource_batch_dataset, multiprot_lm_collate_func, default_seq2seq_collate_func
from torch.utils.data import DataLoader, random_split
import numpy as np
from data.data_modules.base_dm import ProteinDataModule, ProteinGISAIDDataModule
from data.data_modules import register_dm
from utils.args import str2bool
from copy import deepcopy
from collections import defaultdict

@register_dm("lm")
class ProteinLMDataModule(ProteinDataModule):
    def __init__(self, args, vocab=None):
        super().__init__(args)
        if vocab is not None:
            self.vocab = vocab
        else:
            self.vocab = load_esm_alphabet(self.args.vocab, self.args.mol_type)
        # print(self.vocab.all_special_tokens) # ['<eos>', '<unk>', '<pad>', '<cls>', '<mask>']

    @classmethod
    def add_argparse_args(cls, parent_parser):
        parent_parser = super(ProteinLMDataModule, cls).add_argparse_args(parent_parser)
        parent_parser.add_argument('--pairwise_index_cache', type=str, default="")
        parent_parser.add_argument('--sample_src_size', type=int, default=1)
        parent_parser.add_argument('--history_length', type=int, default=1)
        parent_parser.add_argument('--max_history_samples', type=int, default=-1)
        parent_parser.add_argument('--history_data_paths', nargs="+", type=str, default=[])
        parent_parser.add_argument('--predict_starting_time', type=float, default=None, help="When should we start predicting") # 731
        parent_parser.add_argument('--prediction_steps', type=int, default=1, help="How many steps are we going to predict.") # 731
        return parent_parser

    def setup(self, stage):
        # make assignments here (val/train/test split)
        # called on every process in DDP
        # Assign train/val datasets for use in dataloaders
        
        # Load datasets
        self.tokenizer = self.vocab # consistent with TransformerLM
        if stage == "fit" or stage is None:
            dataset = read_fasta(self.args.data_path)
            full_dataset = TemporalFastaDataset(dataset, self.vocab) # get_time_method="simple", properties=['day']
            valid_size = int(len(full_dataset) * self.args.valid_size)
            train_size = len(full_dataset) - valid_size
            self.train_dataset, self.val_dataset = random_split(full_dataset, [train_size, valid_size])

        if stage == "test_multisource":
            test_dataset = read_fasta(self.args.test_data_paths[0])
            print(len(test_dataset))
            test_dataset = TemporalFastaDataset(test_dataset, self.vocab)
            # self.test_datasets.append(TemporalFastaDataset(test_dataset, self.vocab))
            if self.args.history_data_paths:
                history_datasets = []
                for history_data_path in self.args.history_data_paths:
                    history_dataset = read_fasta(history_data_path)
                    # idx = torch.randperm(len(history_dataset))[:self.args.max_history_samples] # Randomly down sampling
                    # history_dataset = [history_dataset[i] for i in idx]
                    # history_dataset = sorted(history_dataset, key=lambda x: float(x[-1].split()[-1]), reverse=True)[:self.args.max_history_samples]
                    history_datasets += history_dataset
                print(len(history_datasets))
                # print(len(history_dataset))
                self.history_dataset = TemporalFastaDataset(history_datasets, self.vocab)
            
            batched_testsets = []
            for repeat in range(self.args.sample_src_size):
                batched_testset = build_fake_multisource_batch_dataset(self.history_dataset, test_dataset, self.args.batch_size, self.args.history_length, ignore_time=False, sort_target=True)
                batched_testsets.append(batched_testset)
            self.test_datasets = batched_testsets
            
        if stage == "test":
            if not hasattr(self, "test_datasets"): # TODO: in prediction time, we have already set the test_datasets
                self.test_datasets = []
                test_dataset = read_fasta(self.args.test_data_paths[0])
                if len(self.args.test_data_paths) > 1:
                    history_dataset = read_fasta(self.args.test_data_paths[1])
                    history_dataset = sorted(history_dataset, key=lambda x: int(x[-1].split()[-1]), reverse=True)[:self.args.max_history_samples]
                    self.history_dataset = history_dataset
                    history_times = list(set([int(x[-1].split()[-1]) for x in history_dataset]))
                    history_times.sort()
                    self.history_times = history_times
                    self.num_history_samples = len(history_dataset)
                    # test_dataset = history_dataset + test_dataset # TODO:?
                    # print(len(test_dataset))
                else:
                    self.num_history_samples = 0
                    self.history_times = []
                self.test_datasets.append(TemporalFastaDataset(test_dataset, self.vocab))
                if self.args.history_data_paths:
                    history_datasets = []
                    for history_data_path in self.args.history_data_paths:
                        history_dataset = read_fasta(history_data_path)
                        idx = torch.randperm(len(history_dataset))[:self.args.max_history_samples] # Randomly down sampling
                        history_dataset = [history_dataset[i] for i in idx]
                        # history_dataset = sorted(history_dataset, key=lambda x: float(x[-1].split()[-1]), reverse=True)[:self.args.max_history_samples]
                        history_datasets += history_dataset
                    # print(len(history_datasets))
                    # print(len(history_dataset))
                    self.history_dataset = TemporalFastaDataset(history_datasets, self.vocab)
            
        if stage == "predict":
            if self.args.history_data_paths is not None:
                history_datasets = []
                for history_data_path in self.args.history_data_paths:
                    history_dataset = read_fasta(history_data_path)
                    idx = torch.randperm(len(history_dataset)) #  # Randomly down sampling
                    if self.args.max_history_samples > 0:
                        idx = idx[:self.args.max_history_samples]
                    history_dataset = [history_dataset[i] for i in idx]
                    history_datasets += history_dataset
                self.history_dataset = TemporalFastaDataset(history_datasets, self.vocab)
            
            if self.args.predict_starting_time is None:    
                assert self.args.history_data_paths
                history_times = list(set([int(x["src_time"]) for x in self.history_dataset]))
                history_times.sort()
                self.history_times = history_times
                predict_starting_time = history_times[-1]
            else:
                predict_starting_time = self.args.predict_starting_time
            pred_dataset = []
            for prediction_step in range(self.args.prediction_steps):
                pred_dataset.extend([("gen_"+str(i), "", "gen_"+str(i) + " " + str(predict_starting_time + prediction_step * self.args.normalize_time_a)) for i in range(self.args.predict_sample_num)])
            predict_dataset = TemporalFastaDataset(pred_dataset, self.vocab)
            # print(len(predict_dataset), predict_dataset[0])
            # print(len(self.history_dataset))
            predict_dataset_combine = []
            for i in range(len(predict_dataset)):
                src_index = np.random.choice(len(self.history_dataset), self.args.batch_size-1)
                predict_dataset_combine.append([self.history_dataset[j] for j in src_index] + [predict_dataset[i]])
            # print(len(predict_dataset_combine))
            # print(len(predict_dataset_combine[0]), predict_dataset_combine[0][0], predict_dataset_combine[0][-1])
            self.predict_dataset = predict_dataset_combine
        
        if stage == "predict_seq2seq":
            history_dataset = TemporalFastaDataset(read_fasta(self.args.data_path), self.vocab)
            target_dataset = TemporalFastaDataset(read_fasta(self.args.test_data_paths[0]), self.vocab) # TODO: 
            if os.path.exists(self.args.pairwise_index_cache):
                src_index= torch.load(self.args.pairwise_index_cache) # [len(tgt_dataset), sample_number]
            else:
                src_index = build_fake_pairwise_dataset(history_dataset, target_dataset, sample_src_size=self.args.sample_src_size)
                if self.args.pairwise_index_cache:
                    torch.save(src_index, self.args.pairwise_index_cache)
            self.predict_dataset = TemporalFastaDataset(history_dataset, self.vocab)
            
    def train_dataloader(self):
        train_loader = DataLoader(
            self.train_dataset, 
            batch_size=self.args.batch_size, 
            shuffle=True, 
            pin_memory=self.args.pin_memory, 
            num_workers=self.args.num_workers, 
            persistent_workers=self.args.persistent_workers,
            collate_fn=partial(default_lm_collate_func, batch_converter=self.vocab.get_batch_converter(max_positions=self.args.max_position_embeddings))
        )
        return train_loader

    def val_dataloader(self):
        val_loader = DataLoader(
            self.val_dataset, 
            batch_size=self.args.batch_size, 
            shuffle=False, 
            num_workers=self.args.num_workers, 
            persistent_workers=self.args.persistent_workers,
            collate_fn=partial(default_lm_collate_func, batch_converter=self.vocab.get_batch_converter(max_positions=self.args.max_position_embeddings))
        )
        return val_loader

    def test_dataloader(self, repeat_size=1, load_history=False, batched=False):
        if load_history:
            test_loaders = DataLoader(
                self.history_dataset, 
                batch_size=None if batched else self.args.batch_size, 
                shuffle=False, 
                # pin_memory=True, 
                pin_memory=self.args.pin_memory, 
                num_workers=self.args.num_workers, 
                persistent_workers=self.args.persistent_workers,
                collate_fn=partial(default_lm_collate_func, batch_converter=self.vocab.get_batch_converter(max_positions=self.args.max_position_embeddings))
                )
        else:
            test_loaders = []
            for test_dataset in self.test_datasets:
                for _ in range(repeat_size):
                    test_loaders.append(
                    DataLoader(
                    test_dataset, 
                    batch_size=None if batched else self.args.batch_size, 
                    shuffle=False, 
                    # pin_memory=True, 
                    num_workers=self.args.num_workers, 
                    persistent_workers=self.args.persistent_workers,
                    collate_fn=partial(default_lm_collate_func, batch_converter=self.vocab.get_batch_converter(max_positions=self.args.max_position_embeddings))
                    ))
        return test_loaders
    
    def predict_dataloader(self,):
        dataloader = DataLoader(
            self.predict_dataset, 
            batch_size=None, 
            shuffle=False, 
            num_workers=self.args.num_workers, 
            persistent_workers=self.args.persistent_workers,
            collate_fn=partial(default_lm_collate_func, batch_converter=self.vocab.get_batch_converter(max_positions=self.args.max_position_embeddings))
        )
        return dataloader

@register_dm("lm_weighted")
class ProteinLMWeightedDataModule(ProteinGISAIDDataModule):
    def __init__(self, args, vocab=None):
        super().__init__(args)
        if vocab is not None:
            self.vocab = vocab
        else:
            self.vocab = load_esm_alphabet(self.args.vocab, self.args.mol_type)

    def build_predict_datasets(self, properties, properties_dict):
        # self.pred_datasets = []
        # for pred_data_path in self.args.pred_data_paths:
        if self.args.set_data_properties is not None:
            set_data_properties = json.loads(self.args.set_data_properties)
            self.pred_datasets = [self.build_predicting_set(None, properties, properties_dict, set_data_properties=set_data_properties)]
        else:
            self.pred_datasets = [self.build_predicting_set(None, properties, properties_dict)]
        # print(set_data_properties)
        # print(self.pred_datasets[0][0])
        # exit()
        

    def build_predicting_set(self, pred_data_path, properties, *args, **argv):
        set_data_properties = argv.get("set_data_properties", {})
        fake_fasta = []
        for time in range(self.args.min_testing_time, self.args.max_testing_time + 1):
            desc = "time_bin=%d|freq=1.0" % (time)
            if len(set_data_properties) > 0:
                desc = desc + "|" + "|".join(["%s=%s" % (k, v) for k, v in set_data_properties.items()])
            fake_fasta.extend([("gen%d" % (i + len(fake_fasta)), "", "gen%d %s" % (i + len(fake_fasta), desc)) for i in range(self.args.generation_seq_number)])
        # print(fake_fasta[0])
        # exit()
        # print(len(fake_fasta))
        # print(fake_fasta)
        # print(">", TemporalFastaDataset(fake_fasta, self.vocab, get_time_method="kw", properties=properties)[0])
        # print(fake_fasta[0])
        # if self.args.remap_continent:
        #     fake_fasta = self.remap_continent(fake_fasta)
        #     print(fake_fasta[0])
        pred_set = TemporalFastaDataset(fake_fasta, self.vocab, get_time_method="kw", properties=properties)
        if self.args.remap_continent:
            pred_set = self.remap_continent(pred_set)
            # print(pred_set[0])
        return pred_set


    def build_testing_set(self, test_data_path, properties):
        if self.args.set_data_properties is not None:
            set_data_properties = json.loads(self.args.set_data_properties)
        else:
            set_data_properties = {}
        test_set = TemporalFastaDataset(read_fasta(test_data_path), self.vocab, get_time_method="kw", properties=properties)
        # print(len(test_set), test_set[0], test_set[1])
                
        extended_test_set = []
        if self.args.min_testing_time != -1 and self.args.max_testing_time != -1:
            for item in test_set:
                # print(item)
                for key in set_data_properties:
                    item[key] = set_data_properties[key]
                # print(item)
                # exit()
                
                for time in range(self.args.min_testing_time, self.args.max_testing_time + 1):
                    new_item = deepcopy(item)
                    new_item["src_time"] = time
                    extended_test_set.append(new_item)
        # print(len(extended_test_set), extended_test_set[0], extended_test_set[1])
        # exit()
        else:
            return test_set
        return extended_test_set
         



class RandomSampler(object):
    def __init__(self, joint_prob) -> None:
        self.joint_prob = joint_prob

    def sample(self, target):
        src_and_prob = self.joint_prob[target]
        probs = torch.tensor([x[1] for x in src_and_prob])
        if len(probs) == 0:
            return target, 0.0
        sample_id = torch.multinomial(probs, 1, replacement=True).squeeze()
        probs = probs / torch.sum(probs) # Normalized to one
        return src_and_prob[sample_id.item()][0], probs[sample_id].item()


@register_dm("sample_pair_lm")
class ProteinSamplePairLMWeightedDataModule(ProteinLMWeightedDataModule):
    def __init__(self, args, vocab=None):
        super().__init__(args, vocab)
    
    def build_property_dict(self, ori_dataset, properties_dict):
        for prop in self.args.data_properties:
            prop_list = list(set([x.get(prop, None) for x in ori_dataset]))
            # print(prop, prop_list)
            if not (len(prop_list) == 1 and prop_list[0] is None): # nothing is included
                prop_dict = {v: idx for idx, v in enumerate(prop_list)}
                setattr(self.args, "%s_dict" % prop, prop_dict)
                setattr(self.args, "%s_list" % prop, prop_list)
                properties_dict[prop] = prop_dict


    def set_collate_func(self, properties_dict, *args, **kwargs):
        self.collate_fn = partial(default_seq2seq_collate_func, \
                batch_converter=self.vocab.get_batch_converter(max_positions=self.args.max_position_embeddings), 
                padding_idx=self.vocab.pad(), properties_dict=properties_dict, remove_gaps=self.args.remove_gaps,
                # remove_gaps_from_source=False
            )

    @classmethod
    def add_argparse_args(cls, parent_parser):
        parent_parser = super(ProteinSamplePairLMWeightedDataModule, cls).add_argparse_args(parent_parser)
        parent_parser.add_argument('--pairwise_index_cache', type=str, default=None)
        parent_parser.add_argument('--pairwise_weight_path', type=str, default=None, help="A file saves the sampling weight for each pair.")
        parent_parser.add_argument('--sample_src_size', type=int, default=1)

        parent_parser.add_argument('--tgt_data_path', type=str)
        parent_parser.add_argument('--test_src_path', type=str)
        parent_parser.add_argument('--test_tgt_path', type=str)
        parent_parser.add_argument('--normalize_evolution_time_by_edit_distance', type=str2bool, default='false')
        
        parent_parser.add_argument('--remove_gaps', type=str2bool, default='true')
        parent_parser.add_argument('--remove_gaps_from_source', type=str2bool, default='true', help="")

        parent_parser.add_argument('--split_pairs_by_source', type=str2bool, default='false')

        return parent_parser

    def build_pairwise_dataset(self, dataset, method):
        if method == "sample_prev":
            time2records = defaultdict(list)
            for i, data in enumerate(dataset):
                time2records[data["src_time"]].append(i)
            
            sorted_time = list(time2records.keys())
            sorted_time.sort()

            source_index = []
            target_index = []

            for i, t in enumerate(sorted_time):
                # prev_time = 
                records = time2records[data["src_time"] - 1]
                freqs = torch.tensor([dataset[j]["freq"] for j in records])
                samples = torch.multinomial(freqs, 1).squeeze()

                
            



        else:
            raise NotImplementedError()
            
    def load_pairwise_weight(self, pairwise_weight_path):
        target2sources = defaultdict(list)
        with open(pairwise_weight_path) as fin:
            head_line = fin.readline().strip().split(",")
            for line in fin:
                line = line.strip().split(",")
                pair = dict(zip(head_line, line))
                src_id = pair["src_id"]
                tgt_id = pair["tgt_id"]
                weight = float(pair["prob"])
                target2sources[tgt_id].append((src_id, weight))
        return target2sources

    def build_training_set(self, properties, properties_dict):
        if self.args.pairwise_weight_path is not None:
            target2sources = self.load_pairwise_weight(self.args.pairwise_weight_path)
            sampler = RandomSampler(target2sources)
        else:
            raise ValueError("pairwise_weight_path is required")
        
        dataset = TemporalFastaDataset(read_fasta(self.args.data_path), self.vocab, get_time_method="kw", properties=properties) # 
        full_dataset = PairwiseRandomDataset(dataset, sampler.sample)
        self.build_property_dict(dataset, properties_dict)
        return full_dataset


    def build_test_datasets(self, properties, properties_dict, *args, **argv):
        # TODO: sample source size > 1?

        self.test_datasets = []

        if self.args.pairwise_weight_path is not None:
            target2sources = self.load_pairwise_weight(self.args.pairwise_weight_path)
            sampler = RandomSampler(target2sources)
        else:
            raise ValueError("pairwise_weight_path is required")
        assert len(self.args.test_data_paths) == 2

        src_dataset = TemporalFastaDataset(read_fasta(self.args.test_data_paths[0]), self.vocab, get_time_method="kw", properties=properties) # 
        tgt_dataset = TemporalFastaDataset(read_fasta(self.args.test_data_paths[1]), self.vocab, get_time_method="kw", properties=properties) 

        if self.args.min_testing_time != -1 and self.args.max_testing_time != -1:
            # overwrite the testing time
            test_dataset_new = []
            for x in tgt_dataset:
                for t in range(self.args.min_testing_time, self.args.max_testing_time+1):
                    x_new = deepcopy(x)
                    x_new["src_time"] = t
                    if x_new["freq"] is None:
                        x_new["freq"] = 1.0
                    if x_new["bin_size"] is None:
                        x_new["bin_size"] = 1.0
                    test_dataset_new.append(x_new)
            tgt_dataset = test_dataset_new

        # for _ in range(self.args.sample_src_size):
        full_dataset = PairwiseRandomDataset(tgt_dataset, sampler.sample, src_dataset=src_dataset)
        self.test_datasets = [full_dataset] * self.args.sample_src_size


    # def build_testing_set(self, test_data_path, properties):
        
        
        
    #     self.build_property_dict(dataset, properties_dict)

    #     history_dataset = TemporalFastaDataset(read_fasta(self.args.data_path), self.vocab, get_time_method="kw", properties=properties)
    #     test_dataset = TemporalFastaDataset(read_fasta(test_data_path), self.vocab, get_time_method="kw", properties=properties)
    #     # print(test_dataset[0])
    #     if self.args.min_testing_time != -1 and self.args.max_testing_time != -1:
    #         # overwrite the testing time
    #         test_dataset_new = []
    #         for x in test_dataset:
    #             for t in range(self.args.min_testing_time, self.args.max_testing_time+1):
    #                 x_new = deepcopy(x)
    #                 x_new["src_time"] = t
    #                 if x_new["freq"] is None:
    #                     x_new["freq"] = 1.0
    #                 if x_new["bin_size"] is None:
    #                     x_new["bin_size"] = 1.0
    #                 test_dataset_new.append(x_new)
    #         # print(test_dataset_new[0], len(test_dataset_new), len(test_dataset))
    #         test_dataset = test_dataset_new
    #     # print(test_dataset[0], test_dataset[1], test_dataset[2])

    #     if self.args.pairwise_index_cache and os.path.exists(self.args.pairwise_index_cache):
    #         src_index= torch.load(self.args.pairwise_index_cache) # [len(tgt_dataset), sample_number]
    #         pairwise_dataset, src_index = build_fake_pairwise_dataset(history_dataset, test_dataset, src_index=src_index)
    #     else:
    #         src_index, tgt_index, sample_probs = build_fake_pairwise_dataset(history_dataset, test_dataset, sample_src_size=self.args.sample_src_size)
    #         # print(src_index, tgt_index)
    #         if self.args.pairwise_index_cache:
    #             torch.save([src_index, tgt_index, sample_probs], self.args.pairwise_index_cache)
    #     source_dataset = []
    #     for (x, p) in zip(src_index.view(-1), sample_probs.view(-1)):
    #         item = history_dataset[x.item()]
    #         item["freq"] = p.item()
    #         source_dataset.append(item)
    #     target_dataset = [test_dataset[x.item()] for x in tgt_index.view(-1)]
    #     # print(target_dataset[0])
    #     # print(source_dataset[0])
    #     # print(history_dataset[src_index[0]])
    #     # print(sample_probs[0])
    #     pairwise_dataset = PairwiseDataset(source_dataset, target_dataset, aligned=True)
    #     # print(pairwise_dataset[0]["tgt_seq"] == pairwise_dataset[1]["tgt_seq"])
    #     # print(pairwise_dataset[2]["tgt_seq"] == pairwise_dataset[3]["tgt_seq"])
    #     # print(pairwise_dataset[4]["tgt_seq"] == pairwise_dataset[5]["tgt_seq"])
    #     # print(pairwise_dataset[1])

    #     # print(len(pairwise_dataset))
    #     # exit()
    #     return pairwise_dataset


@register_dm("multi_lm_weighted")
class MultiProteinLMWeightedDataModule(ProteinGISAIDDataModule):
    def __init__(self, args, vocab=None):
        super().__init__(args)
        if vocab is not None:
            self.vocab = vocab
        else:
            self.vocab = load_esm_alphabet(self.args.vocab)

    @classmethod
    def add_argparse_args(cls, parent_parser):
        parent_parser = super(MultiProteinLMWeightedDataModule, cls).add_argparse_args(parent_parser)
        parent_parser.add_argument('--genes', nargs="+", type=str, default=["ha", "na"])
        return parent_parser

    def build_testing_set(self, test_data_path, properties):
        # return super().build_testing_set(test_data_path, properties)
        return TemporalMultiFastaDataset(read_fasta(test_data_path), self.vocab, get_time_method="kw", properties=properties) # 
    
    def build_predicting_set(self, pred_data_path, properties):
        return TemporalMultiFastaDataset(read_fasta(pred_data_path), self.vocab, get_time_method="kw", properties=properties) # 

    def build_training_set(self, properties, properties_dict):
        full_dataset = TemporalMultiFastaDataset(read_fasta(self.args.data_path), self.vocab, get_time_method="kw", properties=properties) # 
        self.load_lineage(full_dataset, properties_dict)
        self.load_location(full_dataset, properties_dict)
        
        self.genes = full_dataset.genes + ["<sep>", "<pad>"]
        if self.vocab.prepend_bos:
            self.genes += ["<bos>"]
        if self.vocab.append_eos:
            self.genes += ["<eos>"]
        
        self.gene_to_index = {x: i for i, x in enumerate(self.genes)}
        setattr(self.args, "gene_to_index", self.gene_to_index)
        # if self.args.load_location:
        #         # self.setup_properties(full_dataset, properties_dict, "loc", "locations")
        #         self.locations = list(set([x["loc"] for x in full_dataset]))
        #         self.location_to_index = {loc: idx for idx, loc in enumerate(self.locations)}
        #         if not (len(self.locations) == 1 and self.locations[0] is None):
        #             properties_dict["loc"] = self.location_to_index
        #             setattr(self.args, "location_to_index", self.location_to_index)
        # if self.args.load_lineage:
        #     # self.setup_properties(full_dataset, properties_dict, "lineage", "lineage_cls")
        #     self.lineages = list(set([x["lineage"] for x in full_dataset]))
        #     self.lineage_to_index = {loc: idx for idx, loc in enumerate(self.lineages)}
        #     if not (len(self.lineages) == 1 and self.lineages[0] is None):
        #         self.lineage_cls = CovLineage(self.lineages, full_dataset)
        #         properties_dict["lineage"] = self.lineage_to_index
        #         setattr(self.args, "lineage_cls", self.lineage_cls)
        #         setattr(self.args, "lineage_to_index", self.lineage_to_index)
        return full_dataset
    
    def set_collate_func(self, properties_dict, model_config=None):
        if getattr(self, "gene_to_index", None) is None:
            self.gene_to_index = getattr(model_config, "gene_to_index", None)

        properties_dict["genes"] = self.gene_to_index
        # print(properties_dict["genes"])
        self.collate_fn = partial(multiprot_lm_collate_func, \
                batch_converter=self.vocab.get_batch_converter(max_positions=self.args.max_position_embeddings), 
                padding_idx=self.vocab.pad(), properties_dict=properties_dict
            )

    # def setup(self, stage, model_config=None):
    #     properties_dict = {}
    #     # Load datasets
    #     self.tokenizer = self.vocab # consistent with TransformerLM
    #     properties = ['time_bin', 'freq', 'bin_size']
    #     if self.args.load_location:
    #         properties += ["loc"]
    #     if self.args.load_lineage:
    #         properties += ["lineage"]

    #     if stage == "fit" or stage is None:
    #         full_dataset = self.build_training_set(properties, properties_dict)
    #         valid_size = int(len(full_dataset) * self.args.valid_size)
    #         train_size = len(full_dataset) - valid_size
    #         self.train_dataset, self.val_dataset = random_split(full_dataset, [train_size, valid_size])

    #     if stage == "test" or stage == "predict":
    #         self.test_datasets = []
    #         for test_data_path in self.args.test_data_paths:
    #             test_dataset = self.build_testing_set(test_data_path, properties)
    #             self.test_datasets.append(test_dataset)
    #         if model_config:
    #             if self.args.load_lineage:
    #                 lineage_to_index = getattr(model_config, "lineage_to_index", None)
    #                 if lineage_to_index:
    #                     properties_dict["lineage"] = lineage_to_index
    #             if self.args.load_location:
    #                 location_to_index = getattr(model_config, "location_to_index", None) # .get()
    #                 if location_to_index:
    #                     properties_dict["loc"] = location_to_index

    #     self.set_collate_func(properties_dict)