# Modified from DPLM-2:
#    DPLM-2: https://github.com/bytedance/dplm/blob/main/src/byprot/utils/scaffold_utils.py


import os
import random
from copy import deepcopy
from pprint import pprint

import esm
import esm.inverse_folding
import torch
import numpy as np
import pandas as pd

from byprot import utils
from byprot.datamodules.dataset.data_utils import PDBDataProcessor

single_res = ["1qjg"]

# how long is the scaffold on the left side of the motif
scaffold_left = {
    "1PRW": [5, 20],
    "1BCF": [8, 15],
    "5TPN": [10, 40],
    "5IUS": [0, 30],
    "3IXT": [10, 40],
    "5YUI": [5, 30],
    "1QJG": [10, 20],
    "1YCR": [10, 40],
    "2KL8": [0, 0],
    "7MRX_60": [0, 38],
    "7MRX_85": [0, 63],
    "7MRX_128": [0, 122],
    "4JHW": [10, 25],
    "4ZYP": [10, 40],
    "5WN9": [10, 40],
    "5TRV_short": [0, 35],
    "5TRV_med": [0, 65],
    "5TRV_long": [0, 95],
    "6E6R_short": [0, 35],
    "6E6R_med": [0, 65],
    "6E6R_long": [0, 95],
    "6EXZ_short": [0, 35],
    "6EXZ_med": [0, 65],
    "6EXZ_long": [0, 95],
}

# how long is the scaffold on the right side of the motif
scaffold_right = {
    "1PRW": [5, 20],
    "1BCF": [8, 15],
    "5TPN": [10, 40],
    "5IUS": [0, 30],
    "3IXT": [10, 40],
    "5YUI": [10, 30],
    "1QJG": [10, 20],
    "1YCR": [10, 40],
    "2KL8": [0, 0],
    "7MRX_60": [0, 38],
    "7MRX_85": [0, 63],
    "7MRX_128": [0, 122],
    "4JHW": [10, 25],
    "4ZYP": [10, 40],
    "5WN9": [10, 40],
    "5TRV_short": [0, 35],
    "5TRV_med": [0, 65],
    "5TRV_long": [0, 95],
    "6E6R_short": [0, 35],
    "6E6R_med": [0, 65],
    "6E6R_long": [0, 95],
    "6EXZ_short": [0, 35],
    "6EXZ_med": [0, 65],
    "6EXZ_long": [0, 95],
}

# mapping the scaffolding task to the reference motif PDB code
motif_name_mapping = {
    "1PRW": "1prw",
    "1BCF": "1bcf",
    "5TPN": "5tpn",
    "5IUS": "5ius",
    "3IXT": "3ixt",
    "5YUI": "5yui",
    "1QJG": "1qjg",
    "1YCR": "1ycr",
    "2KL8": "2kl8",
    "7MRX_60": "7mrx",
    "7MRX_85": "7mrx",
    "7MRX_128": "7mrx",
    "4JHW": "4jhw",
    "4ZYP": "4zyp",
    "5WN9": "5wn9",
    "5TRV_short": "5trv",
    "5TRV_med": "5trv",
    "5TRV_long": "5trv",
    "6E6R_short": "6e6r",
    "6E6R_med": "6e6r",
    "6E6R_long": "6e6r",
    "6EXZ_short": "6exz",
    "6EXZ_med": "6exz",
    "6EXZ_long": "6exz",
}
scaffold_interval = {
    "1PRW": [[10, 25]],
    "1BCF": [[16, 30], [16, 30], [16, 30]],
    "5IUS": [[15, 40]],
    "5YUI": [[5, 20], [10, 35]],
    "1QJG": [[15, 30], [15, 30]],
    "2KL8": [[20, 20]],
    "4JHW": [[15, 30]],
}
total_length = {
    "1PRW": -1,
    "1BCF": -1,
    "5TPN": [50, 75],
    "5IUS": -1,
    "3IXT": [50, 75],
    "5YUI": [50, 100],
    "1QJG": -1,
    "1YCR": [40, 100],
    "2KL8": -1,
    "7MRX_60": [60, 60],
    "7MRX_85": [85, 85],
    "7MRX_128": [128, 128],
    "4JHW": [60, 90],
    "4ZYP": [30, 50],
    "5WN9": [35, 50],
    "5TRV_short": [56, 56],
    "5TRV_med": [86, 86],
    "5TRV_long": [116, 116],
    "6E6R_short": [48, 48],
    "6E6R_med": [78, 78],
    "6E6R_long": [108, 108],
    "6EXZ_short": [50, 50],
    "6EXZ_med": [80, 80],
    "6EXZ_long": [110, 110],
}

# if the motif has a single domain, len(start_idx_dict[pdb]) == 1
# if the motif have multi-domains, len(start_idx_dict[pdb]) > 1
start_idx_dict = {
    "1prw": [15, 51],
    "1bcf": [90, 122, 46, 17],
    "5tpn": [108],
    "3ixt": [0],
    "4jhw": [144, 37],
    "4zyp": [357],
    "5wn9": [1],
    "5ius": [88, 34],
    "5yui": [89, 114, 194],
    "6vw1": [5, 45],
    "1qjg": [37, 13, 98],
    "1ycr": [2],
    "2kl8": [0, 27],
    "7mrx": [25],
    "5trv": [45],
    "6e6r": [22],
    "6exz": [25],
}

# if the motif has a single domain, len(end_idx_dict[pdb]) == 1
# if the motif have multi-domains, len(end_idx_dict[pdb]) > 1
end_idx_dict = {
    "1prw": [34, 70],
    "1bcf": [98, 129, 53, 24],
    "5tpn": [126],
    "3ixt": [23],
    "4jhw": [159, 43],
    "4zyp": [371],
    "5wn9": [20],
    "5ius": [109, 53],
    "5yui": [93, 116, 196],
    "6vw1": [23, 63],
    "1qjg": [37, 13, 98],
    "1ycr": [10],
    "2kl8": [6, 78],
    "7mrx": [46],
    "5trv": [69],
    "6e6r": [34],
    "6exz": [39],
}

# we have extracted the chain from the PDB file
chain_dict = {
    "1prw": "A",
    "1bcf": "A",
    "5tpn": "A",
    "3ixt": "P",
    "4jhw": "F",
    "4zyp": "A",
    "5wn9": "A",
    "5ius": "A",
    "5yui": "A",
    "6vw1": "A",
    "1qjg": "A",
    "1ycr": "B",
    "2kl8": "A",
    "7mrx": "B",
    "5trv": "A",
    "6e6r": "A",
    "6exz": "A",
}


def get_intervals(list, single_res_domain=False):
    """Given a list (Tensor) of non-masked residues get new start and end index
    for motif placed in scaffold."""
    if single_res_domain:
        start = [l.item() for l in list]
        stop = start
    else:
        start = []
        stop = []
        for i, item in enumerate(list):
            if i == 0:
                start.append(item.item())
            elif i == (len(list) - 1):
                stop.append(item.item())
            elif i != len(list) and (item + 1) != list[i + 1]:
                stop.append(item.item())
                start.append(list[i + 1].item())
    return start, stop

restype_with_x = np.array(['A', 'R', 'N', 'D', 'C', 'Q', 'E', \
                           'G', 'H', 'I', 'L', 'K', 'M', 'F', \
                            'P', 'S', 'T', 'W', 'Y', 'V', 'X'])

def get_motif(pdb_name, ori_pdb_name, prot_dict, mask_token):
    start_idxs = start_idx_dict[ori_pdb_name]
    end_idxs = end_idx_dict[ori_pdb_name]
    assert len(start_idxs) == len(end_idxs)

    aatype = prot_dict["aatype"]
    sequence = restype_with_x[aatype]
    structure = prot_dict["latent"]

    end_idxs = [i + 1 for i in end_idxs]  # inclusive of final residue
    motif_seq = list(sequence[start_idxs[0] : end_idxs[0]])
    motif_struct = torch.tensor(structure[start_idxs[0] : end_idxs[0]])
    
    # if len(spacer_list) == 0, then we do not add any spacer
    for i in range(1, len(start_idxs)):
        interval_start = scaffold_interval[pdb_name][i-1][0]
        interval_end = scaffold_interval[pdb_name][i-1][1]
        spacer_num = random.randint(interval_start, interval_end)

        motif_seq += [mask_token] * spacer_num
        # [L, 20] -> [L + spacer, 20] (padding with zeros)
        motif_struct = torch.cat(
            (motif_struct, torch.zeros((spacer_num, 20), dtype=motif_struct.dtype))
        )
        
        motif_seq += list(sequence[start_idxs[i] : end_idxs[i]])
        motif_struct = torch.cat(
            (motif_struct, torch.tensor(structure[start_idxs[i] : end_idxs[i]]))
        )
    
    # here we returen the motif sequence (still haven't been processed by the tokenizer) and structure
    return motif_seq, motif_struct

# ====================================================================
# =============== For Our Construct Motif-Scaffolding ================
# ====================================================================

def get_initial_prot(prot_dict, tokenizer, pdb, ori_pdb, num, cover_ori_motif, device):
    
    # get motif tokens, and randomly insert mask tokens before and after the motif
    init_seq, init_struct, scaffold_length_list = create_init_seq(
        pdb, ori_pdb, prot_dict, tokenizer, num
    )

    batches = create_batches(init_seq, init_struct, tokenizer, cover_ori_motif, device)
    
    # create start and end indexes for the motif
    start_idxs_list, end_idxs_list = create_idxs_list(ori_pdb, tokenizer, init_seq)

    return batches, start_idxs_list, end_idxs_list, scaffold_length_list

def create_batches(init_seq, init_struct, tokenizer, cover_ori_motif, device):

    batches = []

    def get_seq_len(init_seq, cover_ori_motif, mask_id):
        if cover_ori_motif:
            seq_lens = [len(seq) - 2 for seq in init_seq]  # without special tokens
        else:
            seq_lens = [(np.array(seq)==mask_id).sum() for seq in init_seq]
        return pd.Series(seq_lens)

    mask_id = tokenizer.mask_token_id
    pad_id = tokenizer.pad_token_id
    seq_len_series = get_seq_len(
        init_seq, cover_ori_motif, mask_id
    )
    # groupby the sequence length
    for seq_len, group in seq_len_series.groupby(seq_len_series):
        group_idxs = group.index.tolist()
        batch_init_seq = [init_seq[i] for i in group_idxs]
        batch_init_struct = [init_struct[i] for i in group_idxs]

        # padding the sequence and structure to the max length in the batch
        batch_len = max([len(seq) for seq in batch_init_seq])
        # fill in pad token for [bsz, batch_len]
        for i in range(len(batch_init_seq)):
            batch_init_seq[i] = batch_init_seq[i] + [pad_id] * (batch_len - len(batch_init_seq[i]))
            pad_struct = torch.zeros((batch_len - batch_init_struct[i].shape[0], 20), dtype=batch_init_struct[i].dtype)
            batch_init_struct[i] = torch.cat((batch_init_struct[i], pad_struct), dim=0)

        batch = {
            "idxs": group_idxs,
            "seq_len": seq_len,
            "input_ids": torch.tensor(batch_init_seq, dtype=torch.long, device=device),
            "struct_latent": torch.stack(batch_init_struct).to(device),
        }
        batches.append(batch)

    return batches

def create_init_seq(pdb, ori_pdb, prot_dict, tokenizer, num):

    mask_token = tokenizer.mask_token

    init_seq = []
    init_struct = []
    scaffold_length_list = []
    for i in range(num):
        ## Process length
        length_compatible = False
        while length_compatible is False:
            scaffold_left_length = random.randint(
                scaffold_left[pdb][0], scaffold_left[pdb][1]
            )
            motif_seq, motif_struct = get_motif(
                pdb_name=pdb,
                ori_pdb_name=ori_pdb,
                prot_dict=prot_dict,
                mask_token=mask_token,
            )
            motif_overall_length = len(motif_seq)

            if total_length[pdb] != -1:
                current_length_range = [
                    scaffold_left_length
                    + motif_overall_length
                    + scaffold_right[pdb][0],
                    scaffold_left_length
                    + motif_overall_length
                    + scaffold_right[pdb][1],
                ]
                total_length_range = [
                    total_length[pdb][0],
                    total_length[pdb][1],
                ]
                length_range = [
                    max(current_length_range[0], total_length_range[0]),
                    min(current_length_range[1], total_length_range[1]),
                ]
                if length_range[0] > length_range[1]:
                    continue
                length_compatible = True
                scaffold_right_length = random.randint(
                    length_range[0], length_range[1]
                ) - (scaffold_left_length + motif_overall_length)
            else:
                length_compatible = True
                scaffold_right_length = random.randint(
                    scaffold_right[pdb][0], scaffold_right[pdb][1]
                )

            overall_length = (
                scaffold_left_length + motif_overall_length + scaffold_right_length
            )
            scaffold_length_list.append(scaffold_left_length + scaffold_right_length)

        ## motif aa seq initialization
        seq = (
            [mask_token] * scaffold_left_length
            + motif_seq
            + [mask_token] * scaffold_right_length
        )
        seq = "".join(seq)
        seq = tokenizer.encode(seq, add_special_tokens=True)
        assert len(seq) == (overall_length + 2)
        init_seq.append(seq)

        # [L, 20] -> [scaffold_left + L + scaffold_right, 20], padding with zeros
        struct = torch.cat(
            (
                torch.zeros(scaffold_left_length + 1, 20, dtype=motif_struct.dtype),
                motif_struct,
                torch.zeros(scaffold_right_length + 1, 20, dtype=motif_struct.dtype),
            )
        )
        struct = struct.mul_(0.1875)  # scale the structure
        assert struct.shape[0] == (overall_length + 2)
        init_struct.append(struct)

    # seq is a list processed by tokenizer, struct is a tensor
    return init_seq, init_struct, scaffold_length_list


def create_idxs_list(pdb, tokenizer, init_seq):
    # special tokens
    mask_token = tokenizer.mask_token
    bos_token = tokenizer.cls_token
    eos_token = tokenizer.eos_token

    single_res_domain = pdb in single_res

    start_idxs_list = []
    end_idxs_list = []
    pad_id = tokenizer.pad_token_id
    mask_id = tokenizer.added_tokens_encoder[mask_token]
    bos_id = tokenizer.added_tokens_encoder[bos_token]
    eos_id = tokenizer.added_tokens_encoder[eos_token]
    get_intervals_seqs = init_seq

    for seq in get_intervals_seqs:
        seq = torch.tensor(seq, dtype=torch.long)
        nonmask_locations = (
            (seq != mask_id)
            & (seq != bos_id)
            & (seq != eos_id)
            & (seq != pad_id)
        )
        nonmask_locations = nonmask_locations.nonzero().flatten() - 1
        new_start_idxs, new_end_idxs = get_intervals(
            nonmask_locations, single_res_domain=single_res_domain
        )
        start_idxs_list.append(new_start_idxs)
        end_idxs_list.append(new_end_idxs)

    return start_idxs_list, end_idxs_list

