from collections import defaultdict
import pytorch_lightning as pl
import os, torch, json, re
from functools import partial
from data.io import read_fasta, read_m8, read_newick_tree, read_msa, read_tree
from data.vocab import load_esm_alphabet
from data.datasets import TemporalFastaDataset, TemporalMSADataset, TemporalPairwiseFastaDataset, TemporalUnpairedFastaDataset, TemporalBlockedMSADataset, PairwiseDataset
from data.utils import customized_collate_func, default_lm_collate_func
from torch.utils.data import DataLoader, random_split
from data.lineage import CovLineage
from utils.args import str2bool
from copy import deepcopy
from collections.abc import Mapping
import numpy as np
import logging

class OurPropertyVocab(Mapping):
    def __init__(self, all_toks, unk_tok) -> None:
        self.all_toks = all_toks
        self.unk_tok = unk_tok
        if unk_tok not in all_toks:
            self.all_toks = self.all_toks + [unk_tok]
        # print(self.all_toks)
        self.tok2idx = {x: i for i, x in enumerate(self.all_toks)}
    
    def __iter__(self):
        for tok in self.all_toks:
            yield tok

    def __len__(self):
        return len(self.all_toks)

    def __getitem__(self, tok):
        if tok in self.tok2idx:
            return self.tok2idx[tok]
        else:
            return self.tok2idx[self.unk_tok]
    
    def get(self, key):
        self.__getitem__(key)


class CoutryVocab(Mapping):
    def __init__(self, vocab) -> None:
        super().__init__()
        self.vocab = vocab
    
    def __iter__(self):
        for tok in self.vocab:
            yield tok

    def __len__(self):
        return len(self.vocab)

    def __getitem__(self, tok):
        assert len(tok.split("/")) >= 2, "Cannot find the continent and country information from %s" % tok
        continent, country = tok.split("/")
        if tok in self.vocab:
            return self.vocab[tok]
        else:
            return self.vocab["%s/other_countries" % continent]
    
    def get(self, key):
        self.__getitem__(key)

class GeneralLocationVocab(Mapping):
    # compatible with continent only and continent/country
    def __init__(self, vocab) -> None:
        super().__init__()
        self.vocab = vocab
    
    def __iter__(self):
        for tok in self.vocab:
            yield tok

    def __len__(self):
        return len(self.vocab)

    def __getitem__(self, tok):
        if len(tok.split("/")) <= 1:
            return self.vocab[tok]
        # assert len(tok.split("/")) >= 2, "Cannot find the continent and country information from %s" % tok
        assert len(tok.split("/")) == 2
        continent, country = tok.split("/")
        if tok in self.vocab:
            return self.vocab[tok]
        else:
            return self.vocab["%s/other_countries" % continent]
    
    def get(self, key):
        self.__getitem__(key)

# class PropertyBalanceSampler(Sampler[int]):
#     def __init__(self, data, prop, sample_num=None, adjust_sample_num="max") -> None:
#         self.data = data
#         self.prop = prop
#         self.property_to_indices = defaultdict(list)
#         for i, x in enumerate(data):
#             self.property_to_indices[x[prop]].append(i)
        
#         if sample_num is not None:
#             self.sample_num = sample_num
#         else:
#             if adjust_sample_num == "max":
#                 self.sample_num = max([len(self.property_to_indices[prop]) for prop in self.property_to_indices])
#             elif adjust_sample_num == "min":
#                 self.sample_num = min([len(self.property_to_indices[prop]) for prop in self.property_to_indices])
#             elif adjust_sample_num == "mean":
#                 self.sample_num = len(self.data) // len(self.property_to_indices)

#     def __len__(self) -> int:
#         return self.sample_num * len(self.property_to_indices)

#     def __iter__(self): # -> Iterator[int]:
#         indices = []
#         for prop in self.property_to_indices
#         sizes = torch.tensor([len(x) for x in self.data])
#         yield from torch.argsort(sizes).tolist()
        
#         for _ in range(self.num_samples // 32):
#                 yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()


class ProteinDataModule(pl.LightningDataModule):
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.save_hyperparameters(self.args)
    
    @classmethod
    def add_argparse_args(cls, parent_parser):
        # group = parent_parser.add_argument_group('ProteinDataModule')
        
        # New
        parent_parser.add_argument('--pin_memory', type=str2bool, default="true")
        parent_parser.add_argument('--pre_tokenize', type=str2bool, default="false", help="Tokenize sequences in the initialization?")
        parent_parser.add_argument('--cache_token', type=str2bool, default="false", help="Saving the tokenization!")
        parent_parser.add_argument('--data_path', type=str, default=None, nargs="+")
        parent_parser.add_argument('--test_data_paths', nargs="+", type=str, default=None) # , ])
        parent_parser.add_argument('--disable_autobatch', type=str2bool, default="false")
        parent_parser.add_argument('--max_position_embeddings', type=int, default=1024) # TODO: put it here or in the model part?
        # New

        # parent_parser.add_argument('--data_dir', type=str, default="")
        parent_parser.add_argument('--vocab', type=str, default="", help="If not specified, will be modified according to the model_name_or_path.")
        parent_parser.add_argument('--valid_size', type=float, default=0.1)
        parent_parser.add_argument('--batch_size', type=int, default=32)
        parent_parser.add_argument('--num_workers', type=int, default=0)
        parent_parser.add_argument('--persistent_workers', type=str2bool, default=False)
        # For testing: 
        parent_parser.add_argument('--pred_data_paths', nargs="+", default="", type=str)

        parent_parser.add_argument('--predict_src_file', default="", type=str)
        parent_parser.add_argument('--predict_tgt_file', default="", type=str)
        parent_parser.add_argument('--source_sample_num', default=1, type=int, )
        parent_parser.add_argument('--predict_sample_num', default=1, type=int, )

        # 
        parent_parser.add_argument('--mol_type', type=str, default="protein", choices=["dna_codon", "rna_codon", "protein"])
        
        # group.add_argument('--generation_order', default=1, type=int)
        # group.add_argument('--requires_alignment', action='store_true', help="If the target is required to be aligned.")
        return parent_parser

class PairwiseProteinDataModule(ProteinDataModule):
    def __init__(self, args):
        super().__init__(args)
    
    def prepare_data(self):
        # download data, etc...
        # only called on 1 GPU/TPU in distributed
        # Avoid asign any variables here.
        pass

    def setup(self, stage):
        # make assignments here (val/train/test split)
        # called on every process in DDP
        # Assign train/val datasets for use in dataloaders
        
        # Load datasets
        self.vocab = load_esm_alphabet(self.args.vocab, self.args.mol_type)
        if stage == "fit" or stage is None:
            src_fasta_path = os.path.join(self.args.data_dir, "train.src.fasta")
            tgt_fasta_path = os.path.join(self.args.data_dir, "train.tgt.fasta")
            src_dataset = read_fasta(src_fasta_path)
            tgt_dataset = read_fasta(tgt_fasta_path)
            full_dataset = TemporalPairwiseFastaDataset(src_dataset, tgt_dataset, self.vocab)

            valid_size = int(len(full_dataset) * self.args.valid_size)
            train_size = len(full_dataset) - valid_size
            self.train_dataset, self.val_dataset = random_split(full_dataset, [train_size, valid_size])

        # Assign test dataset for use in dataloader(s)
        if stage == "test":
            self.test_datasets = []
            for test_data_dir in self.args.test_data_dirs:
                src_fasta_path = os.path.join(test_data_dir, "test.src.fasta")
                tgt_fasta_path = os.path.join(test_data_dir, "test.tgt.fasta")
                src_dataset = read_fasta(src_fasta_path)
                tgt_dataset = read_fasta(tgt_fasta_path)
                self.test_datasets.append(TemporalPairwiseFastaDataset(src_dataset, tgt_dataset, self.vocab, get_time_method="kw"))

        if stage == "predict":
            self.predict_dataset = []
            src_dataset = read_fasta(self.args.predict_src_file)
            self.predict_dataset = TemporalFastaDataset(src_dataset, self.vocab)
            # self.mnist_predict = MNIST(self.data_dir, train=False, transform=self.transform)
        
        if stage == "test_population":
            # Randomly choose a prior distribution:
            src_dataset = read_fasta(self.args.predict_src_file)
            tgt_dataset = read_fasta(self.args.predict_tgt_file)
            self.test_population_dataset = TemporalUnpairedFastaDataset(src_dataset, tgt_dataset, self.vocab, source_sample_num=self.args.source_sample_num)
            # self.predict_tgt_dataset = TemporalFastaDataset(tgt_dataset, self.vocab)
            print(len(self.test_population_dataset))
            print(self.test_population_dataset[0])

            # self.test_datasets = []
            # for test_data_dir in self.args.test_data_dirs:
            #     src_fasta_path = os.path.join(test_data_dir, "test.src.fasta")
            #     tgt_fasta_path = os.path.join(test_data_dir, "test.tgt.fasta")
            #     src_dataset = read_fasta(src_fasta_path)
            #     tgt_dataset = read_fasta(tgt_fasta_path)
            #     self.test_datasets.append(TemporalPairwiseFastaDataset(src_dataset, tgt_dataset, self.vocab, get_time_method="kw"))

    # def on_before_batch_transfer(self, batch, dataloader_idx):
    #     # Do something before the batch, like da or add some noise?
    #     batch['x'] = transforms(batch['x'])
    #     return batch

    def train_dataloader(self):
        train_loader = DataLoader(
            self.train_dataset, 
            batch_size=self.args.batch_size, 
            shuffle=True, 
            # pin_memory=True, 
            num_workers=self.args.num_workers, 
            persistent_workers=self.args.persistent_workers,
            collate_fn=partial(customized_collate_func, batch_converter=self.vocab.get_batch_converter())
        )
        
        return train_loader

    def val_dataloader(self):
        val_loader = DataLoader(
            self.val_dataset, 
            batch_size=self.args.batch_size, 
            shuffle=False, 
            num_workers=self.args.num_workers, 
            persistent_workers=self.args.persistent_workers,
            collate_fn=partial(customized_collate_func, batch_converter=self.vocab.get_batch_converter())
        )
        return val_loader

    def test_dataloader(self):
        test_loaders = []
        # test_population_dataset
        for test_dataset in  [self.test_population_dataset]: # self.test_datasets:
            test_loaders.append(
                DataLoader(
                test_dataset, 
                batch_size=self.args.batch_size, 
                shuffle=False, 
                num_workers=self.args.num_workers, 
                persistent_workers=self.args.persistent_workers,
                collate_fn=partial(customized_collate_func, batch_converter=self.vocab.get_batch_converter(), aligned=False)
            ))
        return test_loaders
    
    def predict_dataloader(self, source="src", remove_gaps=True, batch_size=None):
        predict_loader = DataLoader(
            self.predict_src_dataset if source == "src" else self.predict_tgt_dataset, # Monolingual acturally... 
            batch_size=batch_size if batch_size is not None else self.args.batch_size, 
            shuffle=False, 
            num_workers=self.args.num_workers, 
            persistent_workers=self.args.persistent_workers,
            collate_fn=partial(customized_collate_func, batch_converter=self.vocab.get_batch_converter(), remove_gaps=remove_gaps)
        )
        return predict_loader

    def teardown(self, stage, *args, **kwargs):
        # clean up after fit or test
        # called on every process in DDP
        pass

class ProteinGISAIDDataModule(ProteinDataModule):
    def __init__(self, args, vocab=None):
        super().__init__(args)
        self.continent_to_country = json.load(open(args.continent_to_country_mapping_file))
        if "truncated" in args.continent_to_country_mapping_file:
            continent_to_country_mapping_file_full = args.continent_to_country_mapping_file.split("_truncated")[0] + ".json"
            self.continent_to_country_full = json.load(open(continent_to_country_mapping_file_full))

            self.country_to_continent = {}
            for continent, countries in self.continent_to_country_full:    
                countries.append("other_countries")            
                for country in countries:
                    self.country_to_continent[country] = continent
            
            for continent, countries in self.continent_to_country:
                countries.append("other_countries") 
        else:
            self.country_to_continent = {}
            for continent, countries in self.continent_to_country:
                countries.append("other_countries")      
                for country in countries:
                    self.country_to_continent[country] = continent
        
    @classmethod
    def add_argparse_args(cls, parent_parser):
        parent_parser = super(ProteinGISAIDDataModule, cls).add_argparse_args(parent_parser)
        parent_parser.add_argument('--generation_seq_number', type=int, default=1)
        parent_parser.add_argument('--continent_to_country_mapping_file', type=str, default="data/data_modules/continent2countries_minCnt1000.json")
        parent_parser.add_argument('--split_valid_set_by_time', type=str2bool, default="false")
        parent_parser.add_argument('--data_properties', nargs="+", type=str, default=[], help="What kind of information is stored in dataset.")
        parent_parser.add_argument('--set_data_properties', type=str, default=None, help="Set up the property values when generating.")

        # todo: just for debug
        parent_parser.add_argument('--debug_mess_data_property', type=str, default=None)

        # balance data number
        parent_parser.add_argument('--property_weighted_random_sampler', type=str2bool, default="false")

        parent_parser.add_argument('--remap_continent', type=str2bool, default="false")

        # parent_parser.add_argument('--hierarchical_properties', nargs="+", type=str, default=None)
        # parent_parser.add_argument('--load_location', type=str2bool, default="true", help="Load location information.")
        # parent_parser.add_argument('--load_lineage', type=str2bool, default="true", help="Load lineage information.")
        return parent_parser
    
    def reset_testing_time(self, test_set, time_key="src_time"):
        
        if self.args.min_testing_time != -1 and self.args.max_testing_time != -1:
            extended_test_set = []
            for item in test_set:
                for time in range(self.args.min_testing_time, self.args.max_testing_time + 1):
                    new_item = deepcopy(item)
                    new_item[time_key] = time
                    extended_test_set.append(new_item)
            return extended_test_set
        else:
            return test_set

    # def load_lineage(self, full_dataset, properties_dict):
    #     if self.args.load_lineage:
    #     # self.setup_properties(full_dataset, properties_dict, "lineage", "lineage_cls")
    #         self.lineages = list(set([x["lineage"] for x in full_dataset]))
    #         self.lineage_to_index = {loc: idx for idx, loc in enumerate(self.lineages)}
    #         if not (len(self.lineages) == 1 and self.lineages[0] is None):
    #             self.lineage_cls = CovLineage(self.lineages, full_dataset)
    #             properties_dict["lineage"] = self.lineage_to_index
    #             setattr(self.args, "lineage_cls", self.lineage_cls)
    #             setattr(self.args, "lineage_to_index", self.lineage_to_index)
    
    # def load_location(self, full_dataset, properties_dict):
    #     if self.args.load_location:
    #         # self.setup_properties(full_dataset, properties_dict, "loc", "locations")
    #         self.locations = list(set([x["loc"] for x in full_dataset]))
    #         self.location_to_index = {loc: idx for idx, loc in enumerate(self.locations)}
    #         if not (len(self.locations) == 1 and self.locations[0] is None):
    #             properties_dict["loc"] = self.location_to_index
    #             setattr(self.args, "location_to_index", self.location_to_index)
        
    def build_property_dict(self, full_dataset, properties_dict):
        # print(self.args.data_properties)
        # print(full_dataset[0])

        prop_lists = defaultdict(set)
        
        for data in full_dataset:
            for prop in self.args.data_properties:
                prop_lists[prop].add(data.get(prop, None)) # None
                # if "country" in self.args.data_properties and "continent" in self.args.data_properties:
                #     contient2country[data["continent"]].add(data["country"])

        for prop in prop_lists:
            prop_list = list(prop_lists[prop])
            # sorted_prop_list = []
            # data_types = list(set([type(x) for x in prop_list]))
            # data_types = sorted(data_types, key=lambda x: str(x))
            # for data_type in data_types:  # mix type
            #     _props = [x for x in prop_list if isinstance(x, data_type)]
            #     _props.sort()
            #     sorted_prop_list.extend(_props)
            # prop_list = sorted_prop_list            
            prop_list.sort()
            if not (len(prop_list) == 1 and prop_list[0] is None): # nothing is included
                if prop == "continent":
                    # replace by default continent dict
                    prop_list = [x[0] for x in self.continent_to_country]
                    prop_dict = {v: idx for idx, v in enumerate(prop_list)}
                elif prop == "country":
                    if "continent" in prop_lists:
                        prop_list = self.continent_to_country # using nested list
                        prop_dict = {}
                        for _continent, _countries in self.continent_to_country:
                            prop_dict = {**prop_dict, **{v: idx for idx, v in enumerate(_countries)}}
                        setattr(self.args, "contient2country", self.continent_to_country)
                    else:
                        # old setting
                        # replace by default country dict
                        prop_list = [y for x in self.continent_to_country for y in x[1]]
                        prop_list.sort()
                        prop_dict = {v: idx for idx, v in enumerate(prop_list)}
                        # old setting
                    # print(len(prop_list), prop_list)
                elif prop == "location" and "continent" in prop_lists:
                    # for location: the index WITHIN continent
                    prop_list = self.continent_to_country # using nested list
                    prop_dict = {}
                    for _continent, _countries in self.continent_to_country:
                        # _country_vocab = OurPropertyVocab(_countries, "%s_other_countries" % _continent)
                        prop_dict = {**prop_dict, **{"%s/%s" % (_continent, v): idx for idx, v in enumerate(_countries)}}
                        # prop_dict = {**prop_dict, **_country_vocab}
                    setattr(self.args, "contient2country", self.continent_to_country)
                    # print(prop_dict)
                    prop_dict = CoutryVocab(prop_dict)
                    # print(prop_dict)
                elif prop == "location":
                    if len(prop_list[0].split("/")) == 2: # continent+country
                        prop_list = []
                        for _continent, _countries in self.continent_to_country:
                            for _country in _countries:
                                prop_list.append("%s/%s" % (_continent, _country))
                        prop_dict = {v: idx for idx, v in enumerate(prop_list)}                        
                        setattr(self.args, "contient2country", self.continent_to_country)
                        prop_dict = GeneralLocationVocab(prop_dict)
                    else:
                        prop_dict = {v: idx for idx, v in enumerate(prop_list)}
                else:
                    prop_dict = {v: idx for idx, v in enumerate(prop_list)}
                    # print(prop_dict)
                    # exit()
                setattr(self.args, "%s_dict" % prop, prop_dict)
                setattr(self.args, "%s_list" % prop, prop_list)
                
                properties_dict[prop] = prop_dict
        # if "country" in self.args.data_properties and "continent" in self.args.data_properties:
        #     contient2country = sorted(contient2country.items(), key=lambda x: x[0])
        # setattr(self.args, "contient2country", contient2country)

        # for prop in self.args.data_properties:
        #     prop_list = list(set([x.get(prop, None) for x in full_dataset]))
        #     print(prop_list)
        #     prop_list.sort()
        #     print(prop_list)
        #     # print(prop, prop_list)
        #     if not (len(prop_list) == 1 and prop_list[0] is None): # nothing is included
        #         prop_dict = {v: idx for idx, v in enumerate(prop_list)}
        #         setattr(self.args, "%s_dict" % prop, prop_dict)
        #         setattr(self.args, "%s_list" % prop, prop_list)
        #         properties_dict[prop] = prop_dict

    def remap_continent(self, dataset):
        new_dataset = []
        for data in dataset:
            
            if "location" in data:
                location = data["location"]
                # print(location)
                if location.split("/")[1] not in self.country_to_continent:
                    # print(self.country_to_continent, len(self.country_to_continent))
                    # print(location)
                    new_continent = self.country_to_continent["other_countries"]
                    new_country = "other_countries"
                    # print(new_country, new_continent)
                else:
                    new_continent = self.country_to_continent[location.split("/")[1]]
                    new_country = location.split("/")[1]

                new_location = "%s/%s" % (new_continent, new_country)
                data["location"] = new_location
                if "continent" in data:
                    data["continent"] = new_continent # self.country_to_continent[location.split("/")[1]]
                new_dataset.append(data)
            elif "country" in data:
                location = data["country"]
                data["country"] = location
                
                if "continent" in data:
                    if location not in self.country_to_continent:
                        new_continent = self.country_to_continent["other_countries"]
                    else:
                        new_continent = self.country_to_continent[location]
                    data["continent"] = new_continent # self.country_to_continent[location]
                # print(data)
                new_dataset.append(data)
            else:
                new_dataset.append(data)
        # print(new_dataset[0])
        return new_dataset

    def build_training_set(self, properties, properties_dict):
        # Support multiple training files.
        datasets = []
        for data_path in self.args.data_path:
            dataset = read_fasta(data_path)
            datasets.extend(dataset)
        
        full_dataset = TemporalFastaDataset(datasets, self.vocab, get_time_method="kw", properties=properties) # 
        self.build_property_dict(full_dataset, properties_dict)
        # print(properties_dict)
        
        if self.args.remap_continent:
            full_dataset = self.remap_continent(full_dataset)
            # print(full_dataset[0])

        if self.args.debug_mess_data_property is not None:
            if self.args.debug_mess_data_property == "shuffle_property":
                # print(datasets[0])
                # print(full_dataset[0])
                # print(full_dataset[1])
                # print(full_dataset[2])

                # print(full_dataset[0])
                shuffle_index = np.arange(len(full_dataset))
                np.random.shuffle(shuffle_index)
                # print(shuffle_index)
                data_property = self.args.data_properties[0]
                # print(data_property)
                labels = [x[data_property] for x in full_dataset]
                # print(labels[:10])
                shuffle_labels = [labels[idx] for idx in shuffle_index]
                # print(shuffle_labels[:10])
                full_dataset_new = []
                for i, x in enumerate(full_dataset):
                    x[data_property] = shuffle_labels[i]
                    full_dataset_new.append(x)
                full_dataset = full_dataset_new
                # print(full_dataset[0])
                # print(full_dataset[1])
                # print(full_dataset[2])
                # exit()

        # self.load_lineage(full_dataset, properties_dict)
        # self.load_location(full_dataset, properties_dict)
        # if self.args.load_location:
        #         # self.setup_properties(full_dataset, properties_dict, "loc", "locations")
        #         self.locations = list(set([x["loc"] for x in full_dataset]))
        #         self.location_to_index = {loc: idx for idx, loc in enumerate(self.locations)}
        #         if not (len(self.locations) == 1 and self.locations[0] is None):
        #             properties_dict["loc"] = self.location_to_index
        #             setattr(self.args, "location_to_index", self.location_to_index)
        # if self.args.load_lineage:
        #     # self.setup_properties(full_dataset, properties_dict, "lineage", "lineage_cls")
        #     self.lineages = list(set([x["lineage"] for x in full_dataset]))
        #     self.lineage_to_index = {loc: idx for idx, loc in enumerate(self.lineages)}
        #     if not (len(self.lineages) == 1 and self.lineages[0] is None):
        #         self.lineage_cls = CovLineage(self.lineages, full_dataset)
        #         properties_dict["lineage"] = self.lineage_to_index
        #         setattr(self.args, "lineage_cls", self.lineage_cls)
        #         setattr(self.args, "lineage_to_index", self.lineage_to_index)
        # self.full_dataset = full_dataset
        return full_dataset
    
    def build_testing_set(self, test_data_path, properties):
        test_set = TemporalFastaDataset(read_fasta(test_data_path), self.vocab, get_time_method="kw", properties=properties)
        # if self.args.remap_continent:
        #     test_set = self.remap_continent(test_set)
        return test_set
        # return TemporalFastaDataset(read_fasta(test_data_path), self.vocab, get_time_method="kw", properties=properties)

    def set_collate_func(self, properties_dict, *args, **kwargs):                    
        self.collate_fn = partial(default_lm_collate_func, \
                batch_converter=self.vocab.get_batch_converter(max_positions=self.args.max_position_embeddings), 
                padding_idx=self.vocab.pad(), properties_dict=properties_dict
            )

    # def setup_properties(self, dataset, properties_dict, property_key, property_name):
    #     # e.g., key = loc, property_name = location
    #     setattr(self, property_name, list(set([x[property_key] for x in dataset])))
    #     # self.locations = list(set([x[key] for x in dataset]))
    #     # self.location_to_index = {loc: idx for idx, loc in enumerate(self.locations)}
    #     setattr(self, "%s_to_index" % property_name, list(set([x[property_key] for x in dataset])))
    #     if not (len(self.locations) == 1 and self.locations[0] is None):
    #         properties_dict[property_key] = getattr(self, "%s_to_index" % property_name) # self.location_to_index
    #         # properties_dict["loc"] = self.location_to_index
    #         setattr(self.args, "%s_to_index" % property_key, properties_dict[property_key])

    def setup_properties(self, ):
        properties = ['time_bin', 'freq', 'bin_size']
        properties += self.args.data_properties
        # if self.args.load_location:
        #     properties += ["loc"]
        # if self.args.load_lineage:
        #     properties += ["lineage"]
        return properties

    def calc_total_sample_count(self, train_set):
        # If loss weighted by count
        # If loss weighted by time?
        total_sample_count = 0
        for data in train_set:
            total_sample_count += round(data["bin_size"] * data["freq"]) 

        logging.info("total_sample_count: %d" % total_sample_count)
        # self.total_sample_weight = []
        return total_sample_count
    
        
    def load_properties_from_config(self, model_config, properties_dict):
        if model_config:
            for prop in getattr(model_config, "data_properties", []):
                print(prop)
                properties_dict[prop] = getattr(model_config, "%s_dict" % prop)
            # if self.args.load_lineage:
            #     lineage_to_index = getattr(model_config, "lineage_to_index", None)
            #     if lineage_to_index:
            #         properties_dict["lineage"] = lineage_to_index
            # if self.args.load_location:
            #     location_to_index = getattr(model_config, "location_to_index", None) # .get()
            #     if location_to_index:
            #         properties_dict["loc"] = location_to_index
        return properties_dict

    def setup(self, stage, model_config=None):
        properties_dict = {}
        # Load datasets
        self.tokenizer = self.vocab # consistent with TransformerLM
        
        properties = self.setup_properties()
        # properties = ['time_bin', 'freq', 'bin_size']
        # if self.args.load_location:
        #     properties += ["loc"]
        # if self.args.load_lineage:
        #     properties += ["lineage"]

        if stage == "fit" or stage is None:
            full_dataset = self.build_training_set(properties, properties_dict)
            if len(full_dataset) == 2:
                self.train_dataset, self.val_dataset = full_dataset[0], full_dataset[1]
            else:
                if self.args.split_valid_set_by_time:
                    time_bins = [x["src_time"] for x in full_dataset]
                    all_times = list(set(time_bins))
                    all_times.sort()
                    valid_size = max(int(len(all_times) * self.args.valid_size), 1)
                    train_size = len(all_times) - valid_size
                    time_in_train, time_in_valid = all_times[:train_size], all_times[train_size:]
                    # print(time_in_train, )
                    # print(time_in_valid)
                    self.train_dataset = [x for x in full_dataset if x["src_time"] in time_in_train]
                    self.val_dataset = [x for x in full_dataset if x["src_time"] in time_in_valid]

                else:
                    valid_size = int(len(full_dataset) * self.args.valid_size)
                    train_size = len(full_dataset) - valid_size
                    self.train_dataset, self.val_dataset = random_split(full_dataset, [train_size, valid_size])
            
            self.total_sample_count_train = self.calc_total_sample_count(self.train_dataset)
            self.total_sample_count_valid = self.calc_total_sample_count(self.val_dataset)


        if stage == "test" or stage == "predict":
            properties_dict = self.load_properties_from_config(model_config, properties_dict)
            # print(properties_dict)
            # if model_config:
            #     if self.args.load_lineage:
            #         lineage_to_index = getattr(model_config, "lineage_to_index", None)
            #         if lineage_to_index:
            #             properties_dict["lineage"] = lineage_to_index
            #     if self.args.load_location:
            #         location_to_index = getattr(model_config, "location_to_index", None) # .get()
            #         if location_to_index:
            #             properties_dict["loc"] = location_to_index

        if stage == "test":
            self.build_test_datasets(properties, properties_dict)
            
        if stage == "predict":
            self.build_predict_datasets(properties, properties_dict)
            # print(self.pred_datasets[0][0])
            # exit()
        
        self.set_collate_func(properties_dict, model_config)
    
    def build_predict_datasets(self, properties, properties_dict):
        self.pred_datasets = []
        for pred_data_path in self.args.pred_data_paths:
            # print(pred_data_path)
            pred_dataset = self.build_predicting_set(pred_data_path, properties, properties_dict)
            
            if self.args.remap_continent:
                pred_dataset = self.remap_continent(pred_dataset)
            
            # print(pred_dataset[0])
            
            self.pred_datasets.append(pred_dataset)

    def build_test_datasets(self, properties, properties_dict, *args, **argv):
        self.test_datasets = []
        for test_data_path in self.args.test_data_paths:
            test_dataset = self.build_testing_set(test_data_path, properties)
            if self.args.remap_continent:
                test_dataset = self.remap_continent(test_dataset)
            self.test_datasets.append(test_dataset)

    def build_predicting_set(self, pred_data_path, properties, *args, **argv):
        return TemporalFastaDataset(read_fasta(pred_data_path), self.vocab, get_time_method="kw", properties=properties)

    def build_property_weighted_random_sampler(self,):
        property_size = defaultdict(int)
        weights = []
        for i, data in enumerate(self.train_dataset):
            property_size[data[self.args.data_properties[0]]] += 1
        # print(property_size)
        for i, data in enumerate(self.train_dataset):
            weights.append(len(self.train_dataset) / property_size[data[self.args.data_properties[0]]])
        
        # print(len(self.train_dataset))
        # print(self.train_dataset[0], self.train_dataset[1], self.train_dataset[2])
        # print(self.train_dataset[-1], self.train_dataset[-2], self.train_dataset[-3])
        # print(weights[:3], weights[-3:])
        # print(len(weights))
        # exit()
        sampler = torch.utils.data.WeightedRandomSampler(
            weights, 
            len(weights),
            # self.args.epoch_sample_num if self.args.epoch_sample_num is not None else len(self.train_dataset), 
            replacement=True)
        # print(len(sampler), list(sampler)[:3])
        # print(len(sampler), list(sampler)[:3])
        
        # prop_count = defaultdict(int)
        # for idx in list(sampler):
        #     prop_count[self.train_dataset[idx][self.args.data_properties[0]]] += 1
        # print(prop_count)
        return sampler
        

    def train_dataloader(self, ):
        if self.args.property_weighted_random_sampler:
            sampler = self.build_property_weighted_random_sampler()
            shuffle = False
        else:
            sampler = None
            shuffle = True
        
        train_loader = DataLoader(
            self.train_dataset, 
            batch_size=self.args.batch_size, 
            shuffle=shuffle, 
            pin_memory=self.args.pin_memory, 
            num_workers=self.args.num_workers, 
            persistent_workers=self.args.persistent_workers,
            collate_fn=self.collate_fn,
            sampler=sampler
        )
        # for batch in train_loader:
        #     print(batch)
        #     print(self.full_dataset[batch["index"][0].item()])
        #     print(self.full_dataset[batch["index"][1].item()])
        #     break
        # # for batch in train_loader:
        # #     print(batch["location"])
        # #     break
        # exit()
        return train_loader

    def val_dataloader(self):
        val_loader = DataLoader(
            self.val_dataset, 
            batch_size=self.args.batch_size, 
            shuffle=False, 
            num_workers=self.args.num_workers, 
            persistent_workers=self.args.persistent_workers,
            collate_fn=self.collate_fn
        )
        return val_loader

    def test_dataloader(self, test_datasets=None, repeat_size=1, load_history=False, batched=False):
        test_loaders = []
        if test_datasets is None:
            test_datasets = self.test_datasets
            
        for test_dataset in test_datasets:
            test_loaders.append(
            DataLoader(
            test_dataset, 
            batch_size=self.args.batch_size, 
            shuffle=False, 
            pin_memory=self.args.pin_memory, 
            num_workers=self.args.num_workers, 
            persistent_workers=self.args.persistent_workers,
            collate_fn=self.collate_fn
            ))
        return test_loaders
    
    def predict_dataloader(self,):
        pred_loaders = []
        for pred_dataset in self.pred_datasets:
            pred_loaders.append(
            DataLoader(
            pred_dataset, 
            batch_size=self.args.batch_size, 
            shuffle=False, 
            pin_memory=self.args.pin_memory, 
            num_workers=self.args.num_workers, 
            persistent_workers=self.args.persistent_workers,
            collate_fn=self.collate_fn
            ))
        return pred_loaders

if __name__ == "__main__":
    from argparse import ArgumentParser
    parser = ArgumentParser()
    parser = ProteinDataModule.add_argparse_args(parser)
    args = parser.parse_args()
    print(dict(args))