from collections import defaultdict,Counter
import torch
import logging, csv, os
import numpy as np
from data.io import read_fasta
from tqdm import tqdm
from esm import inverse_folding
import biotite.structure.io as bsio
from data.utils import calc_rmsd, RMSDcalculator
from Bio import pairwise2


def get_time_simple(desc):
    return float(desc.split()[-1])

def get_time_kw(desc, key="day"):
    desc_dict = {x.split("=")[0]: x.split("=")[1] for x in desc.split("|")}
    return float(desc_dict[key])

def get_value_from_desc(desc, key="day"):
    if len(desc) == 0:
        return None
    desc_dict = {x.split("=")[0]: x.split("=")[1] for x in desc.split("|")}
    try:
        return float(desc_dict[key])
    except Exception as e:
        if key in desc_dict:
            return desc_dict[key]
        else:
            return None


class TemporalPairwiseFastaDataset(torch.utils.data.Dataset):
    def __init__(self, src_dataset, tgt_dataset, vocab, get_time_method="simple", sample_weights=None) -> None:
        super().__init__()
        self.src_dataset = src_dataset
        self.tgt_dataset = tgt_dataset
        self.sample_weights = sample_weights
        # src_dataset/tgt_dataset List[(id, seq(str/tensor), desc)]
        self.vocab = vocab
        # if get_time_method :
        #     self.get_time_func = None # No need to get time information.
        if get_time_method == "simple":
            self.get_time_func = get_time_simple
        elif get_time_method == "kw":
            self.get_time_func = get_time_kw
        else:
            self.get_time_func = None
            logging.warn("Wrong get_time_method : %s." % self.get_time_func)
        # self.max_mm_num = args.max_mm_num
        # self.src_dataset = []
        # self.tgt_dataset = []

        # for src, tgt in zip(src_dataset, tgt_dataset):
        #     if isinstance(src[1], str) and isinstance(tgt[1], str):
        #         src_token = self.vocab.encode(src[1])
        #         src_time = float(src[-1].split()[-1])
        #         tgt_token = self.vocab.encode(tgt[1])
        #         tgt_time = float(tgt[-1].split()[-1])
        #     elif isinstance(src[1], torch.Tensor) and isinstance(tgt[1], torch.Tensor):
        #         _, src_token, src_time = src
        #         _, tgt_token, tgt_time = tgt
        #     else:
        #         raise NotImplementedError

        #     mm_num = (src_token != tgt_token).sum().item()
        #     time_diff = tgt_time - src_time
        #     if (args.max_mm_num < 0 or mm_num <= args.max_mm_num) and (args.max_time_diff < 0 or time_diff <= args.max_time_diff):
        #         self.src_dataset.append((src[0], src_token, src_time))
        #         self.tgt_dataset.append((tgt[0], tgt_token, tgt_time))
        # assert len(self.src_dataset) == len(self.tgt_dataset)
        # logging.info("Read %d dataset from %d samples" % (len(self.src_dataset), len(src_dataset)))
        # self.padding_all(vocab.pad)

    def padding_all(self, pad_idx):
        max_len = max([x[1].size(0) for x in self.src_dataset])
        new_src_dataset = []
        new_tgt_dataset = []
        for data in self.src_dataset:
            tokens = data[1].new_zeros((max_len, ))
            tokens.fill_(pad_idx)
            tokens[:len(data[1])] = data[1]
            new_src_dataset.append((data[0], tokens, data[-1]))
        
        for data in self.tgt_dataset:
            tokens = data[1].new_zeros((max_len, ))
            tokens.fill_(pad_idx)
            tokens[:len(data[1])] = data[1]
            new_tgt_dataset.append((data[0], tokens, data[-1]))

        self.tgt_dataset = new_tgt_dataset
        self.src_dataset = new_src_dataset

    def __len__(self, ):
        return len(self.src_dataset)
        
    def __getitem__(self, index):
        if self.get_time_func is not None:
            return {
                "src_seq": self.src_dataset[index][1],
                "tgt_seq": self.tgt_dataset[index][1],
                "src_time": self.get_time_func(self.src_dataset[index][-1].split()[1]), # float(src[-1].split()[-1])
                "tgt_time": self.get_time_func(self.tgt_dataset[index][-1].split()[1]), # float(self.tgt_dataset[index][-1].split()[-1])
                "index": index,
                "weights": self.sample_weights[index] if self.sample_weights is not None else 1.0
            }
        else:
            return {
                "src_seq": self.src_dataset[index][1],
                "tgt_seq": self.tgt_dataset[index][1],
                # "src_time": self.get_time_func(self.src_dataset[index][-1].split()[1]), # float(src[-1].split()[-1])
                # "tgt_time": self.get_time_func(self.tgt_dataset[index][-1].split()[1]), # float(self.tgt_dataset[index][-1].split()[-1])
                "index": index,
                # "weights": self.sample_weights[index] if self.sample_weights is not None else 1.0
            }

class TemporalUnpairedFastaDataset(torch.utils.data.Dataset):
    def __init__(self, src_dataset, tgt_dataset, vocab, get_time_method="simple", source_sample_num=1, src_enc_out=None) -> None:
        super().__init__()
        self.src_dataset = src_dataset # TODO: could be encoded? 
        self.tgt_dataset = tgt_dataset
        self.source_sample_num = source_sample_num # For each target, sample how many sources for it.
        # src_dataset/tgt_dataset List[(id, seq(str/tensor), desc)]
        self.vocab = vocab
        self.src_enc_out = src_enc_out
        
        if get_time_method == "simple":
            self.get_time_func = get_time_simple
        elif get_time_method == "kw":
            self.get_time_func = get_time_kw
        else:
            raise ValueError("Please set the right get_time_method.")
        
        self.tgt_to_src_index = self.build_tgt_to_src_mapping(len(self.src_dataset), len(self.tgt_dataset), source_sample_num)
        # print(self.tgt_to_src_index) # [tgt_size x sample_num]

    def build_tgt_to_src_mapping(self, src_size, tgt_size, num_samples):
        return torch.multinomial(torch.ones(tgt_size, src_size), num_samples, replacement=True)

    def __len__(self, ):
        return len(self.tgt_dataset) * self.source_sample_num
        
    def __getitem__(self, index):
        tgt_index = index // self.source_sample_num
        src_index = self.tgt_to_src_index[tgt_index, index % self.source_sample_num].item()
        ret = {
            "src_seq": self.src_dataset[src_index][1],
            "tgt_seq": self.tgt_dataset[tgt_index][1],
            "src_time": self.get_time_func(self.src_dataset[src_index][-1].split()[1]), # float(src[-1].split()[-1])
            "tgt_time": self.get_time_func(self.tgt_dataset[tgt_index][-1].split()[1]) # float(self.tgt_dataset[index][-1].split()[-1])
        }
        if self.src_enc_out is not None:
            ret["src_enc_out"] = self.src_enc_out[src_index]
        return ret

class TemporalFastaDataset(torch.utils.data.Dataset):
    def __init__(self, src_dataset, vocab, get_time_method="simple", properties=['day'], pre_tokenize=False, other_attributes=None) -> None:
        super().__init__()
        # Monolingual
        self.src_dataset = src_dataset
        self.vocab = vocab
        self.get_time_method = get_time_method
        self.pre_tokenize = pre_tokenize
        self.other_attributes = other_attributes

        if pre_tokenize:
            self.src_seq_toks = []
            for _, seq, _ in tqdm(self.src_dataset, desc="Pre tokenize"):
                self.src_seq_toks.append(vocab.encode(seq))

        if get_time_method == "simple":
            self.get_time_func = get_time_simple
        elif get_time_method == "kw":
            self.get_time_func = get_value_from_desc
            self.properties = properties
        else:
            raise ValueError("Please set the right get_time_method.")

    def padding_all(self, pad_idx):
        max_len = max([x[1].size(0) for x in self.src_dataset])
        new_src_dataset = []
        new_tgt_dataset = []
        for data in self.src_dataset:
            tokens = data[1].new_zeros((max_len, ))
            tokens.fill_(pad_idx)
            tokens[:len(data[1])] = data[1]
            new_src_dataset.append((data[0], tokens, data[-1]))
        
        for data in self.tgt_dataset:
            tokens = data[1].new_zeros((max_len, ))
            tokens.fill_(pad_idx)
            tokens[:len(data[1])] = data[1]
            new_tgt_dataset.append((data[0], tokens, data[-1]))

        self.tgt_dataset = new_tgt_dataset
        self.src_dataset = new_src_dataset

        # self.processed_dataset = []
        # for index in range(len(self.src_dataset)):
        #     if self.get_time_method == "simple":
        #         ret = {
        #             "index": index,
        #             "src_id": self.src_dataset[index][0],
        #             "src_seq": self.src_dataset[index][1],
        #             "src_time": self.get_time_func(self.src_dataset[index][-1].split()[1]), # float(src[-1].split()[-1])
        #         }
        #     else:
        #         ret = {
        #             "index": index,
        #             "src_id": self.src_dataset[index][0],
        #             "src_seq": self.src_dataset[index][1],
        #         }
        #         desc = " ".join(self.src_dataset[index][-1].split()[1:])
        #         ret["src_time"] = self.get_time_func(desc, key=self.properties[0]) # float(src[-1].split()[-1])
        #         for key in self.properties[1:]:
        #             ret[key] = self.get_time_func(desc, key=key)
        #         # return ret
        #     self.processed_dataset.append(ret)

    def __len__(self, ):
        return len(self.src_dataset)
        
    def __getitem__(self, index):
        # return self.processed_dataset[index]
        if self.get_time_method == "simple":
            return {
                "index": index,
                "src_id": self.src_dataset[index][0],
                "src_seq": self.src_seq_toks[index] if self.pre_tokenize else self.src_dataset[index][1],
                "src_time": self.get_time_func(self.src_dataset[index][-1].split()[1]), # float(src[-1].split()[-1])
            }
        else:
            ret = {
                "index": index,
                "src_id": self.src_dataset[index][0],
                "src_seq": self.src_seq_toks[index] if self.pre_tokenize else self.src_dataset[index][1],
            }
            desc = " ".join(self.src_dataset[index][-1].split()[1:])
            ret["src_time"] = self.get_time_func(desc, key=self.properties[0]) # float(src[-1].split()[-1])
            for key in self.properties[1:]:
                ret[key] = self.get_time_func(desc, key=key)
            if self.other_attributes is not None:
                for key in self.other_attributes:
                    ret[key] = self.other_attributes[key][self.src_dataset[index][0]]
            return ret

class TemporalMSAAlignDataset(torch.utils.data.Dataset):
    def __init__(self, src_dataset, vocab,get_time_method="simple", properties=['day'], pre_tokenize=False, other_attributes=None) -> None:
        super().__init__()

        msa_ids = list()
        msa_records = list()
        for src_id, seq, desc in src_dataset:
            msa_id = src_id.split("#")[0]
            real_id = src_id.split("#")[1]
            # print(msa_id, real_id)
            # exit()
            if msa_id not in msa_ids:
                msa_ids.append(msa_id)
                msa_records.append([])
                msa_records[-1].append((real_id, seq, desc))
            else:
                msa_records[-1].append((real_id, seq, desc))
        self.src_dataset = msa_records


        self.vocab = vocab
        self.get_time_method = get_time_method
        self.pre_tokenize = pre_tokenize
        self.other_attributes = other_attributes

        if get_time_method == "simple":
            self.get_time_func = get_time_simple
        elif get_time_method == "kw":
            self.get_time_func = get_value_from_desc
            self.properties = properties
        else:
            raise ValueError("Please set the right get_time_method.")

    def __len__(self, ):
        return len(self.src_dataset)
        
    def __getitem__(self, index):
        # return self.processed_dataset[index]
        if self.get_time_method == "simple":
            return {
                "index": index,
                "src_id": self.src_dataset[index][0],
                "src_seq": self.src_seq_toks[index] if self.pre_tokenize else self.src_dataset[index][1],
                "src_time": self.get_time_func(self.src_dataset[index][-1].split()[1]), # float(src[-1].split()[-1])
            }
        else:
            ret = {
                "index": index,
                "src_id": [x[0] for x in self.src_dataset[index]],
                "src_seq": [x[1] for x in self.src_dataset[index]],
            }
            ret["src_time"] = list()
            for x in self.src_dataset[index]:
                desc = " ".join(x[-1].split()[1:])
                ret["src_time"].append(self.get_time_func(desc, key=self.properties[0]))  # float(src[-1].split()[-1])
                
                for key in self.properties[1:]:
                    if key not in ret:
                        ret[key] = list()
                    ret[key].append(self.get_time_func(desc, key=key))
                
                if self.other_attributes is not None:
                    for key in self.other_attributes:
                        if key not in ret:
                            ret[key] = list()
                        ret[key].append(self.other_attributes[key][x[0]])
            return ret


class MultiTaskClassificationDataset(torch.utils.data.Dataset):
    def __init__(self, src_dataset, vocab, meta_data, classification_tasks, binary, predict=False) -> None:
        super().__init__()
        self.src_dataset = src_dataset
        self.vocab = vocab
        self.meta_data = meta_data
        self.classification_tasks = classification_tasks
        self.binary = binary
        self.predict = predict

    def __len__(self, ):
        return len(self.src_dataset)
        
    def __getitem__(self, index):
        # return self.processed_dataset[index]
        ret = {
            "index": index,
            "src_id": self.src_dataset[index][0],
            "src_seq": self.src_dataset[index][1],
        }
        

        for task in self.classification_tasks:
            if not self.predict: # with label
                if self.binary:
                    if task == "Host":
                        host = self.meta_data[self.src_dataset[index][0]][task]
                        if host == "Human": 
                            ret[task] = host
                        else:
                            ret[task] = "Not_human"
                else:
                    ret[task] = self.meta_data[self.src_dataset[index][0]][task]
        return ret

class PairwiseClassificationDataset(torch.utils.data.Dataset):
    def __init__(self, src_dataset, vocab, index_data, category=True) -> None:
        super().__init__()
        self.src_dataset = src_dataset
        self.vocab = vocab
        self.index_data = index_data
        self.category = category
        self.src_id_to_records = {x[0]: x for x in src_dataset}

    def __len__(self, ):
        return len(self.index_data)

    def __getitem__(self, index):
        # return self.processed_dataset[index]
        id1, id2, value = self.index_data[index][0], self.index_data[index][1], self.index_data[index][-1]
        if not self.category:
            value = float(value)
        seq = self.vocab.concat(self.src_id_to_records[id2][1], self.src_id_to_records[id1][1])
        ret = {
            "index": index,
            "src_id1": self.src_id_to_records[id1][0],
            "src_id2": self.src_id_to_records[id2][0],
            "src_seq": seq,
            "label": value
        }
        return ret

class StructsPairwiseClassificationDataset(torch.utils.data.Dataset):
    def __init__(self, src_dataset, vocab, cache_memory=False) -> None:
        super().__init__()
        self.src_dataset = src_dataset
        self.vocab = vocab
        self.cache_memory = cache_memory
        self.cache = []

    def __len__(self, ):
        return len(self.src_dataset)

    def align_structure(self, coords1, coords2, aln_seq1, aln_seq2, native_seq1, native_seq2):
        indices1, indices2 = [], []
        i1 = native_seq1.index("".join([c for c in aln_seq1 if c != "-"])) - 1
        i2 = native_seq2.index("".join([c for c in aln_seq2 if c != "-"])) - 1
        for c1, c2 in zip(aln_seq1, aln_seq2):
            if c1 != "-":
                i1 += 1
            if c2 != "-":
                i2 += 1
            if c1 != "-" and c2 != "-":
                indices1.append(i1)
                indices2.append(i2)
        
        indices1 = np.asarray(indices1)
        indices2 = np.asarray(indices2)
        return coords1[indices1], coords2[indices2]


        # def replace_missing_aa_by_gap(long_seq, short_seq):
        #     # long_seq: -ABCX-EF--, shot_seq: ABCEF
        #     # return -ABC--EF--
        #     i, j = 0, 0
        #     missing_index = []
        #     while i < len(long_seq) and j < len(short_seq):
        #         if long_seq[i] == short_seq[j]:
        #             i += 1
        #             j += 1
        #         else:
        #             missing_index.append(i)
        #             i += 1
        #     while i < len(long_seq):
        #         missing_index.append(i)
        #         i += 1
        #     new_long_seq = list(long_seq)
        #     for index in missing_index:
        #         new_long_seq[index] = "-"
        #     return "".join(new_long_seq)

        # aln_seq1 = replace_missing_aa_by_gap(aln_seq1, native_seq1)
        # aln_seq2 = replace_missing_aa_by_gap(aln_seq2, native_seq2)

        try:
            i1 = native_seq1.index("".join([c for c in aln_seq1 if c != "-"])) - 1
            i2 = native_seq2.index("".join([c for c in aln_seq2 if c != "-"])) - 1
        except Exception as e:
            print(e)
            print(len(ori_seq2), len(aln_seq2), len("".join([c for c in aln_seq2 if c != "-"])))
            print(len(ori_seq1), len(aln_seq1), len("".join([c for c in aln_seq1 if c != "-"])))
            # print("".join([c for c in aln_seq2 if c != "-"]))
            if "".join([c for c in aln_seq1 if c != "-"]) not in ori_seq1:
                alignments = pairwise2.align.globalms("".join([c for c in aln_seq1 if c != "-"]), ori_seq1, 2, -1, -2, -1)[0]
                print(len(alignments[0]), len(alignments[1]))
                for i, (x, y) in enumerate(zip(alignments[0], alignments[1])):
                    if x != y:
                        print(i, x, y)
            
            if "".join([c for c in aln_seq2 if c != "-"]) not in ori_seq2:
                alignments = pairwise2.align.globalms("".join([c for c in aln_seq2 if c != "-"]), ori_seq2, 2, -1, -2, -1)[0]
                for i, (x, y) in enumerate(zip(alignments[0], alignments[1])):
                    if x != y:
                        print(i, x, y)
            
            # print(alignments)
            exit()
        
        for c1, c2 in zip(aln_seq1, aln_seq2):
            if c1 != "-":
                i1 += 1
            if c2 != "-":
                i2 += 1
            
            if c1 != "-" and c2 != "-":
                indices1.append(i1)
                indices2.append(i2)
        
        indices1 = np.asarray(indices1)
        indices2 = np.asarray(indices2)
        return coords1[indices1], coords2[indices2]

    def count_labels(self, ):
        # we don't need to load the pdb here!
        all_values = []
        for data in self.src_dataset:
            value = data[-1]
            all_values.append(value)
        all_values = dict(Counter(all_values).most_common())
        logging.info("Labels in training set: " + str(all_values))
        return all_values

    def __getitem__(self, index):
        # return self.processed_dataset[index]
        # id1, id2, value = self.index_data[index][0], self.index_data[index][1], self.index_data[index][-1]

        if self.cache_memory and index in self.cache:
            return self.cache[index]
        else:
            pdb1, pdb2, aln_seq1, aln_seq2, value = self.src_dataset[index]        
            structure_1 = inverse_folding.util.load_structure(pdb1)
            structure_2 = inverse_folding.util.load_structure(pdb2)
            coords1, native_seqs_1 = inverse_folding.multichain_util.extract_coords_from_complex(structure_1)
            coords2, native_seqs_2 = inverse_folding.multichain_util.extract_coords_from_complex(structure_2)
            coords1, coords2 = coords1["A"], coords2["A"]
            native_seqs_1, native_seqs_2 = native_seqs_1["A"], native_seqs_2["A"]

            alignments = pairwise2.align.globalms(native_seqs_1, native_seqs_2, 2, -1, -2, -1)[0]
            aln_seq1, aln_seq2 = alignments[0], alignments[1]
            aln_coords1, aln_coords2 = self.align_structure(coords1, coords2, aln_seq1, aln_seq2, native_seqs_1, native_seqs_2)
        
        # TODO: how to load the confidence score?
        # confidence_score1 = bsio.load_structure(pdb1, extra_fields=["b_factor"]).b_factor
        # confidence_score2 = bsio.load_structure(pdb2, extra_fields=["b_factor"]).b_factor
        # print(confidence_score1.shape, confidence_score2.shape)
        # print(confidence_score1.mean(), confidence_score2.mean())
        # print(confidence_score1.max(), confidence_score1.min())
        # print(confidence_score2.max(), confidence_score1.min())

            coords1_new = RMSDcalculator(aln_coords1.reshape(-1, 3), aln_coords2.reshape(-1, 3)).apply(coords1)
            coords2_new = coords2 - np.mean(coords2, axis=0, keepdims=True)
            coords1_new, coords2_new = coords1_new.reshape(-1, 3, 3), coords2_new.reshape(-1, 3, 3)

            # seq = self.vocab.concat(self.src_id_to_records[id2][1], self.src_id_to_records[id1][1])
            ret = {
                "index": index,
                "coords1": coords1_new,
                "coords2": coords2_new,
                "seq1": native_seqs_1,
                "seq2": native_seqs_2,
                # "src_id1": self.src_id_to_records[id1][0],
                # "src_id2": self.src_id_to_records[id2][0],
                # "src_seq1": seq,
                # "src_seq2": seq,
                "label": value
            }
            if self.cache_memory:
                self.cache[index] = ret
            return ret

class PairwiseAlnClassificationDataset(torch.utils.data.Dataset):
    def __init__(self, src_dataset, vocab, category=True, loss_weights = None,
    prepend_special_token_for_seq1=None,
    prepend_special_token_for_seq2=None,
    numerical=False, numerical_interval=None 
    ) -> None:
        # loss_weights: a list with the same size of src_dataset

        super().__init__()
        self.src_dataset = src_dataset
        self.vocab = vocab
        self.category = category
        self.numerical_interval = numerical_interval
        self.numerical = numerical
        # self.numerical_min = None
        self.prepend_special_token_for_seq1 = prepend_special_token_for_seq1
        self.prepend_special_token_for_seq2 = prepend_special_token_for_seq2
        if loss_weights is not None:
            assert len(loss_weights) == len(src_dataset), "Loss weights list (size=%d) should have the same size as src_dataset (size=%d)" % (len(loss_weights), len(src_dataset))
            self.loss_weights = loss_weights
        else:
            self.loss_weights = None

    def __len__(self, ):
        return len(self.src_dataset)

    def __getitem__(self, index):
        # return self.processed_dataset[index]
        # print(self.src_dataset[index])
        id1, id2, seq1, seq2, value = self.src_dataset[index]
        if not self.category and not self.numerical: # continuous
            value = float(value)
        elif self.numerical:
            value = int(float(value) // self.numerical_interval) #  * self.numerical_interval

        if "#" in seq1:
            seqs1 = seq1.split("#")
        else:
            seqs1 = [seq1]
        # print(seqs1)
        
        if self.prepend_special_token_for_seq1 is not None:
            seqs1 = [self.prepend_special_token_for_seq1 + x for x in seqs1]
        
        if "#" in seq2:
            seqs2 = seq2.split("#")
        else:
            seqs2 = [seq2]
        
        # print(seqs2)
        
        if self.prepend_special_token_for_seq2 is not None:
            seqs2 = [self.prepend_special_token_for_seq2 + x for x in seqs2]

        ret = {
            "index": index,
            "src_id1": id1,
            "src_id2": id2,
            "src_seq": seqs1 + seqs2,
            "label": value,
            "seq_label": [0] * len(seqs1) + [1] * len(seqs2),
            "ref_seq_label": [1] + [0] * (len(seqs1) - 1) + [1] + [0] * (len(seqs2) - 1),
            "loss_weight": self.loss_weights[index] if self.loss_weights is not None else 1.0
        }

        return ret

class TemporalMultiFastaDataset(TemporalFastaDataset):
    def __init__(self, src_dataset, vocab, get_time_method="simple", properties=['day'], pre_tokenize=False) -> None:
        super().__init__(None, vocab, get_time_method, properties, pre_tokenize)
        self.src_dataset = []
        self.genes = [x.split(":")[0] for x in src_dataset[0][1].split("|")]
        print(self.genes)
        for acc_id, seq, desc in src_dataset:
            _seqs = [x.split(":")[1] for x in seq.split("|")]
            self.src_dataset.append((acc_id, _seqs, desc))
        
    def __getitem__(self, index):
        gene_annotation = []
        for i, gene in enumerate(self.genes):
            if i == 0:
                gene_annotation.extend([gene] * len(self.src_dataset[index][1][i]))
            else:
                gene_annotation.append("<sep>")
                gene_annotation.extend([gene] * len(self.src_dataset[index][1][i]))
        ret = {
            "index": index,
            "src_id": self.src_dataset[index][0],
            "src_seq":  "<sep>".join(self.src_dataset[index][1]),
            "gene_annotation": gene_annotation,
        }
        desc = " ".join(self.src_dataset[index][-1].split()[1:])
        ret["src_time"] = self.get_time_func(desc, key=self.properties[0]) # float(src[-1].split()[-1])
        for key in self.properties[1:]:
            ret[key] = self.get_time_func(desc, key=key)
        return ret

        
class TemporalMSADataset(torch.utils.data.Dataset):
    def __init__(self, data_dir, vocab, get_time_method="simple", properties=['day'], filtering_func=None, cache_token=True, representative_fasta_path=None) -> None:
        super().__init__()
        self.src_paths = list()
        self.acc_ids = list()
        self.cache_token = cache_token # Cache tokenize
        self.tokens = []
        if representative_fasta_path is not None:
            self.representative_records = read_fasta(representative_fasta_path)
            self.acc_id2representative_record = {x[0]: x for x in self.representative_records}
        
        for file in tqdm(os.listdir(data_dir)):
            if os.path.getsize(os.path.join(data_dir, file)) == 0:
                continue
            if file.endswith('.a3m'):
                self.src_paths.append(os.path.join(data_dir, file))
                self.acc_ids.append(file.split(".a3m")[0])
            
            # if pre_tokenize:
            #     if cache_tokens and os.path.exists(os.path.join(data_dir, file) + ".tok"):
            #         src_tokens = np.load(os.path.exists(os.path.join(data_dir, file) + ".tok"))
            #     else:
            #         records = read_fasta(os.path.join(data_dir, file), quiet=True)
            #         src_tokens = [vocab.encode(x[1]) for x in records]
            #         if cache_tokens:
            #             np.save(src_tokens, os.path.exists(os.path.join(data_dir, file) + ".tok"))
            #     self.tokens.append(src_tokens)

            # self.src_msa[file.split(".a3m")[0]] = read_fasta(os.path.join(data_dir, file), quiet=True)

        self.vocab = vocab
        self.get_time_method = get_time_method
        self.filtering_func = filtering_func

        if get_time_method == "simple":
            self.get_time_func = get_time_simple
        elif get_time_method == "kw":
            self.get_time_func = get_value_from_desc
            self.properties = properties
        else:
            raise ValueError("Please set the right get_time_method.")

    def __len__(self, ):
        return len(self.src_paths)
        
    def __getitem__(self, index):
        # return self.processed_dataset[index]
        if self.get_time_method == "simple":
            raise NotImplementedError
            return {
                "index": index,
                "src_id": self.src_dataset[index][0],
                "src_seq": self.src_dataset[index][1],
                "src_time": self.get_time_func(self.src_dataset[index][-1].split()[1]), # float(src[-1].split()[-1])
            }
        else:
            data = read_fasta(self.src_paths[index], quiet=True)
            if self.cache_token:
                if os.path.exists(self.src_paths[index] + ".cache"):
                    toks = np.load(self.src_paths[index] + ".cache")
                else:
                    toks = [self.vocab.encode(x[1]) for x in data]
                    np.save(self.src_paths[index] + ".cache", toks)
            
            ref_inx = [x[0] for x in data].index(self.acc_ids[index]) # Get reference ids.
            print(ref_inx)

            ret = {
                "index": index,
                "src_id": [x[0] for x in data],
                "src_seq": toks if self.cache_token else [x[1] for x in data],
            }
            ret["src_time"] = list()
            ret["mismatch_num"] = list()
            for key in self.properties[1:]:
                ret[key] = list()
            for x in data:
                desc = " ".join(x[-1].split()[1:])
                ret["src_time"].append(self.get_time_func(desc, key=self.properties[0]))# float(src[-1].split()[-1])
                ret["mismatch_num"].append(self.get_time_func(desc, key="msa_score"))
                for key in self.properties[1:]:
                    ret[key].append(self.get_time_func(desc, key=key))
            if self.filtering_func is not None:
                try:
                    return self.filtering_func(ret, ref_idx=ref_inx)
                except Exception as e:
                    print(e)
                    print(self.src_paths[index])
            return ret

class TemporalBlockedMSADataset(torch.utils.data.Dataset):
    def __init__(self, src_dataset, block_size: int = 128, window_size: int = 5, bidirection_window: bool = True) -> None:
        super().__init__()
        # src_dataset: a dict, key is the time, value is the list of sequences at each time slot;
        # block_size: how many sequences within each block
        # window_size: how many time slots are considered as the context
        # bidirection_window: if true, we use previous and afterward "window_size" time slots; else we only use previous;

        self.window_size = window_size
        self.bidirection_window = bidirection_window
        self.src_dataset = src_dataset

        sorted_keys = list(src_dataset.keys())
        sorted_keys.sort()
        self.min_time = sorted_keys[0]
        self.max_time = sorted_keys[-1]
        
        # build the block coordinates: time slot, start index as well as the end index;
        block_coords = []
        for time in sorted_keys:
            datasize = len(self.src_dataset[time])
            if datasize <= block_size:
                block_coords.append((time, 0, datasize))
            else:
                for start in range(0, datasize, block_size):
                    block_coords.append((time, start, min(start+block_size, datasize)))
        self.target_block_coords = [x for x in block_coords if x[0] != sorted_keys[0]]
        self.time2blocks = defaultdict(list)
        for time, start, end in block_coords:
            self.time2blocks[time].append((start, end))

    def shuffle(self, ):
        # NOTE: shuffle the sequences within each time slot
        # for time in self.src_dataset:
        raise NotImplementedError

    def __len__(self, ):
        return len(self.target_block_coords)
        
    def __getitem__(self, index):
        # Get the target block
        target_block_coords = self.target_block_coords[index]
        target_block = self.src_dataset[target_block_coords[0]][target_block_coords[1]: target_block_coords[2]]
        target_time = target_block_coords[0]
        
        src_seqs = [[]]
        src_times = [[]]
        for previous_time in range(max(target_time - self.window_size, self.min_time), target_time):
            sample_block = np.random.choice(len(self.time2blocks[previous_time]), 1)[0]
            src_start, src_end = self.time2blocks[previous_time][sample_block]
            src_block = self.src_dataset[previous_time][src_start: src_end]
            src_seqs[0].append(src_block)
            src_times[0].append(previous_time)
        
        if self.bidirection_window:
            src_seqs.append([])
            src_times.append([])
            for afterwards_time in range(target_time + 1, min(target_time + self.window_size + 1, self.max_time)):
                sample_block = np.random.choice(len(self.time2blocks[afterwards_time]), 1)[0]
                src_start, src_end = self.time2blocks[afterwards_time][sample_block]
                src_block = self.src_dataset[afterwards_time][src_start: src_end]
                src_seqs[1].append(src_block)
                src_times[1].append(afterwards_time)

        return {
            "src_seq": src_seqs, 
            "src_time": src_times,
            "tgt_seq": target_block,
            "tgt_time": target_time
        }

class PairwiseDataset(torch.utils.data.Dataset):
    # Combine two dataset together, according to the matching_properties.
    def __init__(self, src_dataset, tgt_dataset, align_keys=("src_id", ), aligned=False) -> None:
        super().__init__()
        # assert matching_type in ("acc_id", "time_loc_lineage") # One-to-one, or one-to-many
        
        self.src_dataset = src_dataset # src_dataset
        self.tgt_dataset = tgt_dataset
        # assert len(src_dataset) == len(tgt_dataset)
        self.attributes = {}
        
        if aligned:
            if len(src_dataset) == 1:
                self.src_dataset = [src_dataset[0]] * len(tgt_dataset)
            assert len(self.src_dataset) == len(tgt_dataset)
            self.src_valid_index = list(range(len(self.src_dataset)))
            self.tgt_valid_index = list(range(len(self.tgt_dataset)))
        else:
            self.src_acc_id2tgt_indices = defaultdict(list)
            for i, x in enumerate(self.tgt_dataset):
                src_acc_ids = "_".join([x.get(prop, "") for prop in align_keys]) # x["src_id"]
                # if matching_type == "acc_id":
                #     src_acc_ids = x["src_id"]
                # elif matching_type == "time_loc_lineage":
                #     src_acc_ids = "%s_%s_%s" % (x["src_time"], x.get("loc", ""), x.get("lineage", ""))
                # else:
                #     raise NotImplementedError
                self.src_acc_id2tgt_indices[src_acc_ids].append(i)
            # self.src_id2tgt_index = {x["src_id"]: i for i, x in enumerate(self.tgt_dataset)}

            # print(len(src_dataset.acc_ids))
            self.src_valid_index = []
            self.tgt_valid_index = []


            if len(align_keys) == 1 and align_keys[0] == "src_id":
                for ori_src_index, acc_id in enumerate(src_dataset.acc_ids):
                    if acc_id in self.src_acc_id2tgt_indices:
                        for ori_tgt_index in self.src_acc_id2tgt_indices[acc_id]:
                            self.src_valid_index.append(ori_src_index)
                            self.tgt_valid_index.append(ori_tgt_index)
            else:
                if getattr(src_dataset, "acc_id2representative_record", None) is not None:
                    repr_records = src_dataset.acc_id2representative_record
                else:
                    raise NotImplementedError("You should read all data.")

                self.src_acc_id2src_indices = defaultdict(list)
                for ori_src_index, acc_id in enumerate(src_dataset.acc_ids):
                    desc = {x.split("=")[0]: x.split("=")[1] for x in repr_records[acc_id][-1].split()[1].split("|")}
                    src_acc_id = "_".join([desc.get(prop, "") for prop in align_keys]) # x["src_id"]
                    self.src_acc_id2src_indices[src_acc_id].append(ori_src_index)

                for acc_id in self.src_acc_id2src_indices:
                    if acc_id in self.src_acc_id2tgt_indices:
                        for ori_src_index in self.src_acc_id2src_indices[acc_id]:
                            for ori_tgt_index in self.src_acc_id2tgt_indices[acc_id]:
                                self.src_valid_index.append(ori_src_index)
                                self.tgt_valid_index.append(ori_tgt_index)

    def __len__(self, ):
        return len(self.src_valid_index)
        
    def add_attributes(self, name, data):
        # setattr(self, key, data)
        self.attributes[name] = data

    def __getitem__(self, index):
        original_index_for_src = self.src_valid_index[index]
        original_index_for_tgt = self.tgt_valid_index[index]

        ret = self.src_dataset[original_index_for_src]
        # print(list(ret.keys()))
        for key in list(ret.keys()):
            if "src_" not in key:
                value = ret.pop(key)
                ret["src_%s" % key] = value
        # print(list(ret.keys()))
        # print(ret["src_time"])
        # print(ret["src_id"])
        
        tgt_dict = self.tgt_dataset[original_index_for_tgt] # self.src_id2tgt_index[ret["src_id"][0]]
        # print(list(tgt_dict.keys()))
        # print(tgt_dict["src_time"])
        # print(tgt_dict["src_id"])
        
        ## assert tgt_dict["src_id"] == ret["src_id"][0], "The first id (%s) in msa should be the target id %s." % (ret["src_id"][0], tgt_dict["src_id"])
        
        for key in list(tgt_dict.keys()):
            if "src_" in key:
                ret["tgt_%s" % key.split("src_")[1]] = tgt_dict[key]
            else:
                ret["tgt_%s" % key] = tgt_dict[key]
        
        ret["index"] = index

        for att in self.attributes:
            if isinstance(self.attributes[att][index], tuple):
                ret["src_%s" % att] = self.attributes[att][index][0]
                ret["tgt_%s" % att] = self.attributes[att][index][1]
            else:
                ret["src_%s" % att] = self.attributes[att][index]
                ret["tgt_%s" % att] = self.attributes[att][index]

        return ret

class PairwiseRandomDataset(torch.utils.data.Dataset):
    # only one dataset, but randomly sample source or target
    def __init__(self, dataset, sampling_method, src_dataset=None) -> None:
        super().__init__()
        # assert matching_type in ("acc_id", "time_loc_lineage") # One-to-one, or one-to-many
        self.dataset = dataset # src_dataset
        self.src_dataset = src_dataset
        if src_dataset is None:
            self.src_dataset = dataset
        self.epiid2index = {x["src_id"]: i for i, x in enumerate(self.src_dataset)}
        self.sampling_method = sampling_method
        self.attributes = {}
    
    def add_attributes(self, name, data):
        # setattr(self, key, data)
        self.attributes[name] = data
        
    def __len__(self, ):
        return len(self.dataset)

    def __getitem__(self, index):
        original_index_for_tgt = index
        sampled_src_epi_id, pair_prob = self.sampling_method(self.dataset[index]["src_id"]) # give the id, return a new id, as well as the sampling prob?
        original_index_for_src = self.epiid2index[sampled_src_epi_id]

        ret = self.src_dataset[original_index_for_src]
        for key in list(ret.keys()):
            if "src_" not in key:
                value = ret.pop(key)
                ret["src_%s" % key] = value
        
        tgt_dict = self.dataset[original_index_for_tgt] # self.src_id2tgt_index[ret["src_id"][0]]
        for key in list(tgt_dict.keys()):
            if "src_" in key:
                ret["tgt_%s" % key.split("src_")[1]] = tgt_dict[key]
            else:
                ret["tgt_%s" % key] = tgt_dict[key]
        
        ret["index"] = index

        for att in self.attributes:
            if isinstance(self.attributes[att][index], tuple):
                ret["src_%s" % att] = self.attributes[att][index][0]
                ret["tgt_%s" % att] = self.attributes[att][index][1]
            else:
                ret["src_%s" % att] = self.attributes[att][index]
                ret["tgt_%s" % att] = self.attributes[att][index]
        
        ret["src_pair_freq"] = ret["tgt_pair_freq"] = pair_prob

        return ret