from collections import defaultdict
from skbio import Protein
from skbio.alignment import local_pairwise_align_ssw, make_identity_substitution_matrix
import torch
import numpy as np
import time
from copy import deepcopy
# from models.utils import calc_time_sample_prob

def calc_time_sample_prob(time_diff, alpha = -0.01):
    time_diff = time_diff.view(-1)
    # time_diff: [B]
    prob = time_diff.new_zeros(time_diff.size())
    # print(prob.size())
    n = time_diff[time_diff>0] # only if time_diff > 0
    probs = (n * alpha).exp() # * p
    # print(probs)
    prob[time_diff>0] = probs
    return prob

def discretize_time(time, one_step, normalize_time_a, normalize_time_b=0, discrete=True):
    if one_step:
        return torch.argsort(time)
    else:
        normalized_time = (time - normalize_time_b) / (normalize_time_a)
        if discrete:
            return torch.ceil(normalized_time)
        else:
            return normalized_time

def year_month_to_int(input, start_date="2019-12"):
    year, month = input.split("-")
    ref_year, ref_month = start_date.split("-")
    return (int(year) - int(ref_year)) * 12 + int(month) - int(ref_month)

def tree_collate_func(list_of_dict, batch_converter, tree, padding_idx=None):
    seq_ids = [x["src_id"] for x in list_of_dict]
    # print(seq_ids)
    # batch_nodes = [tree.id2nodes[ids] for ids in seq_ids]
    batch = tree.extract_subtree(seq_ids)
    # print(batch)
    src_seqs = [("", remove_redundant_gaps(list_of_dict[rid]["src_seq"])) for rid in batch["observed_revserse_index"]]
    _, _, src_tokens = batch_converter(src_seqs)
    # print(src_tokens.size())
    src_time = torch.tensor([list_of_dict[rid]['src_time'] for rid in batch["observed_revserse_index"]])
    # print(src_time)
    # for key in batch
        # print
    # print(padding_idx)
    if padding_idx is not None:
        attention_mask = (src_tokens != padding_idx).float()
        # print(attention_mask)
    else:
        attention_mask = src_tokens.new_ones(src_tokens.size())
    ret = {
        "input_ids": src_tokens, # [B x L]
        "input_time": src_time, # [B]
        "labels": src_tokens, # [B x L]
        "attention_mask": attention_mask, # [B x L]
        "edges_index": torch.tensor(batch["edges_index"]),
        "edges_length": torch.tensor(batch["edges_length"]),
        "node_observed_masks": torch.tensor(batch["node_observed_masks"]),
        "layers": torch.tensor(batch["layers"])
    }
    return ret

def collate_properties(list_of_dict, keys, properties_dict, return_dict=None):
    if not return_dict:
        return_dict = {}
    for key in keys:
        if isinstance(list_of_dict[0][key], np.int64) or \
            isinstance(list_of_dict[0][key], np.float64) or\
                isinstance(list_of_dict[0][key], float) or isinstance(list_of_dict[0][key], int):
            return_dict[key] = torch.tensor([x[key] for x in list_of_dict])
        elif key in properties_dict:
            return_dict[key] = torch.tensor([properties_dict[key][x[key]] for x in list_of_dict])
        elif isinstance(list_of_dict[0][key], list):
            return_dict[key] = torch.tensor([list_of_dict[i][key] for i in range(len(list_of_dict))])
    return return_dict

def default_lm_collate_func(list_of_dict, batch_converter, padding_idx=None, properties_dict={}):
    # src_seqs = [("", remove_redundant_gaps(x["src_seq"])) for x in list_of_dict]
    src_seqs = []
    for x in list_of_dict:
        if isinstance(x["src_seq"], str):
            src_seqs.append(("", remove_redundant_gaps(x["src_seq"])))
        elif isinstance(x["src_seq"], list):
            src_seqs.append([("", y) for y in x["src_seq"]])


    # src_seqs = [("", remove_redundant_gaps(x["src_seq"])) for x in list_of_dict]
    _, _, src_tokens = batch_converter(src_seqs)

    if padding_idx is not None:
        attention_mask = (src_tokens != padding_idx).float()
    else:
        attention_mask = src_tokens.new_ones(src_tokens.size())
    
    ret = {
        "input_ids": src_tokens, # [B x L]
        # "input_time": src_time, # [B]
        "labels": src_tokens, # [B x L]
        "attention_mask": attention_mask # [B x L]
    }
    if "src_time" in list_of_dict[0]:
        src_time = torch.tensor([x["src_time"] if x["src_time"] is not None else 1.0 for x in list_of_dict])
        ret["input_time"] = src_time
        
    other_keys = [x for x in list_of_dict[0].keys() if x != "src_time" and x != "src_seq"]
    
    ret = collate_properties(list_of_dict, other_keys, properties_dict, return_dict=ret)
    
    # for key in other_keys:
    #     if isinstance(list_of_dict[0][key], np.int64) or \
    #         isinstance(list_of_dict[0][key], np.float64) or\
    #             isinstance(list_of_dict[0][key], float) or isinstance(list_of_dict[0][key], int):
    #         ret[key] = torch.tensor([x[key] for x in list_of_dict])
    #     elif key in properties_dict:
    #         ret[key] = torch.tensor([properties_dict[key][x[key]] for x in list_of_dict])
    # # print(ret)
    # # print(properties_dict)

    return ret

def multiprot_lm_collate_func(list_of_dict, batch_converter, padding_idx=None, properties_dict={}):
    src_seqs = [("", remove_redundant_gaps(x["src_seq"])) for x in list_of_dict]
    # print(len(src_seqs[0][1]))
    _, _, src_tokens = batch_converter(src_seqs)
    # print(src_tokens.size())
    # print(src_tokens[0][src_tokens[0] != padding_idx].size())
    # print(src_tokens[0][src_tokens[0] != padding_idx])
    # print(list(list_of_dict[0].keys()))
    # genes = remove_redundant_gaps["genes"]

    gene_dict = properties_dict["genes"]
    # print(gene_dict)
    gene_tokens = []
    for x in list_of_dict:
        gene_tok = [gene_dict[g] for g in x['gene_annotation']]
        # print(len(gene_tok))
        if batch_converter.alphabet.prepend_bos:
            gene_tok = [gene_dict["<bos>"]] + gene_tok
        if batch_converter.alphabet.append_eos:
            gene_tok = gene_tok + [gene_dict["<eos>"]]
        if len(gene_tok) < src_tokens.size(1):
            gene_tok = gene_tok + [gene_dict["<pad>"]] * (src_tokens.size(1) - len(gene_tok))
        # print(gene_tok)
        gene_tokens.append(gene_tok)
    gene_tokens = torch.tensor(gene_tokens)
    # print(gene_tokens.size())
    # gene_annotation = [[properties_dict["gene"][g] for g in x['gene_annotation']] for x in list_of_dict]

    # _, _, src_tokens = batch_converter(src_seqs)
    src_time = torch.tensor([x['src_time'] for x in list_of_dict])
    if padding_idx is not None:
        attention_mask = (src_tokens != padding_idx).float()
    else:
        attention_mask = src_tokens.new_ones(src_tokens.size())
    ret = {
        "input_ids": src_tokens, # [B x L]
        "input_time": src_time, # [B]
        "labels": src_tokens, # [B x L]
        "genes": gene_tokens,
        "attention_mask": attention_mask # [B x L]
    }
    other_keys = [x for x in list_of_dict[0].keys() if x != "src_time" and x != "src_seq"]

    ret = collate_properties(list_of_dict, other_keys, properties_dict, return_dict=ret)
    # for key in other_keys:
    #     if isinstance(list_of_dict[0][key], np.int64) or \
    #         isinstance(list_of_dict[0][key], np.float64) or\
    #             isinstance(list_of_dict[0][key], float) or isinstance(list_of_dict[0][key], int):
    #         ret[key] = torch.tensor([x[key] for x in list_of_dict])
    #     elif key in properties_dict:
    #         ret[key] = torch.tensor([properties_dict[key][x[key]] for x in list_of_dict])
    return ret

def default_msa2seq_collate_func(list_of_dict, src_batch_converter, tgt_batch_converter, padding_idx=None, properties_dict={}):
    # source is the MSA sequences, while the target is the reference sequences.
    
    if isinstance(list_of_dict[0]["tgt_seq"], list):
        tgt_seqs = [x["tgt_seq"] for x in list_of_dict]
        _, _, tgt_tokens = tgt_batch_converter(tgt_seqs, seq_encoded_list=tgt_seqs)
    else:
        tgt_seqs = [("", remove_redundant_gaps(x["tgt_seq"])) for x in list_of_dict]
        _, _, tgt_tokens = tgt_batch_converter(tgt_seqs)
    
    max_src_num = max([len(x["src_seq"]) for x in list_of_dict])
    src_seqs = [[("", y) for y in x["src_seq"]] for x in list_of_dict]
    _, _, src_tokens = src_batch_converter(src_seqs) # [B, M, L']
    # print(src_tokens[0])

    src_time = torch.zeros(len(list_of_dict), max_src_num)
    for i, x in enumerate(list_of_dict):
        src_time[i, :len(x['src_time'])] = torch.tensor(x['src_time'])
    tgt_time = torch.tensor([x['tgt_time'] for x in list_of_dict])
    
    if padding_idx is not None: # [B, L']
        attention_mask = (src_tokens[:, 0, :] != padding_idx).float()
    else:
        attention_mask = src_tokens[:, 0, :].new_ones(tgt_tokens.size())
    
    # source masks
    source_masks = torch.zeros(len(list_of_dict), max_src_num)
    for i, x in enumerate(list_of_dict):
        source_masks[i, :len(x["src_seq"])] = torch.tensor([1] * len(x["src_seq"]))
    # print(source_masks.sum(-1))
    # print(src_tokens)
    # tgt_isl_id = torch.tensor([int(x['tgt_id'].replace("EPI_ISL_", "").replace("EPI", "")) for x in list_of_dict])
    try:
        tgt_isl_id = torch.tensor([int("1" + x['tgt_id'].replace("EPI_ISL_", "").replace("EPI", "")) for x in list_of_dict])
    except Exception as e:
        tgt_isl_id = torch.tensor([x["tgt_index"] for x in list_of_dict])

    ret = {
        "input_ids": src_tokens, # [B x M x L]
        "input_time": src_time, # [B x M]
        "target_time": tgt_time, # [B]
        "labels": tgt_tokens, # [B x L]
        "attention_mask": attention_mask, # [B x L],
        "source_masks": source_masks, # [B x M],
        "tgt_isl_id": tgt_isl_id
    }

    # print(ret["target_time"].unsqueeze(1) - ret["input_time"])
    # print(ret["target_time"] - ret["input_time"][:,0])

    other_keys = [x for x in list_of_dict[0].keys() if "_time" not in x and "_seq" not in x]
    # print(other_keys)
    for key in other_keys:
        if "src_" in key and key != "src_index":
            if isinstance(list_of_dict[0][key][0], np.int64) or \
                isinstance(list_of_dict[0][key][0], np.float64) or\
                    isinstance(list_of_dict[0][key][0], float) or isinstance(list_of_dict[0][key][0], int):
                ret[key] = torch.zeros(len(list_of_dict), max_src_num)
                for i, x in enumerate(list_of_dict):
                    ret[key][i, :len(x[key])] = torch.tensor(x[key])
            elif key in properties_dict:
                # ret[key] = torch.tensor([properties_dict[key][x[key]] for x in list_of_dict])
                ret[key] = torch.zeros(len(list_of_dict), max_src_num)
                for i, x in enumerate(list_of_dict):
                    ret[key][i, :len(x[key])] = torch.tensor([properties_dict[key][y]] for y in x[key])
        else:
            if isinstance(list_of_dict[0][key], np.int64) or \
                isinstance(list_of_dict[0][key], np.float64) or\
                    isinstance(list_of_dict[0][key], float) or isinstance(list_of_dict[0][key], int):
                ret[key] = torch.tensor([x[key] for x in list_of_dict])
            elif key in properties_dict:
                ret[key] = torch.tensor([properties_dict[key][x[key]] for x in list_of_dict])
    # for key in ret:
    #     if isinstance(ret[key], torch.Tensor):
    #         print(key, ret[key].size())
    # print(time.time() - t1)
    # exit()
    return ret

def default_masked_lm_collate_func(list_of_dict, batch_converter, padding_idx=None):
    src_seqs = [("", remove_redundant_gaps(x["src_seq"])) for x in list_of_dict]
    _, _, src_tokens = batch_converter(src_seqs)
    src_time = torch.tensor([x['src_time'] for x in list_of_dict])
    if padding_idx is not None:
        attention_mask = (src_tokens != padding_idx).float()
    else:
        attention_mask = src_tokens.new_ones(src_tokens.size())
    ret = {
        "input_ids": src_tokens, # [B x L]
        "input_time": src_time, # [B]
        "labels": src_tokens, # [B x L]
        "attention_mask": attention_mask # [B x L]
    }
    other_keys = [x for x in list_of_dict[0].keys() if x != "src_time" and x != "src_seq"]
    for key in other_keys:
        if isinstance(list_of_dict[0][key], np.int64) or \
            isinstance(list_of_dict[0][key], np.float64) or\
                isinstance(list_of_dict[0][key], float) or isinstance(list_of_dict[0][key], int):
            ret[key] = torch.tensor([x[key] for x in list_of_dict])
    return ret

def msa_downsampling(seqs, num_seqs):
    if len(seqs) <= num_seqs:
        return seqs
    
    arr = np.arange(len(seqs))
    np.random.shuffle(arr)
    arr = arr[:num_seqs]
    return [seqs[i] for i in arr]

def default_msa_collate_func(list_of_dict, batch_converter, padding_idx=None, max_msa_seq_num=128):    
    # tgt_seqs = [x["tgt_seq"] for x in list_of_dict]
    # NOTE: remove extra gaps in target tokens.
    tgt_seqs_flattern = [[(y[0], remove_redundant_gaps(y[1]))] for x in list_of_dict for y in x["tgt_seq"]]
    _, _, tgt_tokens = batch_converter(tgt_seqs_flattern)
    tgt_tokens = tgt_tokens.view(len(list_of_dict), -1, tgt_tokens.size(-1)) # [batch_size x msa_seq_num x max_seq_length]
    # print(np.reshape(np.asarray(a), (3, 15)))
    tgt_time = torch.tensor([x['tgt_time'] for x in list_of_dict]) # [batch_size]
    tgt_msa_masks = torch.tensor(
        [[1] * len(x["tgt_seq"]) + [0] * (tgt_tokens.size(1) - len(x["tgt_seq"]))  for x in list_of_dict],
        dtype=torch.bool
        )
    # print(tgt_time)
    # print(tgt_tokens.size(), tgt_time.size())

    # build sources
    src_seqs = [] # [x["src_seq"][0] for x in list_of_dict] # len(src_seqs[i]) <= window size, src_seqs[i][j]: blocks in each time
    src_times = []
    src_msa_masks = [] # [B x T x max_msa_num]
    for x in list_of_dict:
        _, _, tokens = batch_converter(x["src_seq"][0]) # [T x M x L]
        src_seqs.append(tokens)
        src_times.append(x["src_time"][0])
        src_msa_masks.append(torch.tensor([[1] * len(y) + [0] * (tokens.size(1) - len(y)) for y in x["src_seq"][0]])) # [T x M]
        if len(x["src_seq"]) == 2:
            _, _, tokens = batch_converter(x["src_seq"][1]) # [T x M x L]
            src_seqs.append(tokens)
            src_times.append(x["src_time"][1])
            src_msa_masks.append(torch.tensor([[1] * len(y) + [0] * (tokens.size(1) - len(y)) for y in x["src_seq"][1]]))

    src_seqs_tensor = torch.zeros(len(src_seqs), max([x.size(0) for x in src_seqs]), max([x.size(1) for x in src_seqs]), max([x.size(-1) for x in src_seqs]), dtype=torch.long)
    src_seqs_tensor.fill_(padding_idx)
    src_times_tensor = torch.zeros(len(src_times), max([len(x) for x in src_times]))
    src_times_tensor.fill_(-100)
    src_msa_masks_tensor = torch.zeros(len(src_msa_masks), max([x.size(0) for x in src_msa_masks]), max([x.size(1) for x in src_msa_masks]))
    # print("src_times_tensor.size()", src_times_tensor.size())

    for i, tokens in enumerate(src_seqs):
        src_seqs_tensor[i, :tokens.size(0), :tokens.size(1), :tokens.size(2)] = tokens
    for i, time in enumerate(src_times):
        src_times_tensor[i, :len(time)] = torch.tensor(time)
    for i, masks in enumerate(src_msa_masks):
        src_msa_masks_tensor[i, :masks.size(0), :masks.size(1)] = masks

    direction_num = len(list_of_dict[0]["src_seq"])
    # print(direction_num)
    src_seqs_tensor = src_seqs_tensor.view(-1, direction_num, src_seqs_tensor.size(-3), src_seqs_tensor.size(-2), src_seqs_tensor.size(-1)) # [batch, direction_num, window_size, msa_num, seq_len]
    # print("src_times_tensor.size()", src_times_tensor.size())
    src_times_tensor = src_times_tensor.view(-1, direction_num, src_times_tensor.size(-1)) # [batch, direction_num, window_size]
    src_msa_masks_tensor = src_msa_masks_tensor.view(-1, direction_num, src_msa_masks_tensor.size(-2), src_msa_masks_tensor.size(-1))
    # print("src_times_tensor.size()", src_times_tensor.size())
    # print(src_times_tensor)

    ret = {
        "src_tokens": src_seqs_tensor, 
        "src_time": src_times_tensor, 
        "src_msa_masks": src_msa_masks_tensor,
        "tgt_tokens": tgt_tokens,
        "tgt_time": tgt_time,
        "tgt_msa_masks": tgt_msa_masks
    }
    return ret

# def get_isl_id(list_of_dict, key="src_id", index_key="index"):
#     try:
#         tgt_isl_id = torch.tensor([int("1" + x[key].replace("EPI_ISL_", "").replace("EPI", "")) for x in list_of_dict])
#     except Exception as e:
#         tgt_isl_id = torch.tensor([x[index_key] for x in list_of_dict])

def default_seq2seq_collate_func(
    list_of_dict, 
    src_batch_converter, 
    tgt_batch_converter, 
    src_padding_idx=None, 
    tgt_padding_idx=None,
    properties_dict={}, 
    remove_gaps=True,
    remove_gaps_from_source=True,
    *args, **kwargs):

    if remove_gaps:
        src_seqs = [("", remove_redundant_gaps(x["src_seq"])) for x in list_of_dict]
        tgt_seqs = [("", remove_redundant_gaps(x.get("tgt_seq", ""))) for x in list_of_dict]
    else:
        src_seqs, tgt_seqs = [], []
        for x in list_of_dict:
            src_seq, tgt_seq = remove_redundant_gaps(x["src_seq"], x["tgt_seq"], remove_gaps_from_source=remove_gaps_from_source)
            src_seqs.append(("", src_seq))
            tgt_seqs.append(("", tgt_seq))
    # else:
    #     src_seqs = [("", x["src_seq"]) for x in list_of_dict]
    #     tgt_seqs = [("", x.get("tgt_seq", "")) for x in list_of_dict]
        
    # src_seqs = [("", remove_redundant_gaps(x["src_seq"])) for x in list_of_dict]
    # tgt_seqs = [("", remove_redundant_gaps(x.get("tgt_seq", ""))) for x in list_of_dict]
    _, _, src_tokens = src_batch_converter(src_seqs)
    _, _, tgt_tokens = tgt_batch_converter(tgt_seqs)
    # src_time = torch.tensor([x["src_time"] if x["src_time"] is not None else 1.0 for x in list_of_dict])
    src_time = torch.tensor([x['src_time'] if x["src_time"] is not None else 0.0 for x in list_of_dict])
    tgt_time = torch.tensor([x.get("tgt_time", 0.0) if x.get("tgt_time", 0.0) is not None else 0.0 for x in list_of_dict])
    if src_padding_idx is not None:
        attention_mask = (src_tokens != src_padding_idx).float()
    else:
        attention_mask = src_tokens.new_ones(src_tokens.size())
    
    # try:
    #     tgt_isl_id = torch.tensor([int("1" + x['tgt_id'].replace("EPI_ISL_", "").replace("EPI", "")) for x in list_of_dict])
    # except Exception as e:
    #     tgt_isl_id = torch.tensor([x["tgt_index"] for x in list_of_dict])
    
    # try:
    #     src_isl_id = torch.tensor([int("1" + x['src_id'].replace("EPI_ISL_", "").replace("EPI", "")) for x in list_of_dict])
    # except Exception as e:
    #     src_isl_id = torch.tensor([x["src_index"] for x in list_of_dict])

    ret = {
        "input_ids": src_tokens, # [B x L]
        "labels": tgt_tokens, # [B x L]
        "attention_mask": attention_mask,# [B x L],
        "source_time": src_time,
        "target_time": tgt_time,
        # "tgt_isl_id": tgt_isl_id,
        # "src_isl_id": src_isl_id
    }
    other_keys = [x for x in list_of_dict[0].keys() if x not in ("src_seq", "tgt_seq", "src_time", "tgt_time")]
    ret = collate_properties(list_of_dict, other_keys, properties_dict, return_dict=ret)

    return ret

def structure2seq_collate_func(list_of_dict, coords, batch_converter, padding_idx=None, properties_dict={}, remove_gaps=True):
    if remove_gaps:
        batch = [(coords, None, remove_redundant_gaps(x["src_seq"])) for x in list_of_dict]
    else:
        batch = [(coords, None, x["src_seq"]) for x in list_of_dict]
    
    # print(batch)
    coords, confidence, strs, tokens, padding_mask = batch_converter(batch)
    
    src_time = torch.tensor([x['src_time'] if x['src_time'] is not None else 0.0 for x in list_of_dict])
    
    ret = {
        "coords": coords, # [B x L]
        "confidence": confidence, # [B x L]
        "tokens": tokens,# [B x L],
        "src_time": src_time,
        "padding_mask": padding_mask
    }
    other_keys = [x for x in list_of_dict[0].keys() if x not in ("src_seq", "tgt_seq", "src_time", "tgt_time")]
    ret = collate_properties(list_of_dict, other_keys, properties_dict, return_dict=ret)
    return ret

def structure_pairs_collate_func(list_of_dict, batch_converter, padding_idx=None, properties_dict={}, remove_gaps=True):
    # print(list_of_dict[0])
    # if remove_gaps:
    #     batch = [(torch.cat([x["coords1"], x["coords2"]]), None, remove_redundant_gaps(x["src_seq"])) for x in list_of_dict]
    # else:
    batch = [(np.concatenate([x["coords1"], x["coords2"]], axis=0), None, x["seq1"] + x["seq2"]) for x in list_of_dict]
    
    coords, confidence, strs, tokens, padding_mask = batch_converter(batch)
    # print(confidence)
    # print(tokens.size())
    
    # src_time = torch.tensor([x['src_time'] if x['src_time'] is not None else 0.0 for x in list_of_dict])
    
    ret = {
        "coords": coords, # [B x L]
        "confidence": confidence, # [B x L]
        "tokens": tokens,# [B x L],
        # "src_time": src_time,
        "padding_mask": padding_mask,
        # "chain_index": [0] * len() # TODO: Chain label?
    }
    other_keys = [x for x in list_of_dict[0].keys() if x not in ( "src_seq", "tgt_seq", "src_time", "tgt_time")]
    ret = collate_properties(list_of_dict, other_keys, properties_dict, return_dict=ret)
    return ret

def customized_collate_func(list_of_dict, batch_converter, remove_gaps=True, aligned=True):
    src_seqs = [] # [("", x["src_seq"]) for x in list_of_dict]
    tgt_seqs = [] # [("", x["tgt_seq"]) for x in list_of_dict]
    for x in list_of_dict:
        if "tgt_seq" in x:
            if aligned:
                if remove_gaps:
                    new_src, new_tgt = remove_redundant_gaps(x["src_seq"], x["tgt_seq"])
                else:
                    new_src, new_tgt = x["src_seq"], x["tgt_seq"]
                tgt_seqs.append(("", new_tgt))
                src_seqs.append(("", new_src))
                assert len(new_src) == len(new_tgt)
            else:
                # print(len(x["src_seq"]))
                new_src, new_tgt = remove_redundant_gaps(x["src_seq"]), remove_redundant_gaps(x["tgt_seq"])
                new_src, new_tgt, src_start_pos, src_end_pos = align_seqs(new_src, new_tgt, remove_gaps_in_seq1=True) # Don't keep any gaps in seq1
                # print(src_start_pos, src_end_pos, len(new_src))
                assert len(new_src) == len(new_tgt)
                assert len(new_src) == (src_end_pos - src_start_pos + 1)
                if x.get("src_enc_out", None) is not None:
                    x["src_enc_out"] = x["src_enc_out"][src_start_pos:src_end_pos+1]
                tgt_seqs.append(("", new_tgt))
                src_seqs.append(("", new_src))
        else:
            if remove_gaps:
                new_src = remove_redundant_gaps(x["src_seq"])
            else:
                new_src = x["src_seq"]
            src_seqs.append(("", new_src))

    if len(tgt_seqs) > 0:
        _, _, tgt_tokens = batch_converter(tgt_seqs)
        _, _, src_tokens = batch_converter(src_seqs)
        tgt_time = torch.tensor([x['tgt_time'] for x in list_of_dict])
        src_time = torch.tensor([x['src_time'] for x in list_of_dict])
        ret = {
            "src_token": src_tokens,
            "tgt_token": tgt_tokens,
            "src_time": src_time,
            "tgt_time": tgt_time
        }
    else:
        _, _, src_tokens = batch_converter(src_seqs)
        src_time = torch.tensor([x['src_time'] for x in list_of_dict])
        ret = {
            "src_token": src_tokens,
            "src_time": src_time,
        }
    if len(list_of_dict) > 0 and list_of_dict[0].get("src_enc_out", None) is not None:
        src_enc_out = torch.zeros(len(list_of_dict), max([x["src_enc_out"].size(0) for x in list_of_dict]), list_of_dict[0]["src_enc_out"].size(1))
        for i, x in enumerate(list_of_dict):
            src_enc_out[i, : x["src_enc_out"].size(0)] = x["src_enc_out"]
        ret["src_enc_out"] = src_enc_out
    
    return ret
    # assert src_tokens.size() == tgt_tokens.size(), "src_tokens.size() == %s, tgt_tokens.size() == %s" % (json.dumps(src_tokens.size()), json.dumps(tgt_tokens.size()))
    # if src_tokens.size() != tgt_tokens.size():
        # print(src_tokens.size(), tgt_tokens.size())

def remove_redundant_gaps(src_seq, tgt_seq=None, remove_gaps_from_source=False):
    if tgt_seq is not None: 
        if remove_gaps_from_source:
            new_src, new_tgt = [], []
            for src_c, tgt_c in zip(src_seq, tgt_seq):
                if src_c != "-":
                    new_src.append(src_c)
                    new_tgt.append(tgt_c)
            return "".join(new_src), "".join(new_tgt)
        else: # Remove double gaps
            new_src_seq = np.asarray(list(src_seq))
            new_tgt_seq = np.asarray(list(tgt_seq))
            gaps = np.logical_and((new_src_seq == "-"), (new_tgt_seq == "-"))
            return "".join(new_src_seq[~gaps]), "".join(new_tgt_seq[~gaps])
    else:
        if len(src_seq) == 0:
            return src_seq
        new_src_seq = np.asarray(list(src_seq))
        return "".join(new_src_seq[new_src_seq != "-"])

def align_seqs(seq1, seq2, remove_oov=True, remove_gaps_in_seq1=False, remove_gaps_in_seq2=False):
    if remove_oov:
        seq1 = "".join([x for x in seq1 if x in "LAGVSERTIDPKQNFYMHWCXBUZO"])
        seq2 = "".join([x for x in seq2 if x in "LAGVSERTIDPKQNFYMHWCXBUZO"])

    alignment, score, start_end_positions = local_pairwise_align_ssw(
        Protein(seq1),
        Protein(seq2),
        substitution_matrix=make_identity_substitution_matrix(5, -1, alphabet="LAGVSERTIDPKQNFYMHWCXBUZO*")
    )

    aseq1 = str(alignment[0])
    aseq2 = str(alignment[1])

    new_seq1, new_seq2 = [], []
    for c1, c2 in zip(aseq1, aseq2):
        if remove_gaps_in_seq1 and c1 == "-": # If there is a gap in seq1, ignore.
            continue
        if remove_gaps_in_seq2 and c2 == "-": # If there is a gap in seq2, ignore.
            continue
        
        new_seq1.append(c1)
        new_seq2.append(c2)

    return "".join(new_seq1), "".join(new_seq2), start_end_positions[0][0], start_end_positions[0][1]

def get_query_to_targets(targets):
    q2t = defaultdict(list)
    for id, seq, desc in targets:
        descs = {item.split("=")[0]: item.split("=")[1] for item in desc.split()[1].split("|")}
        qid = descs["src"]
        if "src_start" in descs:
            q2t[qid].append((id, seq, int(descs["src_start"]), int(descs["src_end"], )))
        else:
            q2t[qid].append((id, seq, None, None)) # all seqs
    return q2t

# def align_seqs(seq1, seq2, remove_oov=True):
#     if remove_oov:
#         seq1 = "".join([x for x in seq1 if x in "LAGVSERTIDPKQNFYMHWCXBUZO"])
#         seq2 = "".join([x for x in seq2 if x in "LAGVSERTIDPKQNFYMHWCXBUZO"])

#     alignment, score, start_end_positions = local_pairwise_align_ssw(
#         Protein(seq1),
#         Protein(seq2),
#         substitution_matrix=make_identity_substitution_matrix(5, -1, alphabet="LAGVSERTIDPKQNFYMHWCXBUZO*")
#     )

#     aseq1 = str(alignment[0])
#     aseq2 = str(alignment[1])

#     new_seq1, new_seq2 = [], []
#     for c1, c2 in zip(aseq1, aseq2):
#         if c1 != "-":
#             new_seq1.append(c1)
#             new_seq2.append(c2)

#     return "".join(new_seq1), "".join(new_seq2), start_end_positions[0][0], start_end_positions[0][1]

def remove_gaps_from_source(src_tokens, tgt_tokens, gap_idx, pad_idx): # self.alphabet.gap_idx
    # print(src_tokens.size())
    assert src_tokens.size() == tgt_tokens.size()
    gaps_masks = (src_tokens == gap_idx)
    # print(gaps_masks)
    # print(gaps_masks.sum(-1))
    indices = torch.arange(src_tokens.size(1), device=src_tokens.device).unsqueeze(0).repeat(src_tokens.size(0), 1)
    # print(indices.size(), src_tokens.size())
    indices[gaps_masks] = gaps_masks.size(1) # .masked_fill_(gaps_masks, gaps_masks.size(1))
    indices_sorted, indices_sorted_indices = torch.sort(indices, dim=-1)
    # print("indices_sorted_indices", indices_sorted_indices)
    src_tokens_gaps_removed = torch.gather(src_tokens, 1, indices_sorted_indices)
    tgt_tokens_gaps_removed = torch.gather(tgt_tokens, 1, indices_sorted_indices)
    # print(src_tokens_gaps_removed)
    gaps_masks = (src_tokens_gaps_removed == gap_idx)
    # print(gaps_masks)
    # print(gaps_masks.sum(-1))
    src_tokens_gaps_removed[gaps_masks] = pad_idx # self.alphabet.pad()
    tgt_tokens_gaps_removed[gaps_masks] = pad_idx
    # print(src_tokens_gaps_removed)
    # print(tgt_tokens_gaps_removed)
    max_length = (src_tokens_gaps_removed != pad_idx).sum(-1).max().item()
    return src_tokens_gaps_removed[:, :max_length], tgt_tokens_gaps_removed[:, :max_length]


def build_fake_pairwise_dataset(src_dataset, tgt_dataset, src_index=None, sample_src_size=5):
    if src_index is None:
        src_times = torch.tensor([x["src_time"] for x in src_dataset])
        tgt_times =  torch.tensor([x["src_time"] for x in tgt_dataset])
        src_prob = torch.tensor([x["freq"] for x in src_dataset])
        time_diff = tgt_times.unsqueeze(1) - src_times.unsqueeze(0)
        batch_size = 2000
        start = 0
        end = start + batch_size
        src_dataset_new = []
        src_index = []
        tgt_index = []
        sample_probs = []
        for start in range(time_diff.size(0))[::batch_size]:
        # while (start < time_diff.size(0)):
            batch = time_diff[start:end]
            # print(batch)
            time_prob = calc_time_sample_prob(batch).view(batch.size()) # [B, N]
            # sample_prob = time_prob * src_prob.unsqueeze(0)
            sample_prob = src_prob.unsqueeze(0).repeat(batch.size(0), 1)
            # print(torch.sum(torch.abs(sample_prob.sum(-1) - 1)))
            # exit()
            # sample_prob = sample_prob / sample_prob.sum(-1, keepdims=True)
            # print(sample_prob.size(), sample_prob.sum(-1))
            samples = torch.multinomial(sample_prob, sample_src_size, replacement=True) # .squeeze(-1) [B, 1]

            start += batch_size
            end = start + batch_size
            src_index.append(samples)
            # print(src_prob.size(), samples.size())
            # print(src_prob[samples].size())
            # exit()
            sample_probs.append(src_prob[samples]) # torch.gather(src_prob, 1, samples)
            # tgt_index.append()
            # src_dataset_new.extend([src_dataset[x.item()] for x in samples.view(-1)])
        src_index = torch.cat(src_index, dim=0)
        sample_probs = torch.cat(sample_probs, dim=0)
        tgt_index = torch.arange(len(tgt_dataset)).unsqueeze(1).repeat(1, sample_src_size)
        assert src_index.size(0) == len(tgt_times)
    

    # src_dataset_new = [src_dataset[x.item()] for x in src_index.view(-1)]
    # tgt_dataset_new = [tgt_dataset[i] for i in range(len(tgt_dataset)) for _ in range(src_index.size(-1))]
    # new_datasets = []
    # for src, tgt in zip(src_dataset_new, tgt_dataset_new):
    #     src["tgt_seq"] = tgt["src_seq"]
    #     src["tgt_time"] = tgt["src_time"]
    #     new_datasets.append(src)

    return src_index, tgt_index, sample_probs

def build_fake_pairwise_dataset_old(src_dataset, tgt_dataset, src_index=None, sample_src_size=5):
    if src_index is None:
        src_times = torch.tensor([x["src_time"] for x in src_dataset])
        tgt_times =  torch.tensor([x["src_time"] for x in tgt_dataset])
        # print(len(src_times), len(tgt_times))
        # print(src_times)
        # print(tgt_times)
        time_diff = tgt_times.unsqueeze(1) - src_times.unsqueeze(0)
        # print(time_diff.size())
        # print(time_diff)
        batch_size = 2000
        start = 0
        end = start + batch_size
        src_dataset_new = []
        src_index = []
        for start in range(time_diff.size(0))[::batch_size]:
        # while (start < time_diff.size(0)):
            batch = time_diff[start:end]
            # print(batch.size())
            sample_probs = calc_time_sample_prob(batch).view(batch.size())
            samples = torch.multinomial(sample_probs, sample_src_size) # .squeeze(-1)
            # print(sample_probs.size())
            # print(sample_probs)
            # print(samples.size())
            start += batch_size
            end = start + batch_size
            src_index.append(samples)
            # src_dataset_new.extend([src_dataset[x.item()] for x in samples.view(-1)])
        src_index = torch.cat(src_index, dim=0)
        assert src_index.size(0) == len(tgt_times)
    

    src_dataset_new = [src_dataset[x.item()] for x in src_index.view(-1)]
    tgt_dataset_new = [tgt_dataset[i] for i in range(len(tgt_dataset)) for _ in range(src_index.size(-1))]
    new_datasets = []
    for src, tgt in zip(src_dataset_new, tgt_dataset_new):
        src["tgt_seq"] = tgt["src_seq"]
        src["tgt_time"] = tgt["src_time"]
        new_datasets.append(src)

    return new_datasets, src_index

def build_fake_multisource_pairwise_dataset(src_dataset, tgt_dataset, src_num, ignore_time=False, src_index=None, sample_src_size=5):
    if src_index is None:
        src_times = torch.tensor([x["src_time"] for x in src_dataset])
        tgt_times =  torch.tensor([x["src_time"] for x in tgt_dataset])
        # print(len(src_times), len(tgt_times))
        # print(src_times)
        # print(tgt_times)
        time_diff = tgt_times.unsqueeze(1) - src_times.unsqueeze(0)
        # print(time_diff.size())
        # print(time_diff)
        batch_size = 2000
        start = 0
        end = start + batch_size
        src_dataset_new = []
        src_index = []
        for start in range(time_diff.size(0))[::batch_size]:
        # while (start < time_diff.size(0)):
            batch = time_diff[start:end]
            # print(batch.size())
            if ignore_time:
                sample_probs = batch.new_ones(batch.size())
            else:
                sample_probs = calc_time_sample_prob(batch).view(batch.size())
            samples = torch.multinomial(sample_probs, sample_src_size * src_num) # .squeeze(-1)
            # print(sample_probs.size())
            # print(sample_probs)
            # print(samples.size())
            start += batch_size
            end = start + batch_size
            src_index.append(samples)
            # src_dataset_new.extend([src_dataset[x.item()] for x in samples.view(-1)])
        src_index = torch.cat(src_index, dim=0)
        assert src_index.size(0) == len(tgt_times)
    
    new_datasets = []
    # print(src_index.size())
    src_index = src_index.view(-1, sample_src_size, src_num)
    # print(src_index.size())
    for i in range(sample_src_size):
        new_dataset = []
        src_dataset_new = [[src_dataset[y.item()] for y in x] for x in src_index[:, i]]
        # print(len(src_dataset_new))
        # tgt_dataset_new = [tgt_dataset[i] for i in range(len(tgt_dataset))]
        for src, tgt in zip(src_dataset_new, tgt_dataset):
            # print("src", src)
            # src["tgt_seq"] = tgt["src_seq"]
            # src["tgt_time"] = tgt["src_time"]
            new_dataset.append({"src": src, "tgt_seq": tgt["src_seq"], "tgt_time": tgt["src_time"]})
        new_datasets.append(new_dataset)
    
    # print(len(new_datasets), len(new_datasets[0]), len(new_datasets[0][0]["src"]))
    # print(new_datasets[0][0])
    return new_datasets, src_index

def build_fake_multisource_batch_dataset(src_dataset, tgt_dataset, batch_size, src_num, ignore_time=False, sort_target=False):
    assert src_num < batch_size, "src_num< batch_size"
    batch_size = batch_size - src_num

    if sort_target:
        # print(tgt_dataset[0])
        tgt_dataset = sorted(tgt_dataset, key=lambda x: x["src_time"])
        # print(tgt_dataset[0])

    src_times = torch.tensor([x["src_time"] for x in src_dataset])
    tgt_times =  torch.tensor([x["src_time"] for x in tgt_dataset])
    # print(len(src_times), len(tgt_times))
    # print(src_times)
    # print(tgt_times)
    time_diff = tgt_times.unsqueeze(1) - src_times.unsqueeze(0)
    # print(time_diff.size())
    # print(time_diff)
    start = 0
    end = start + batch_size
    src_dataset_new = []
    src_index = []
    tgt_index = []
    for start in range(time_diff.size(0))[::batch_size]:
        batch = time_diff[start].unsqueeze(0)
        if ignore_time:
            sample_probs = batch.new_ones(batch.size())
        else:
            sample_probs = calc_time_sample_prob(batch).view(batch.size())
        samples = torch.multinomial(sample_probs, src_num) # .squeeze(-1)
        tgt_index.append(torch.arange(start, min(end, len(time_diff))))
        start += batch_size
        end = start + batch_size
        src_index.append(samples)
        # src_dataset_new.extend([src_dataset[x.item()] for x in samples.view(-1)])
    # print(len(src_index), src_index[0], src_index[-1])
    # print(len(tgt_index), tgt_index[0], tgt_index[-1])
    # print(len(tgt_dataset), batch_size)
    batched_new_dataset = []
    for i in range(len(src_index)):
        batched_new_dataset.append([src_dataset[j] for j in src_index[i][0]] + \
            [tgt_dataset[j] for j in tgt_index[i]])
    # print(len(batched_new_dataset))
    # print(len(batched_new_dataset[0]))
    # print(batched_new_dataset[0][0])
    return batched_new_dataset
    # exit()
    # src_index = torch.stack(src_index, dim=0) # [N, B, src_number]
    # assert src_index.size(0) == len(tgt_times)
    
    new_datasets = []
    for i in range(sample_src_size):
        new_dataset = []
        src_dataset_new = [[src_dataset[y.item()] for y in x] for x in src_index[:, i]]
        # print(len(src_dataset_new))
        # tgt_dataset_new = [tgt_dataset[i] for i in range(len(tgt_dataset))]
        for src, tgt in zip(src_dataset_new, tgt_dataset):
            # print("src", src)
            # src["tgt_seq"] = tgt["src_seq"]
            # src["tgt_time"] = tgt["src_time"]
            new_dataset.append({"src": src, "tgt_seq": tgt["src_seq"], "tgt_time": tgt["src_time"]})
        new_datasets.append(new_dataset)
    
    # print(len(new_datasets), len(new_datasets[0]), len(new_datasets[0][0]["src"]))
    # print(new_datasets[0][0])
    return new_datasets, src_index

def build_fake_pairwise_dataset_2(dataset, time_diff=1, src_index=None, sample_src_size=5):
    if src_index is None:
        # new_src_datasets, new_tgt_datasets = [], []
        src_index = []
        tgt_index = []
        time2index = defaultdict(list)
        for i in range(len(dataset)):
            time2index[dataset[i]["src_time"]].append(i)
        # time2index = sorted(time2index.items(), key=lambda x: x[0])
        sorted_time = list(time2index.keys())
        sorted_time.sort()
        # print(sorted_time)
        # print(time2index[-1][0])
        for tgt_time in sorted_time[1:]:
            src_time = tgt_time - time_diff
            if src_time not in time2index:
                continue
            # print(src_time, tgt_time)
            # print(len(time2index[src_time]), len(time2index[tgt_time]))
            samples = torch.multinomial(torch.ones(len(time2index[src_time])), sample_src_size * len(time2index[tgt_time]), replacement=True)
            # print(samples.size())
            sampled_index = torch.tensor([time2index[src_time][x.item()] for x in samples])
            sampled_index = sampled_index.view(-1, sample_src_size)
            # print(sampled_index.size())
            # print([time2index[src_time][x.item()] for x in samples])
            # print(time2index[src_time])
            src_index.append(sampled_index)
            # print(len(time2index[tgt_time]))
            tgt_index.append(torch.tensor(time2index[tgt_time]).unsqueeze(-1).repeat(1, sample_src_size)) # [sample_num, sample_src_size]
            # print(torch.tensor(time2index[tgt_time]).unsqueeze(-1).repeat(1, sample_src_size).size())
            # print(sampled_index.size())
        src_index = torch.cat(src_index, dim=0)
        tgt_index = torch.cat(tgt_index, dim=0)
        # print(src_index.size())
    
    src_dataset_new = [dataset[x.item()] for x in src_index.view(-1)]
    tgt_dataset_new = [dataset[x.item()] for x in tgt_index.view(-1)]
    # tgt_dataset_new = [dataset[i] for tgt_time in sorted_time[1:] for i in time2index[tgt_time]]
    # print([x["src_time"] for x in tgt_dataset_new])

    # print(len(src_dataset_new))
    # print(len(tgt_dataset_new))

    new_datasets = []
    for src, tgt in zip(src_dataset_new, tgt_dataset_new):
        src["tgt_seq"] = tgt["src_seq"]
        src["tgt_time"] = tgt["src_time"]
        new_datasets.append(src)

    return new_datasets, src_index


def normalize_evolution_time_by_edit_distance(dataset):
    edit_distance = np.asarray([x["src_dis"] for x in dataset], dtype=np.float)
    evolution_time_diff = np.asarray([x["tgt_time"] - x["src_time"] for x in dataset], dtype=np.float)
    offset, rate = np.polyfit(edit_distance, evolution_time_diff, 1)
    offset = 0.0
    dataset.add_attributes("modified_evolution_time", [offset + x * rate for x in edit_distance])


def split_dataset_by_sources_or_targets(src_dataset, full_dataset, valid_size):
    # if by_target:
    #     src_dataset = tgt_dataset

    # Non-overlap sources
    # print(src_dataset[0][0])
    source_indexes = list(set([x[0] for x in src_dataset]))
    source_indexes.sort()
    valid_size = int(len(source_indexes) * valid_size)
    train_size = len(source_indexes) - valid_size
    # print(len(source_indexes), train_size, valid_size)
    reverse_map = defaultdict(list)
    for i, x in enumerate(src_dataset):
        reverse_map[x[0]].append(i)
    
    random_index = torch.randperm(len(source_indexes))
    train_ids = [source_indexes[i] for i in random_index[:train_size]]
    valid_ids = [source_indexes[i] for i in random_index[train_size:]]
    train_dataset, val_dataset = [], []
    # print(train_ids[0])
    # print(valid_ids[0])
    for acc_id in train_ids:
        train_dataset.extend([full_dataset[i] for i in reverse_map[acc_id]])
    for acc_id in valid_ids:
        val_dataset.extend([full_dataset[i] for i in reverse_map[acc_id]])
    # print(train_dataset[0])
    # if by_target:
    #     print(set([x["tgt_seq"] for x in train_dataset]) & set([x["tgt_seq"] for x in val_dataset]))
    # else:
    #     print(set([x["src_seq"] for x in train_dataset]) & set([x["src_seq"] for x in val_dataset]))
    return train_dataset, val_dataset


def fit_rms(ref_c, c):
    # move geometric center to the origin
    ref_trans = np.average(ref_c, axis=0)
    ref_c = ref_c - ref_trans
    # print(ref_c.mean(axis=0))
    c_trans = np.average(c, axis=0)
    c = c - c_trans
    # print(c.mean(axis=0))

    # covariance matrix
    C = np.dot(c.T, ref_c)

    # Singular Value Decomposition
    (r1, s, r2) = np.linalg.svd(C)

    # compute sign (remove mirroring)
    if np.linalg.det(C) < 0:
        r2[2,:] *= -1.0
    U = np.dot(r1, r2)
    return (c_trans, U, ref_trans)


def calc_rmsd(xyz1, xyz2):
    c_trans, U, ref_trans = fit_rms(xyz1, xyz2)
    new_c2 = np.dot(xyz2 - c_trans, U) # + ref_trans
    new_c1 = xyz1 - ref_trans
    rmsd = np.sqrt( np.average( np.sum( ( new_c1 - new_c2 )**2, axis=1 ) ) )
    return rmsd, new_c1, new_c2


class RMSDcalculator:
    def __init__(self, xyz1, xyz2):
        # The first one the reference structure
        self.calc_rmsd(xyz1, xyz2)

    def apply(self, xyz):
        xyz = xyz - xyz.mean(axis=0, keepdims=True)
        return np.dot(xyz, self.U)
        
    def calc_rmsd(self, xyz1, xyz2):
        c_trans, U, ref_trans = fit_rms(xyz1, xyz2)
        new_c2 = np.dot(xyz2 - c_trans, U) # + ref_trans
        new_c1 = xyz1 - ref_trans
        rmsd = np.sqrt( np.average( np.sum( ( new_c1 - new_c2 )**2, axis=1 ) ) )
        self.U = U
        self.ref_trans = ref_trans
        return rmsd, new_c1, new_c2

    # def get_aligned_coord(self, xyz, name=None):
    #     new_c2 = deepcopy(xyz)

    #     for atom in new_c2:
    #         atom.x, atom.y, atom.z = np.dot(np.array([atom.x, atom.y, atom.z]) - self.c_trans, self.U) + self.ref_trans
    #     return new_c2