# Data Processing following TankBind
import os

import numpy as np
import pandas as pd
import scipy
import torch
from Bio.PDB import PDBParser
from rdkit import Chem
from torch_geometric.data import Dataset, HeteroData
from torch_geometric.utils import to_dense_adj
from tqdm import tqdm
import lmdb
import pickle

from feature import get_protein_feature, extract_torchdrug_feature_from_mol, get_keepNode, \
    get_protein_edge_features_and_index


#adj - > n_hops connections adj
from utils import read_molecule, uniform_random_rotation


def construct_data_from_graph_gvp(protein_coords, protein_seq, protein_node_s,
                                  protein_node_v, protein_edge_index, protein_edge_s, protein_edge_v,
                                  compound_coords, compound_node_features, input_atom_edge_list,
                                  input_atom_edge_attr_list, LAS_edge_index, rdkit_coords, setting,
                                  contactCutoff=8.0, pocket_radius=20,
                                  add_noise_to_com=None, use_whole_protein=False, use_compound_com_as_pocket=True, chosen_pocket_com=None,
                                  ):
    n_protein_node = protein_coords.shape[0]
    n_compound_node = compound_coords.shape[0]
    # centroid instead of com.
    compound_center = compound_coords.mean(axis=0)
    keepNode = get_keepNode(compound_center, protein_coords.numpy(), n_protein_node, pocket_radius, use_whole_protein,
                            use_compound_com_as_pocket, add_noise_to_com, chosen_pocket_com)   # CONFUSE

    if keepNode.sum() < 5:
        # if only include less than 5 residues, simply add first 100 residues.
        keepNode[:100] = True
    input_node_xyz = protein_coords[keepNode]
    input_edge_idx, input_protein_edge_s, input_protein_edge_v = get_protein_edge_features_and_index(protein_edge_index, protein_edge_s, protein_edge_v, keepNode)

    # construct graph data.
    data = HeteroData()

    # only if your ligand is real this y_contact is meaningful.
    dis_map = scipy.spatial.distance.cdist(input_node_xyz.cpu().numpy(), compound_coords)
    y_contact = dis_map < contactCutoff
    data.dis_map = torch.tensor(dis_map, dtype=torch.float).flatten()

    # additional information. keep records.
    data.y = torch.tensor(y_contact, dtype=torch.float).flatten()
    data.seq = protein_seq[keepNode]
    data['protein'].coords = input_node_xyz
    data['protein'].node_s = protein_node_s[keepNode] # [num_protein_nodes, num_protein_feautre]
    data['protein'].node_v = protein_node_v[keepNode]
    data['protein', 'p2p', 'protein'].edge_index = input_edge_idx
    data['protein', 'p2p', 'protein'].edge_s = input_protein_edge_s
    data['protein', 'p2p', 'protein'].edge_v = input_protein_edge_v
    pocket_center = data['protein'].coords.mean(axis=0)

    data['compound'].x = compound_node_features
    data['compound'].true_coords = torch.tensor(compound_coords, dtype=torch.float)
    rdkit_coords = uniform_random_rotation(rdkit_coords) # TODO
    data['compound'].rdkit_coords = torch.tensor(rdkit_coords, dtype=torch.float)
    # Init Compound: Random Initialization around Pocket Center
    if setting == "perturb 3A":
        data['compound'].init_coords = 3 * (2 * torch.rand(compound_coords.shape) - 1) + data['compound'].true_coords
    elif setting == "perturb 4A":
        data['compound'].init_coords = 4 * (2 * torch.rand(compound_coords.shape) - 1) + data['compound'].true_coords
    elif setting == "perturb 5A":
        data['compound'].init_coords = 5 * (2 * torch.rand(compound_coords.shape) - 1) + data['compound'].true_coords
    elif setting == "compound center":
        data["compound"].init_coords = 10 * (2 * torch.rand(compound_coords.shape) - 1) + compound_center.reshape(1, 3)
    elif setting == "pocket center":
        data["compound"].init_coords = 5 * (2 * torch.rand(compound_coords.shape) - 1) + pocket_center.reshape(1, 3)
    elif setting == "pocket center from rdkit":
        data["compound"].init_coords = data['compound'].rdkit_coords \
                                       - data['compound'].rdkit_coords.mean(axis=0).reshape(1, 3)\
                                       + pocket_center.reshape(1, 3)
    #data["compound"].init_coords = compound_coords - compound_center.reshape(1, 3) + pocket_center.reshape(1, 3)
    data['compound', 'c2c', 'compound'].edge_index = input_atom_edge_list[:, :2].long().t().contiguous()
    data['compound', 'c2c', 'compound'].edge_weight = torch.ones(input_atom_edge_list.shape[0])
    data['compound', 'c2c', 'compound'].edge_attr = input_atom_edge_attr_list
    data['compound', 'LAS', 'compound'].edge_index = LAS_edge_index
    return data, input_node_xyz, keepNode

class PDBbind(Dataset):
    def __init__(self, root, refine_folder, other_folder, index_file, mode,
                 pocket_radius=20, add_noise_to_com=None,
                 transform=None, pre_transform=None, pre_filter=None, setting=None,
                 ):
        self.index_list = []
        self.refine_folder = refine_folder
        self.other_folder = other_folder
        self.index_file = index_file
        self.mode = mode
        self.setting = setting
        self.add_noise_to_com = add_noise_to_com
        super().__init__(root, transform, pre_transform, pre_filter)
        print(self.processed_paths)
        self.pocket_radius = pocket_radius
        self.lmdb_env = lmdb.open(self.processed_paths[0], readonly=True, lock=False, readahead=False, meminit=False)
        self.index_list = torch.load(self.processed_paths[1])

    def process(self):
        env = lmdb.open(self.processed_paths[0], map_size=int(1e11))
        txn = env.begin(write=True)
        parser = PDBParser(QUIET=True)

        with open(self.index_file, "r") as f:
            index_lines = f.readlines()
            for index in tqdm(index_lines):
                index = index.strip()
                try:
                    if os.path.isdir(os.path.join(self.refine_folder, index)):
                        root_folder = os.path.join(self.refine_folder, index)
                    elif os.path.isdir(os.path.join(self.other_folder, index)):
                        root_folder = os.path.join(self.other_folder, index)
                    else:
                        raise Exception("File Not Found")

                    compound_file_sdf = os.path.join(root_folder, f"{index}_ligand.sdf")
                    compound_file_mol2 = os.path.join(root_folder, f"{index}_ligand.mol2")
                    compound_name = f"{index}_ligand"
                    mol = read_molecule(compound_file_sdf, sanitize=True, remove_hs=True)
                    if mol is None:
                        mol = read_molecule(compound_file_mol2, sanitize=True, remove_hs=True)
                    if mol is None:
                        raise Exception("Read Mol Failed")
                    compound_feat = extract_torchdrug_feature_from_mol(mol, has_LAS_mask=True)

                    try:
                        protein_file = os.path.join(root_folder, f"{index}_protein.pdb")
                        pocket_name = f"{index}_protein"
                        s = parser.get_structure("x", protein_file)
                        res_list = list(s.get_residues())
                        protein_feat = get_protein_feature(res_list)

                        info = [index, compound_name, pocket_name]
                        data_to_store = {
                            "compound_feat": compound_feat,
                            "protein_feat": protein_feat,
                            "info": info
                        }
                        txn.put(index.encode('ascii'), pickle.dumps(data_to_store))
                        self.index_list.append(index)
                        if len(self.index_list) % 20 == 0:
                            txn.commit()
                            txn = env.begin(write=True)
                    except Exception:
                        raise Exception("Read Protein Failed")
                except Exception as e:
                    with open(os.path.join(self.root, f"log_{self.mode}.txt"), "a") as f:
                        f.write(f"{index}:  {e.args} \n")
            txn.commit()
            env.close()
            torch.save(self.index_list, self.processed_paths[1])
            print('Finish writing lmdb.')

    @property
    def processed_file_names(self):
        return [f'{self.mode}_data.lmdb', f'{self.mode}_index.pt']

    def len(self):
        return len(self.index_list)

    def get(self, idx):
        index = self.index_list[idx]
        with self.lmdb_env.begin(write=False) as txn:
            data_to_load = pickle.loads(txn.get(index.encode('ascii')))

        protein_node_xyz, protein_seq, protein_node_s, protein_node_v, \
        protein_edge_index, protein_edge_s, protein_edge_v = data_to_load['protein_feat']

        coords, compound_node_features, input_atom_edge_list, input_atom_edge_attr_list, \
        pair_dis_distribution, LAS_edge_index = data_to_load['compound_feat']

        # y is distance map, instead of contact map.
        data, input_node_list, keepNode = construct_data_from_graph_gvp(protein_node_xyz, protein_seq, protein_node_s,
                                                                        protein_node_v, protein_edge_index,
                                                                        protein_edge_s, protein_edge_v,
                                                                        coords, compound_node_features,
                                                                        input_atom_edge_list, input_atom_edge_attr_list,
                                                                        LAS_edge_index,
                                                                        pocket_radius=self.pocket_radius,
                                                                        use_whole_protein=False,    # CONFUSE: default False
                                                                        use_compound_com_as_pocket=True,   # TODO: different
                                                                        chosen_pocket_com=None,     # TODO: different
                                                                        add_noise_to_com=self.add_noise_to_com,
                                                                        setting=self.setting,
                                                                        )

        data.idx = idx
        return data

    def get_ligand_file_path_sdf(self, idx):
        index = self.index_list[idx]
        if os.path.exists(os.path.join(self.refine_folder, index, f"{index}_ligand.sdf")):
            ligand_file = os.path.join(self.refine_folder, index, f"{index}_ligand.sdf")
        elif os.path.exists(os.path.join(self.other_folder, index, f"{index}_ligand.sdf")):
            ligand_file = os.path.join(self.other_folder, index, f"{index}_ligand.sdf")
        else:
            ligand_file = None
        return ligand_file

    def get_ligand_file_path_mol2(self, idx):
        index = self.index_list[idx]
        if os.path.exists(os.path.join(self.refine_folder, index, f"{index}_ligand.mol2")):
            ligand_file = os.path.join(self.refine_folder, index, f"{index}_ligand.mol2")
        elif os.path.exists(os.path.join(self.other_folder, index, f"{index}_ligand.mol2")):
            ligand_file = os.path.join(self.other_folder, index, f"{index}_ligand.mol2")
        else:
            ligand_file = None
        return ligand_file

    def get_index(self, idx):
        index = self.index_list[idx]
        return index


class TankBindDataSet_new(Dataset):
    def __init__(self, root, pocket_info_path, protein_embed_folder, compound_embed_path, compound_folder=None,
                 pocket_radius=20, contactCutoff=8.0, shake_nodes=None, setting=None,
                 transform=None, pre_transform=None, pre_filter=None, native_pocket_threshold=0.9):
        self.pocket_info_path = pocket_info_path
        self.protein_embed_folder = protein_embed_folder
        self.compound_embed_path = compound_embed_path
        self.compound_folder = compound_folder
        self.add_noise_to_com = None   # TODO: not robust
        self.pocket_radius = pocket_radius
        self.contactCutoff = contactCutoff
        self.shake_nodes = shake_nodes
        self.setting = setting
        self.native_pocket_threshold = native_pocket_threshold

        self.skip_pdb_list = ['4dcy', '5fjw', '5dhf', '4i60', '4xkc', '4jfv', '4jfw', '1t2v', '3zs1', '3p55', '3p3h',
                           '3w8o', '4jhq', '3p3j', '3p44', '3whw', '2jld']  # TODO

        super().__init__(root, transform, pre_transform, pre_filter)
        print(self.processed_paths)
        self.protein_lmdb = lmdb.open(self.processed_paths[1], readonly=True, lock=False, readahead=False,
                                      meminit=False)
        self.compound_lmdb = lmdb.open(self.processed_paths[2], readonly=True, lock=False, readahead=False,
                                       meminit=False)
        self.info = torch.load(self.processed_paths[0])
        self.compound_rdkit_coords = torch.load(self.processed_paths[3])
        # TODO: delete
        self.skip_pdb_list = ['4dcy', '5fjw', '5dhf', '4i60', '4xkc', '4jfv', '4jfw', '1t2v', '3zs1', '3p55', '3p3h',
                           '3w8o', '4jhq', '3p3j', '3p44', '3whw', '2jld']
        self.info = self.info.query("compound_name not in @self.skip_pdb_list").reset_index(drop=True)


    @property
    def processed_file_names(self):
        # TODO: compound_rdkit_coords.pt
        return [f'info.pt', 'protein_embed.pt', 'compound_embed.pt', 'compound_rdkit_coords.pt']

    def process(self):
        if not os.path.exists(self.processed_paths[1]):
            # Save Protein as lmdb (too large)
            print("Process Protein")
            env = lmdb.open(self.processed_paths[1], map_size=int(1e11))
            txn = env.begin(write=True)
            for i, file_name in tqdm(enumerate(os.listdir(self.protein_embed_folder))):
                file_path = os.path.join(self.protein_embed_folder, file_name)
                index = file_name.split(".")[0]
                protein = torch.load(file_path)[index]
                txn.put(index.encode('ascii'), pickle.dumps(protein))
                if i % 20 == 0:
                    txn.commit()
                    txn = env.begin(write=True)
            txn.commit()
            env.close()

        if not os.path.exists(self.processed_paths[2]):
            print("Process Compound")
            env = lmdb.open(self.processed_paths[2], map_size=int(1e11))
            txn = env.begin(write=True)
            compound_dict = torch.load(self.compound_embed_path)
            for index, compound in tqdm(compound_dict.items()):
                txn.put(index.encode('ascii'), pickle.dumps(compound))
            txn.commit()
            env.close()

        self.protein_lmdb = lmdb.open(self.processed_paths[1], readonly=True, lock=False, readahead=False,
                                      meminit=False)
        self.compound_lmdb = lmdb.open(self.processed_paths[2], readonly=True, lock=False, readahead=False,
                                       meminit=False)
        if not os.path.exists(self.processed_paths[0]):
            print("Process Index")
            self.info = torch.load(self.pocket_info_path)
            t = []
            pre_pdb = None
            for i, line in tqdm(self.info.iterrows(), total=self.info.shape[0]):
                pdb = line['compound_name']
                d = self.get(i)
                p_length = d['protein'].coords.shape[0]
                c_length = d['compound'].true_coords.shape[0]
                y_length = d.y.shape[0]
                num_contact = (d.y > 0).sum()
                t.append([i, pdb, p_length, c_length, y_length, num_contact])
            t = pd.DataFrame(t, columns=['index', 'pdb', 'p_length', 'c_length', 'y_length', 'num_contact'])
            t['num_contact'] = t['num_contact'].apply(lambda x: x.item())
            self.info = pd.concat([self.info, t[['p_length', 'c_length', 'y_length', 'num_contact']]], axis=1)
            native_num_contact = self.info.query("use_compound_com").set_index("protein_name")['num_contact'].to_dict()
            self.info['native_num_contact'] = self.info.protein_name.map(native_num_contact)  # CONFUSE
            torch.save(self.info, self.processed_paths[0])


    def len(self):
        return len(self.info)

    def get_protein(self, index):
        with self.protein_lmdb.begin(write=False) as txn:
            protein = pickle.loads(txn.get(index.encode('ascii')))
        return protein

    def get_compound(self, index):
        with self.compound_lmdb.begin(write=False) as txn:
            compound = pickle.loads(txn.get(index.encode('ascii')))
        return compound

    def get(self, idx):
        line = self.info.iloc[idx]
        pocket_com = line['pocket_com']
        use_compound_com = bool(line['use_compound_com'])
        use_whole_protein = line['use_whole_protein'] if "use_whole_protein" in line.index else False
        group = line['group'] if "group" in line.index else 'train'

        protein_name = line['protein_name']
        protein_node_xyz, protein_seq, protein_node_s, protein_node_v, protein_edge_index, \
        protein_edge_s, protein_edge_v = self.get_protein(protein_name)

        compound_name = line['compound_name']
        rdkit_coords = self.compound_rdkit_coords[compound_name]
        coords, compound_node_features, input_atom_edge_list, input_atom_edge_attr_list, \
        pair_dis_distribution, LAS_edge_index = self.get_compound(compound_name)

        # node_xyz could add noise too.
        if self.shake_nodes is not None:
            protein_node_xyz = protein_node_xyz + self.shake_nodes * (2 * np.random.rand(*protein_node_xyz.shape) - 1)
            coords = coords + self.shake_nodes * (2 * np.random.rand(*coords.shape) - 1)

        data, input_node_list, keepNode = construct_data_from_graph_gvp(protein_node_xyz, protein_seq, protein_node_s,
                                                                        protein_node_v, protein_edge_index,
                                                                        protein_edge_s, protein_edge_v,
                                                                        coords, compound_node_features,
                                                                        input_atom_edge_list, input_atom_edge_attr_list,
                                                                        LAS_edge_index=LAS_edge_index,
                                                                        rdkit_coords=rdkit_coords,
                                                                        contactCutoff=self.contactCutoff,
                                                                        pocket_radius=self.pocket_radius,
                                                                        add_noise_to_com=self.add_noise_to_com,
                                                                        use_whole_protein=use_whole_protein,
                                                                        use_compound_com_as_pocket=use_compound_com,
                                                                        chosen_pocket_com=pocket_com,
                                                                        setting=self.setting,
                                                                        )

        # affinity = affinity_to_native_pocket * min(1, float((data.y.numpy() > 0).sum()/(5*coords.shape[0])))
        affinity = float(line['affinity'])
        data.affinity = torch.tensor([affinity], dtype=torch.float)
        data.pdb = line['pdb'] if "pdb" in line.index else f'smiles_{idx}'
        data.idx = idx
        data.group = group

        data.real_affinity_mask = torch.tensor([use_compound_com], dtype=torch.bool)
        data.real_y_mask = torch.ones(data.y.shape).bool() if use_compound_com else torch.zeros(data.y.shape).bool()
        # fract_of_native_contact = float(line['fract_of_native_contact']) if "fract_of_native_contact" in line.index else 1
        # equivalent native pocket
        if "native_num_contact" in line.index:
            fract_of_native_contact = (data.y.numpy() > 0).sum() / float(line['native_num_contact'])
            is_equivalent_native_pocket = bool(fract_of_native_contact >= self.native_pocket_threshold)    # TODO: how to directly filter this
            data.is_equivalent_native_pocket = torch.tensor([is_equivalent_native_pocket], dtype=torch.bool)
            data.equivalent_native_y_mask = torch.ones(data.y.shape).bool() if is_equivalent_native_pocket \
                else torch.zeros(data.y.shape).bool()
            data.equivalent_coord_mask = torch.ones(len(data['compound'].true_coords)).bool() if is_equivalent_native_pocket \
                else torch.zeros(len(data['compound'].true_coords)).bool()
        return data

    def get_ligand_file_path_sdf(self, idx):
        assert self.compound_folder is not None
        line = self.info.iloc[idx]
        compound_name = line['compound_name']
        if os.path.exists(os.path.join(self.compound_folder,  f"{compound_name}.sdf")):
            ligand_file = os.path.join(self.compound_folder,  f"{compound_name}.sdf")
        else:
            ligand_file = None
        return ligand_file

    def get_ligand_file_path_mol2(self, idx):
        assert self.compound_folder is not None
        line = self.info.iloc[idx]
        compound_name = line['compound_name']
        if os.path.exists(os.path.join(self.compound_folder,  f"{compound_name}.mol2")):
            ligand_file = os.path.join(self.compound_folder,  f"{compound_name}.mol2")
        else:
            ligand_file = None
        return ligand_file

    def get_index(self, idx):
        line = self.info.iloc[idx]
        index = line['pdb']
        return index



def get_train_index(refine_folder, other_folder, test_index_file, valid_index_file, output_file):
    refine_index = os.listdir(refine_folder)
    refine_index.remove('readme')
    other_index = os.listdir(other_folder)
    other_index.remove('readme')
    all_index = refine_index + other_index

    # remove_list = []
    with open(test_index_file, "r") as f:
        remove_lines = f.readlines()
        for line in remove_lines:
            all_index.remove(line.strip())
    with open(valid_index_file, "r") as f:
        remove_lines = f.readlines()
        for line in remove_lines:
            all_index.remove(line.strip())
    with open(output_file, "w") as f:
        for index in all_index:
            f.write(f"{index}\n")





class ComplexGraph:
    # (Compound1, Protein1, Compound1, Protein2, ...)
    def __init__(self, protein_node_nums, compound_node_nums, LAS_edge_index,
                 protein_batch=None, compound_batch=None,
                 protein_ptr=None, compound_ptr=None):
        self.device = protein_node_nums.device
        self.B = len(protein_node_nums)
        # assert protein_node_nums.device == compound_node_nums.device
        # assert len(protein_node_nums) == len(compound_node_nums)
        self.protein_node_nums = protein_node_nums
        self.compound_node_nums = compound_node_nums
        self.protein_batch = protein_batch
        self.compound_batch = compound_batch
        self.protein_ptr = protein_ptr
        self.compound_ptr = compound_ptr
        if self.protein_batch is None:
            self.protein_batch = torch.arange(self.B, device=self.device).repeat_interleave(protein_node_nums)
        if self.compound_batch is None:
            self.compound_batch = torch.arange(self.B, device=self.device).repeat_interleave(compound_node_nums)
        if self.protein_ptr is None:
            self.protein_ptr = torch.cat([protein_node_nums.new_zeros(1), protein_node_nums]).cumsum(0)
        if self.compound_ptr is None:
            self.compound_ptr = torch.cat([compound_node_nums.new_zeros(1), compound_node_nums]).cumsum(0)

        self.complex_node_nums = protein_node_nums + compound_node_nums
        _graph_node_nums = torch.stack((compound_node_nums, protein_node_nums), dim=0).T.reshape(-1)
        self.is_protein = torch.tensor([0, 1], dtype=torch.bool, device=self.device).repeat(self.B).repeat_interleave(
            _graph_node_nums)
        self.batch = torch.arange(self.B, device=self.device).repeat_interleave(self.complex_node_nums)

        self.LAS_edge_index = LAS_edge_index
        self.LAS_mask = to_dense_adj(LAS_edge_index, batch=self.compound_batch).to(torch.bool)

    def set_compound_edge(self, edge_list, edge_attr):
        self.compound_edge_list = edge_list
        self.compound_edge_attr = edge_attr

    def get_node_embed(self, protein_node_embed, compound_node_embed, distinguish):
        if distinguish is True:
            # complex_embed size = 2 * protein/compound_embed size
            protein_node_embed = torch.cat([
                protein_node_embed.new_zeros(protein_node_embed.shape), protein_node_embed], dim=1)
            compound_node_embed = torch.cat([
                compound_node_embed, compound_node_embed.new_zeros(compound_node_embed.shape)], dim=1)

        embed_size = protein_node_embed.shape[1]
        # assert protein_node_embed.shape[1] == compound_node_embed.shape[1]
        complex_node_embed = torch.empty(self.complex_node_nums.sum(), embed_size, device=self.device)
        complex_node_embed[self.is_protein] = protein_node_embed
        complex_node_embed[~self.is_protein] = compound_node_embed

        return complex_node_embed

    def get_node_coord(self, protein_node_coord, compound_node_coord):
        complex_node_coord = torch.empty(self.complex_node_nums.sum(), 3, device=self.device)
        complex_node_coord[self.is_protein] = protein_node_coord
        complex_node_coord[~self.is_protein] = compound_node_coord

        return complex_node_coord

    def get_protein_compound_edge(self, use_complex_index=False):
        # (p1,c1), (p1,c2), (p1,c3), ...
        if hasattr(self, "p2c_edge_index") is False:
            self.p2c_edge_index = {}
            edge_src = torch.empty(0, device=self.device, dtype=torch.int)
            edge_dst = torch.empty(0, device=self.device, dtype=torch.int)
            edge_src_complex_index = torch.empty(0, device=self.device, dtype=torch.int)
            edge_dst_complex_index = torch.empty(0, device=self.device, dtype=torch.int)
            for i in range(self.B):
                cur_protein_edge_src = torch.arange(
                    self.protein_node_nums[i], device=self.device)[:, None].repeat(1,
                                                                                   self.compound_node_nums[i]).reshape(
                    -1)
                cur_compound_edge_dst = torch.arange(
                    self.compound_node_nums[i], device=self.device)[None, :].repeat(self.protein_node_nums[i],
                                                                                    1).reshape(-1)

                edge_src = torch.cat([edge_src, cur_protein_edge_src + self.protein_ptr[i]])
                edge_dst = torch.cat([edge_dst, cur_compound_edge_dst + self.compound_ptr[i]])
                edge_src_complex_index = torch.cat(
                    [edge_src_complex_index, cur_protein_edge_src + self.protein_ptr[i] + self.compound_ptr[i + 1]])
                edge_dst_complex_index = torch.cat(
                    [edge_dst_complex_index, cur_compound_edge_dst + self.compound_ptr[i] + self.protein_ptr[i]])
            self.p2c_edge_index["complex_index"] = torch.stack([edge_src_complex_index, edge_dst_complex_index], dim=0)
            self.p2c_edge_index["inner_index"] = torch.stack([edge_src, edge_dst], dim=0)

        if use_complex_index:
            return self.p2c_edge_index["complex_index"]
        else:
            return self.p2c_edge_index["inner_index"]

    def get_compound_compound_edge(self):
        # (c1,c1), (c1,c2), (c1,c3)
        if hasattr(self, "c2c_edge_index") is False:
            edge_src = torch.empty(0, device=self.device, dtype=torch.int)
            edge_dst = torch.empty(0, device=self.device, dtype=torch.int)
            for i in range(self.B):
                cur_edge_src = torch.arange(self.compound_node_nums[i], device=self.device)[:, None].repeat(
                    1, self.compound_node_nums[i]).reshape(-1)
                cur_edge_dst = torch.arange(self.compound_node_nums[i], device=self.device)[None, :].repeat(
                    self.compound_node_nums[i], 1).reshape(-1)
                edge_src = torch.cat([edge_src, cur_edge_src + self.compound_ptr[i]])
                edge_dst = torch.cat([edge_dst, cur_edge_dst + self.compound_ptr[i]])
            self.c2c_edge_index = torch.stack([edge_src, edge_dst], dim=0)
        return self.c2c_edge_index
