import torch
import random
import numpy as np
import mdtraj as md
import os
from rdkit import Chem
from torch_geometric.data import Data, Dataset
import random
from tqdm import tqdm
from typing import Literal

import pickle
from rdkit.Chem.rdchem import HybridizationType, BondType
from .feature_utils import get_node_features
from .transforms import *
from .mistake_fixer import TrajFixer

from utils.data_filter import filter_data, get_smiles_dict
from typing import Literal, Any, Optional, Set

# import sys
# sys.path.append


########################################################
'''
UTILS
'''
########################################################


# Constants
BOND_TYPES = {t: i + 1 for i, t in enumerate(BondType.names.values())}
NUM_FRAMES = 12500
IGNORE = [
 'C_C_C_C_C_O_N_C_H_c1ccccc1_C_H_O_C_O_O_C_H_1C_C_2_O_C_H_OC_O_c3ccccc3_C_H_3_C_4_OC_C_O_CO_C_H_4C_C_H_O_C_3_C_C_O_C_H_O_C_C1C_C2_C_C_42',
 'C_C_C_C_C_O_N_C_H_c1ccccc1_C_H_O_C_O_O_C_H_1C_C_2_O_C_H_OC_O_c3ccccc3_C_H_3_C_4_OC_C_O_CO_C_H_4C_C_H_O_C_3_C_C_O_C_H_O_C_C1C_C2_C_C_41',
 'C_C_C_C_C_O_N_C_H_c1ccccc1_C_H_O_C_O_O_C_H_1C_C_2_O_C_H_OC_O_c3ccccc3_C_H_3_C_4_OC_C_O_CO_C_H_4C_C_H_O_C_3_C_C_O_C_H_O_C_C1C_C2_C_C_43',
 'C_C_C_C_C_O_N_C_H_c1ccccc1_C_H_O_C_O_O_C_H_1C_C_2_O_C_H_OC_O_c3ccccc3_C_H_3_C_4_OC_C_O_CO_C_H_4C_C_H_O_C_3_C_C_O_C_H_O_C_C1C_C2_C_C_44',
 'C_C_C_C_C_O_N_C_H_c1ccccc1_C_H_O_C_O_O_C_H_1C_C_2_O_C_H_OC_O_c3ccccc3_C_H_3_C_4_OC_C_O_CO_C_H_4C_C_H_O_C_3_C_C_O_C_H_O_C_C1C_C2_C_C_40',
 'C_C_1_C_O_O_CC_C_2_C_CC_C_3_C_C_CC_O_C_H_4_C_5_C_CC_C_H_O_C_H_6O_C_H_C_O_O-_C_H_O_C_H_O_C_H_6O_C_H_6O_C_H_C_O_O-_C_H_O_C_H_O_C_H_6O_C_C_C_C_H_5CC_C_43C_C_H_2C1_131',
 'C_C_1_C_O_O_CC_C_2_C_CC_C_3_C_C_CC_O_C_H_4_C_5_C_CC_C_H_O_C_H_6O_C_H_C_O_O-_C_H_O_C_H_O_C_H_6O_C_H_6O_C_H_C_O_O-_C_H_O_C_H_O_C_H_6O_C_C_C_C_H_5CC_C_43C_C_H_2C1_132',
 'C_C_1_C_O_O_CC_C_2_C_CC_C_3_C_C_CC_O_C_H_4_C_5_C_CC_C_H_O_C_H_6O_C_H_C_O_O-_C_H_O_C_H_O_C_H_6O_C_H_6O_C_H_C_O_O-_C_H_O_C_H_O_C_H_6O_C_C_C_C_H_5CC_C_43C_C_H_2C1_130',
 'C_C_1_C_O_O_CC_C_2_C_CC_C_3_C_C_CC_O_C_H_4_C_5_C_CC_C_H_O_C_H_6O_C_H_C_O_O-_C_H_O_C_H_O_C_H_6O_C_H_6O_C_H_C_O_O-_C_H_O_C_H_O_C_H_6O_C_C_C_C_H_5CC_C_43C_C_H_2C1_134',
 'C_C_1_C_O_O_CC_C_2_C_CC_C_3_C_C_CC_O_C_H_4_C_5_C_CC_C_H_O_C_H_6O_C_H_C_O_O-_C_H_O_C_H_O_C_H_6O_C_H_6O_C_H_C_O_O-_C_H_O_C_H_O_C_H_6O_C_C_C_C_H_5CC_C_43C_C_H_2C1_133',
 'CC_C_H_1OC_O_C_H_C_C_H_O_C_H_2C_C_C_OC_C_H_O_C_H_C_O2_C_H_C_C_H_O_C_H_2O_C_H_C_C_C_H_N_C_C_C_H_2O_C_C_O_C_C_H_C_C_H_2N_C_H_COCCOC_O_C_H_C_H_2C_C_1_C_O_48',
 'CC_C_H_1OC_O_C_H_C_C_H_O_C_H_2C_C_C_OC_C_H_O_C_H_C_O2_C_H_C_C_H_O_C_H_2O_C_H_C_C_C_H_N_C_C_C_H_2O_C_C_O_C_C_H_C_C_H_2N_C_H_COCCOC_O_C_H_C_H_2C_C_1_C_O_45',
 'CC_C_H_1OC_O_C_H_C_C_H_O_C_H_2C_C_C_OC_C_H_O_C_H_C_O2_C_H_C_C_H_O_C_H_2O_C_H_C_C_C_H_N_C_C_C_H_2O_C_C_O_C_C_H_C_C_H_2N_C_H_COCCOC_O_C_H_C_H_2C_C_1_C_O_46',
 'CC_C_H_1OC_O_C_H_C_C_H_O_C_H_2C_C_C_OC_C_H_O_C_H_C_O2_C_H_C_C_H_O_C_H_2O_C_H_C_C_C_H_N_C_C_C_H_2O_C_C_O_C_C_H_C_C_H_2N_C_H_COCCOC_O_C_H_C_H_2C_C_1_C_O_49',
 'CC_C_H_1OC_O_C_H_C_C_H_O_C_H_2C_C_C_OC_C_H_O_C_H_C_O2_C_H_C_C_H_O_C_H_2O_C_H_C_C_C_H_N_C_C_C_H_2O_C_C_O_C_C_H_C_C_H_2N_C_H_COCCOC_O_C_H_C_H_2C_C_1_C_O_47',
 'CC_C_1_O_C_C_H_2CN_CCc3c_nH_c4ccccc34_C_C_O_OC_c3cc4c_cc3OC_N_C_O_C_H_3_C_O_C_O_OC_C_H_OC_C_O_C_5_CC_C_CC_N_H_6CC_C_43_C_H_65_C2_C1_12',
 'CC_C_1_O_C_C_H_2CN_CCc3c_nH_c4ccccc34_C_C_O_OC_c3cc4c_cc3OC_N_C_O_C_H_3_C_O_C_O_OC_C_H_OC_C_O_C_5_CC_C_CC_N_H_6CC_C_43_C_H_65_C2_C1_11',
 'CC_C_1_O_C_C_H_2CN_CCc3c_nH_c4ccccc34_C_C_O_OC_c3cc4c_cc3OC_N_C_O_C_H_3_C_O_C_O_OC_C_H_OC_C_O_C_5_CC_C_CC_N_H_6CC_C_43_C_H_65_C2_C1_13',
 'CC_C_1_O_C_C_H_2CN_CCc3c_nH_c4ccccc34_C_C_O_OC_c3cc4c_cc3OC_N_C_O_C_H_3_C_O_C_O_OC_C_H_OC_C_O_C_5_CC_C_CC_N_H_6CC_C_43_C_H_65_C2_C1_14',
 'CC_C_1_O_C_C_H_2CN_CCc3c_nH_c4ccccc34_C_C_O_OC_c3cc4c_cc3OC_N_C_O_C_H_3_C_O_C_O_OC_C_H_OC_C_O_C_5_CC_C_CC_N_H_6CC_C_43_C_H_65_C2_C1_10',
 'CC_O_N_C_C_H_O_CN_C_C_O_c1c_I_c_C_O_NC_C_H_O_CO_c_I_c_C_O_NC_C_H_O_CO_c1I_c1c_I_c_C_O_NC_C_H_O_CO_c_I_c_C_O_NC_C_H_O_CO_c1I_142',
 'CC_O_N_C_C_H_O_CN_C_C_O_c1c_I_c_C_O_NC_C_H_O_CO_c_I_c_C_O_NC_C_H_O_CO_c1I_c1c_I_c_C_O_NC_C_H_O_CO_c_I_c_C_O_NC_C_H_O_CO_c1I_141',
 'CC_O_N_C_C_H_O_CN_C_C_O_c1c_I_c_C_O_NC_C_H_O_CO_c_I_c_C_O_NC_C_H_O_CO_c1I_c1c_I_c_C_O_NC_C_H_O_CO_c_I_c_C_O_NC_C_H_O_CO_c1I_143',
 'CC_O_N_C_C_H_O_CN_C_C_O_c1c_I_c_C_O_NC_C_H_O_CO_c_I_c_C_O_NC_C_H_O_CO_c1I_c1c_I_c_C_O_NC_C_H_O_CO_c_I_c_C_O_NC_C_H_O_CO_c1I_144',
 'CC_O_N_C_C_H_O_CN_C_C_O_c1c_I_c_C_O_NC_C_H_O_CO_c_I_c_C_O_NC_C_H_O_CO_c1I_c1c_I_c_C_O_NC_C_H_O_CO_c_I_c_C_O_NC_C_H_O_CO_c1I_140',
 'C_CCOCCOCCOCCNc1nc_N2CCN_C_O_C_H_C_H_C_CC_n3cc_C_H_N_CO_nn3_CC2_nc_N2CCN_C_O_C_H_C_H_C_CC_n3cc_C_H_NH3_CO_nn3_CC2_n1_83',
 'C_CCOCCOCCOCCNc1nc_N2CCN_C_O_C_H_C_H_C_CC_n3cc_C_H_N_CO_nn3_CC2_nc_N2CCN_C_O_C_H_C_H_C_CC_n3cc_C_H_NH3_CO_nn3_CC2_n1_84',
 'C_CCOCCOCCOCCNc1nc_N2CCN_C_O_C_H_C_H_C_CC_n3cc_C_H_N_CO_nn3_CC2_nc_N2CCN_C_O_C_H_C_H_C_CC_n3cc_C_H_NH3_CO_nn3_CC2_n1_80',
 'C_CCOCCOCCOCCNc1nc_N2CCN_C_O_C_H_C_H_C_CC_n3cc_C_H_N_CO_nn3_CC2_nc_N2CCN_C_O_C_H_C_H_C_CC_n3cc_C_H_NH3_CO_nn3_CC2_n1_82',
 'C_CCOCCOCCOCCNc1nc_N2CCN_C_O_C_H_C_H_C_CC_n3cc_C_H_N_CO_nn3_CC2_nc_N2CCN_C_O_C_H_C_H_C_CC_n3cc_C_H_NH3_CO_nn3_CC2_n1_81',
 'CC_C_H_1OC_O_C_C_H_O_C_H_C_C_H_O_C_H_2O_C_H_C_C_H_O_C_H_3C_C_C_O_C_H_O_C_H_C_O3_C_H_N_C_C_C_H_2O_C_H_CC_O_C_C_H_C_C_O_C_C_C_C_C_C_H_1CO_C_H_1O_C_H_C_C_H_O_C_H_OC_C_H_1OC_111',
 'CC_C_H_1OC_O_C_C_H_O_C_H_C_C_H_O_C_H_2O_C_H_C_C_H_O_C_H_3C_C_C_O_C_H_O_C_H_C_O3_C_H_N_C_C_C_H_2O_C_H_CC_O_C_C_H_C_C_O_C_C_C_C_C_C_H_1CO_C_H_1O_C_H_C_C_H_O_C_H_OC_C_H_1OC_112',
 'CC_C_H_1OC_O_C_C_H_O_C_H_C_C_H_O_C_H_2O_C_H_C_C_H_O_C_H_3C_C_C_O_C_H_O_C_H_C_O3_C_H_N_C_C_C_H_2O_C_H_CC_O_C_C_H_C_C_O_C_C_C_C_C_C_H_1CO_C_H_1O_C_H_C_C_H_O_C_H_OC_C_H_1OC_110',
 'CC_C_H_1OC_O_C_C_H_O_C_H_C_C_H_O_C_H_2O_C_H_C_C_H_O_C_H_3C_C_C_O_C_H_O_C_H_C_O3_C_H_N_C_C_C_H_2O_C_H_CC_O_C_C_H_C_C_O_C_C_C_C_C_C_H_1CO_C_H_1O_C_H_C_C_H_O_C_H_OC_C_H_1OC_113',
 'CC_C_H_1OC_O_C_C_H_O_C_H_C_C_H_O_C_H_2O_C_H_C_C_H_O_C_H_3C_C_C_O_C_H_O_C_H_C_O3_C_H_N_C_C_C_H_2O_C_H_CC_O_C_C_H_C_C_O_C_C_C_C_C_C_H_1CO_C_H_1O_C_H_C_C_H_O_C_H_OC_C_H_1OC_114',
 'C_CCOCCOCCOCCNc1nc_N2CCN_C_O_C_H_CCCCN_n3cc_C_H_NH3_C_H_C_CC_nn3_CC2_nc_N2CCN_C_O_C_H_CCCCN_n3cc_C_H_N_CO_nn3_CC2_n1_123',
 'C_CCOCCOCCOCCNc1nc_N2CCN_C_O_C_H_CCCCN_n3cc_C_H_NH3_C_H_C_CC_nn3_CC2_nc_N2CCN_C_O_C_H_CCCCN_n3cc_C_H_N_CO_nn3_CC2_n1_124',
 'C_CCOCCOCCOCCNc1nc_N2CCN_C_O_C_H_CCCCN_n3cc_C_H_NH3_C_H_C_CC_nn3_CC2_nc_N2CCN_C_O_C_H_CCCCN_n3cc_C_H_N_CO_nn3_CC2_n1_120',
 'C_CCOCCOCCOCCNc1nc_N2CCN_C_O_C_H_CCCCN_n3cc_C_H_NH3_C_H_C_CC_nn3_CC2_nc_N2CCN_C_O_C_H_CCCCN_n3cc_C_H_N_CO_nn3_CC2_n1_122',
 'C_CCOCCOCCOCCNc1nc_N2CCN_C_O_C_H_CCCCN_n3cc_C_H_NH3_C_H_C_CC_nn3_CC2_nc_N2CCN_C_O_C_H_CCCCN_n3cc_C_H_N_CO_nn3_CC2_n1_121',
 'C_CCOCCOCCOCCNc1nc_N2CCN_C_O_C_H_CCCCN_n3cc_C_H_NH3_C_H_C_CC_nn3_CC2_nc_N2CCN_C_O_C_H_Cc3cc4ccccc4_nH_3_n3cc_C_H_N_CC_C_C_nn3_CC2_n1_16',
 'C_CCOCCOCCOCCNc1nc_N2CCN_C_O_C_H_CCCCN_n3cc_C_H_NH3_C_H_C_CC_nn3_CC2_nc_N2CCN_C_O_C_H_Cc3cc4ccccc4_nH_3_n3cc_C_H_N_CC_C_C_nn3_CC2_n1_18',
 'C_CCOCCOCCOCCNc1nc_N2CCN_C_O_C_H_CCCCN_n3cc_C_H_NH3_C_H_C_CC_nn3_CC2_nc_N2CCN_C_O_C_H_Cc3cc4ccccc4_nH_3_n3cc_C_H_N_CC_C_C_nn3_CC2_n1_15',
 'C_CCOCCOCCOCCNc1nc_N2CCN_C_O_C_H_CCCCN_n3cc_C_H_NH3_C_H_C_CC_nn3_CC2_nc_N2CCN_C_O_C_H_Cc3cc4ccccc4_nH_3_n3cc_C_H_N_CC_C_C_nn3_CC2_n1_17',
 'C_CCOCCOCCOCCNc1nc_N2CCN_C_O_C_H_CCCCN_n3cc_C_H_NH3_C_H_C_CC_nn3_CC2_nc_N2CCN_C_O_C_H_Cc3cc4ccccc4_nH_3_n3cc_C_H_N_CC_C_C_nn3_CC2_n1_19',
 'C_C_H_1O_C_H_O_C_H_2_C_H_O_C_H_3CC_C_4_C_C_H_CC_C_5_C_C_H_4CC_C4_C_H_6CC_C_C_CC_C_6_C_O_O_C_H_6O_C_H_CO_C_H_7O_C_H_CO_C_H_O_C_H_8O_C_H_C_C_H_O_C_H_O_C_H_8O_C_H_O_C_H_7O_C_H_O_C_H_O_C_H_6O_CC_C_45C_C_3_C_CO_OC_C_H_O_C_H_2O_C_H_O_C_H_O_C_H_1O_34',
 'C_C_H_1O_C_H_O_C_H_2_C_H_O_C_H_3CC_C_4_C_C_H_CC_C_5_C_C_H_4CC_C4_C_H_6CC_C_C_CC_C_6_C_O_O_C_H_6O_C_H_CO_C_H_7O_C_H_CO_C_H_O_C_H_8O_C_H_C_C_H_O_C_H_O_C_H_8O_C_H_O_C_H_7O_C_H_O_C_H_O_C_H_6O_CC_C_45C_C_3_C_CO_OC_C_H_O_C_H_2O_C_H_O_C_H_O_C_H_1O_33',
 'C_C_H_1O_C_H_O_C_H_2_C_H_O_C_H_3CC_C_4_C_C_H_CC_C_5_C_C_H_4CC_C4_C_H_6CC_C_C_CC_C_6_C_O_O_C_H_6O_C_H_CO_C_H_7O_C_H_CO_C_H_O_C_H_8O_C_H_C_C_H_O_C_H_O_C_H_8O_C_H_O_C_H_7O_C_H_O_C_H_O_C_H_6O_CC_C_45C_C_3_C_CO_OC_C_H_O_C_H_2O_C_H_O_C_H_O_C_H_1O_30',
 'C_C_H_1O_C_H_O_C_H_2_C_H_O_C_H_3CC_C_4_C_C_H_CC_C_5_C_C_H_4CC_C4_C_H_6CC_C_C_CC_C_6_C_O_O_C_H_6O_C_H_CO_C_H_7O_C_H_CO_C_H_O_C_H_8O_C_H_C_C_H_O_C_H_O_C_H_8O_C_H_O_C_H_7O_C_H_O_C_H_O_C_H_6O_CC_C_45C_C_3_C_CO_OC_C_H_O_C_H_2O_C_H_O_C_H_O_C_H_1O_32',
 'C_C_H_1O_C_H_O_C_H_2_C_H_O_C_H_3CC_C_4_C_C_H_CC_C_5_C_C_H_4CC_C4_C_H_6CC_C_C_CC_C_6_C_O_O_C_H_6O_C_H_CO_C_H_7O_C_H_CO_C_H_O_C_H_8O_C_H_C_C_H_O_C_H_O_C_H_8O_C_H_O_C_H_7O_C_H_O_C_H_O_C_H_6O_CC_C_45C_C_3_C_CO_OC_C_H_O_C_H_2O_C_H_O_C_H_O_C_H_1O_31',
 'COC_O_C_1_Cc2ccc_OC_cc2_C_H_2c3cc_C_O_N4CCCC4_n_CCc4c_nH_c5ccc_O_cc45_c3C_C_H_2CN1C_O_c1ccccc1_28',
 'COC_O_C_1_Cc2ccc_OC_cc2_C_H_2c3cc_C_O_N4CCCC4_n_CCc4c_nH_c5ccc_O_cc45_c3C_C_H_2CN1C_O_c1ccccc1_25',
 'COC_O_C_1_Cc2ccc_OC_cc2_C_H_2c3cc_C_O_N4CCCC4_n_CCc4c_nH_c5ccc_O_cc45_c3C_C_H_2CN1C_O_c1ccccc1_26',
 'COC_O_C_1_Cc2ccc_OC_cc2_C_H_2c3cc_C_O_N4CCCC4_n_CCc4c_nH_c5ccc_O_cc45_c3C_C_H_2CN1C_O_c1ccccc1_29',
 'COC_O_C_1_Cc2ccc_OC_cc2_C_H_2c3cc_C_O_N4CCCC4_n_CCc4c_nH_c5ccc_O_cc45_c3C_C_H_2CN1C_O_c1ccccc1_27',
 'CCCCCC_O_N_C_H_c1ccccc1_C_H_O_C_O_O_C_H_1C_C_2_O_C_H_OC_O_c3ccccc3_C_H_3_C_4_OC_C_O_CO_C_H_4C_C_H_O_C_H_4OC_C_H_O_C_H_O_C_H_4O_C_3_C_C_O_C_H_O_C_C1C_C2_C_C_14',
 'CCCCCC_O_N_C_H_c1ccccc1_C_H_O_C_O_O_C_H_1C_C_2_O_C_H_OC_O_c3ccccc3_C_H_3_C_4_OC_C_O_CO_C_H_4C_C_H_O_C_H_4OC_C_H_O_C_H_O_C_H_4O_C_3_C_C_O_C_H_O_C_C1C_C2_C_C_13',
 'CCCCCC_O_N_C_H_c1ccccc1_C_H_O_C_O_O_C_H_1C_C_2_O_C_H_OC_O_c3ccccc3_C_H_3_C_4_OC_C_O_CO_C_H_4C_C_H_O_C_H_4OC_C_H_O_C_H_O_C_H_4O_C_3_C_C_O_C_H_O_C_C1C_C2_C_C_10',
 'CCCCCC_O_N_C_H_c1ccccc1_C_H_O_C_O_O_C_H_1C_C_2_O_C_H_OC_O_c3ccccc3_C_H_3_C_4_OC_C_O_CO_C_H_4C_C_H_O_C_H_4OC_C_H_O_C_H_O_C_H_4O_C_3_C_C_O_C_H_O_C_C1C_C2_C_C_12',
 'CCCCCC_O_N_C_H_c1ccccc1_C_H_O_C_O_O_C_H_1C_C_2_O_C_H_OC_O_c3ccccc3_C_H_3_C_4_OC_C_O_CO_C_H_4C_C_H_O_C_H_4OC_C_H_O_C_H_O_C_H_4O_C_3_C_C_O_C_H_O_C_C1C_C2_C_C_11',
 'CC_O_O_C_H_1C_C_H_O_C_H_2_C_H_O_C_C_H_O_C_H_3_C_H_O_C_C_H_O_C_H_4CC_C_5_C_C_H_CC_C_H_6_C_H_5C_C_H_O_C_5_C_C_H_C7_CC_O_OC7_CC_C_65O_C4_O_C_H_3C_O_C_H_2C_O_C_H_C_C_H_1O_C_H_1O_C_H_CO_C_H_O_C_H_O_C_H_1O_6',
 'CC_O_O_C_H_1C_C_H_O_C_H_2_C_H_O_C_C_H_O_C_H_3_C_H_O_C_C_H_O_C_H_4CC_C_5_C_C_H_CC_C_H_6_C_H_5C_C_H_O_C_5_C_C_H_C7_CC_O_OC7_CC_C_65O_C4_O_C_H_3C_O_C_H_2C_O_C_H_C_C_H_1O_C_H_1O_C_H_CO_C_H_O_C_H_O_C_H_1O_8',
 'CC_O_O_C_H_1C_C_H_O_C_H_2_C_H_O_C_C_H_O_C_H_3_C_H_O_C_C_H_O_C_H_4CC_C_5_C_C_H_CC_C_H_6_C_H_5C_C_H_O_C_5_C_C_H_C7_CC_O_OC7_CC_C_65O_C4_O_C_H_3C_O_C_H_2C_O_C_H_C_C_H_1O_C_H_1O_C_H_CO_C_H_O_C_H_O_C_H_1O_5',
 'CC_O_O_C_H_1C_C_H_O_C_H_2_C_H_O_C_C_H_O_C_H_3_C_H_O_C_C_H_O_C_H_4CC_C_5_C_C_H_CC_C_H_6_C_H_5C_C_H_O_C_5_C_C_H_C7_CC_O_OC7_CC_C_65O_C4_O_C_H_3C_O_C_H_2C_O_C_H_C_C_H_1O_C_H_1O_C_H_CO_C_H_O_C_H_O_C_H_1O_7',
 'CC_O_O_C_H_1C_C_H_O_C_H_2_C_H_O_C_C_H_O_C_H_3_C_H_O_C_C_H_O_C_H_4CC_C_5_C_C_H_CC_C_H_6_C_H_5C_C_H_O_C_5_C_C_H_C7_CC_O_OC7_CC_C_65O_C4_O_C_H_3C_O_C_H_2C_O_C_H_C_C_H_1O_C_H_1O_C_H_CO_C_H_O_C_H_O_C_H_1O_9'
]

# Common function for extracting the 2D features
def mol_2d(mol, features=None):
    # Ouputs (N_atoms,)
    atomic_number = []
    for atom in mol.GetAtoms():
        atomic_number.append(atom.GetAtomicNum())
    z = torch.tensor(atomic_number, dtype=torch.long)

    # Outputs (2, 2 * M_edges) and (2 * M_edges)
    row, col, edge_type = [], [], []
    for bond in mol.GetBonds():
        start = bond.GetBeginAtomIdx()
        end = bond.GetEndAtomIdx()
        row += [start, end]
        col += [end, start]

        bond_type = bond.GetBondType()
        bond_type_id = BOND_TYPES[bond_type]
        edge_type += [bond_type_id, bond_type_id] 
    edge_index = torch.tensor([row, col], dtype=torch.long)
    edge_type = torch.tensor(edge_type, dtype=torch.long)

    # Get the features (N_atoms, Feature_dim)
    if features:
        features = get_node_features(mol, features)

    # Return the values
    return z, edge_index, edge_type, features


def get_keep_atoms(mol, z):
    mol_no = Chem.RemoveHs(mol)
    important_atoms = list(mol.GetSubstructMatch(mol_no)) 
    one_hot_keep = torch.zeros(z.shape[0], dtype=torch.bool).numpy()
    one_hot_keep[important_atoms] = 1
    keep_idxs = [i for i, a in enumerate(one_hot_keep) if a == 1]
    assert important_atoms == keep_idxs
    return keep_idxs


########################################################
'''
TRAIN/TEST DATASETS
'''
########################################################


class ConformerDataset(Dataset):
    def __init__(
        self, pkl_path, 
        features=['aromatic', 'hybridization', 'partial_charge', 'num_bond_h', 'degree', 'formal_charge', 'ring_size'],
        transforms=["edge_order|2"],
        remove_hs=True,
        filter=True
    ):
        super().__init__()
        with open(pkl_path, 'rb') as f:
            self.data = pickle.load(f)
        if filter:
            self.data = filter_data(self.data)[0]
        self.features = features
        self.transforms = transforms
        self.remove_hs = remove_hs

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Get the molecule
        mol_data = self.data[idx]
        mol = mol_data['rdmol']

        # Remove Hydrogens if needed
        if self.remove_hs:
            mol = Chem.RemoveHs(mol)

        # Get the atomic numbers and features  Jiaqi: what is node_feature?
        z, edge_index, edge_type, node_features = mol_2d(mol, self.features)

        # Get the coordinates
        coords = torch.tensor(mol.GetConformer(0).GetPositions(), dtype=torch.float32).unsqueeze(-1)        

        data = Data(
            x=z, # [N,]
            pos=coords, # [N, 3, 1] (The 1 represents 1 time-step)
            edge_index=edge_index, # [2, M] 
            edge_attr=edge_type, # [M,]
        )

        if self.features:
            data.x_features = node_features # [N, F]

        # Apply the transforms
        applied_transforms = []
        for transform in self.transforms:
            transform_key = transform.split('|')[0]
            transform_arg = int(transform.split('|')[1])
            applied_transforms.append(TRANSFORMS[transform_key](transform_arg))
        all_transforms = Compose(applied_transforms)

        # Apply the transforms
        data = all_transforms(data)

        return data


class TrajectoryDataset(Dataset):
    def __init__(
        self,
        folder_path: str,
        expected_time_dim: int,
        conditioning: Literal["none", "forward", "interpolation"],
        features: list[str] = ["aromatic", "hybridization", "partial_charge", "num_bond_h"],
        transforms: list[str] = ["edge_order|3"],
        remove_hs: bool = True,
        num_frames: int | None = None,
        frame_rate: int | None = None,
        start_frame: int | None = None,
        random_frame: bool = False,
        subsample: float | None = None
    ):
        super().__init__()
        # Collect trajectory directories
        if "DRUGS" in folder_path:
            print("Ignoring traj with over 60 atoms")
            ignore = IGNORE
        else:
            ignore = []
        ignore = set(ignore or [])
        data_dict = {}
        def remove_suffix(s):
            return s.rsplit('_', 1)[0]
        for gen in os.listdir(folder_path):
            gen_dir = os.path.join(folder_path, gen)
            if not os.path.isdir(gen_dir):
                continue
            for smile in os.listdir(gen_dir):
                smile_dir = os.path.join(gen_dir, smile)
                if os.path.isdir(smile_dir) and os.path.exists(os.path.join(smile_dir, 'system.pdb')):
                    smi_name = os.path.basename(smile_dir)
                    if smi_name not in ignore:
                        smile_key = remove_suffix(smile)
                        if smile_key not in data_dict:
                            data_dict[smile_key] = []
                        data_dict[smile_key].append(smile_dir)

        if type(subsample) is float:
            print("Subsampling the molecules based on random index")
            keys = list(data_dict.keys())
            total = len(keys)
            number = int(total * subsample)
            rng = random.Random(0)
            chosen_idxs = rng.sample(range(total), k=number)
            new_keys = [keys[i] for i in chosen_idxs]
            new_dict = {k:data_dict[k] for k in new_keys}
            print(f"New number of molecules {len(new_dict)} vs the old {len(data_dict)}")
            data_dict = new_dict
        elif type(subsample) is str:
            print("Subsampling the molecules based on given path")
            with open(subsample, 'rb') as f:
                new_keys = pickle.load(f)
            new_dict = {k:data_dict[k] for k in new_keys}
            print(f"New number of molecules {len(new_dict)} vs the old {len(data_dict)}")
            data_dict = new_dict

        print(len(data_dict))

        self.data_dirs = []
        for v in data_dict.values():
            self.data_dirs += v

        self.features = features
        self.transforms = transforms
        self.remove_hs = remove_hs
        self.num_frames = num_frames
        self.frame_rate = frame_rate
        self.start_frame = start_frame
        self.random_frame = random_frame
        self.expected_time_dim = expected_time_dim
        self.conditioning = conditioning

        # Sanity check
        if self.num_frames is not None:
            assert self.num_frames <= NUM_FRAMES, (
                f"num_frames ({self.num_frames}) must be <= total frames ({NUM_FRAMES})."
            )

    def __len__(self) -> int:
        return len(self.data_dirs)

    def __getitem__(self, idx):
        # Get paths
        data_dir = self.data_dirs[idx]
        xtc_file = os.path.join(data_dir, 'traj.xtc')
        pdb_file = os.path.join(data_dir, 'system.pdb')
        mol_pkl = os.path.join(data_dir, 'mol.pkl')

        # Load trajectory
        traj = md.load(xtc_file, top=pdb_file)

        # Load the full RDKit molecule that corresponds to this trajectory
        mol = pickle.load(open(mol_pkl, 'rb'))

        # Remove Hs if needed
        if self.remove_hs:
            z = mol_2d(mol)[0]
            keep_idxs = get_keep_atoms(mol, z)
            traj = traj.atom_slice(keep_idxs)
            mol = Chem.RemoveHs(mol)

        # Center and align the trajectory (in memory)
        # The trajectories were already prealigned
        traj.center_coordinates()

        # Convert the traj coordinates to Angstroms because MDTraj uses Nm
        traj.xyz *= 10

        # Subsample frames evenly (e.g., every Nth frame)
        if self.random_frame:
            max_start = NUM_FRAMES - self.num_frames
            start = np.random.randint(0, max_start + 1)
        else:
            start = self.start_frame
        end = start + self.num_frames
        traj = traj[start:end:self.frame_rate]

        # Convert coordinates: traj.xyz shape is (T, N, 3)
        coords = torch.tensor(traj.xyz, dtype=torch.float32).permute(1, 2, 0)  # [N, 3, T]

        # Get the atomic numbers and features 
        z, edge_index, edge_type, features = mol_2d(mol, self.features)
        
        # Get the conditioning
        conditioning = torch.zeros(self.expected_time_dim, dtype=torch.bool) # Will contain True for conditioning frames
        if self.conditioning != 'none':
            conditioning[0] = True
        if self.conditioning == 'interpolation':
            conditioning[-1] = True
        denoise_coords = coords[:, :, ~conditioning]

        # Make sure sizes make sense
        assert z.shape[0] == coords.shape[0]
        assert denoise_coords.shape[-1] == self.expected_time_dim - torch.sum(conditioning)

        # Create Data object
        data = Data(
            x=z,  # [N,]
            pos=denoise_coords,  # [N, 3, T - C]
            edge_index=edge_index,  # [2, M] 
            edge_attr=edge_type,  # [M, 1]
            x_features=features,  # [N, F]
            original_frames=coords # [N, 3, T]
        )

        # Apply the transforms
        applied_transforms = []
        for transform in self.transforms:
            key_arg = transform.split('|')
            if len(key_arg) == 1:
                transform_key = key_arg[0]
                applied_transforms.append(TRANSFORMS[transform_key]())
            else:
                transform_key = key_arg[0]
                transform_arg = int(key_arg[1])
                applied_transforms.append(TRANSFORMS[transform_key](transform_arg))
        all_transforms = Compose(applied_transforms)

        # Apply the transforms
        data = all_transforms(data)

        # Return value
        return data
    

# Get the dataset
def get_datasets(config):
    if config.dataset.type == 'conformer':
        print("Loading conformer training dataset")
        train_dataset = ConformerDataset(
            pkl_path=config.dataset.train_conf_path, 
            features=config.dataset.features,
            transforms=config.dataset.transforms,
            filter=config.dataset.filter_data,
            remove_hs=config.dataset.remove_hs
        )
        print("Loading conformer validation dataset")
        val_dataset = ConformerDataset(
            pkl_path=config.dataset.val_conf_path, 
            features=config.dataset.features,
            transforms=config.dataset.transforms,
            filter=config.dataset.filter_data,
            remove_hs=config.dataset.remove_hs
        )
    elif config.dataset.type == 'trajectory':
        print("Loading trajectory training dataset")
        train_dataset = TrajectoryDataset(
            folder_path=config.dataset.train_traj_dir, 
            expected_time_dim=config.dataset.expected_time_dim,
            features=config.dataset.features,
            num_frames=config.dataset.num_frames,
            transforms=config.dataset.transforms,
            frame_rate=config.dataset.frame_rate,
            start_frame=config.dataset.start_frame,
            random_frame=config.dataset.random_frame,
            remove_hs=config.dataset.remove_hs,
            conditioning=config.denoiser.conditioning,
            subsample=config.dataset.subsample
        )
        print("Loading trajectory validation dataset")
        val_dataset = TrajectoryDataset(
            folder_path=config.dataset.val_traj_dir, 
            expected_time_dim=config.dataset.expected_time_dim,
            features=config.dataset.features,
            num_frames=config.dataset.num_frames,
            transforms=config.dataset.transforms,
            frame_rate=config.dataset.frame_rate,
            start_frame=config.dataset.start_frame,
            random_frame=config.dataset.random_frame,
            remove_hs=config.dataset.remove_hs,
            conditioning=config.denoiser.conditioning
        )
    else:
        raise ValueError(f"Invalid dataset type: {config.dataset.type}")
    
    return train_dataset, val_dataset


########################################################
'''
TEST DATASETS
'''
########################################################


class ConformerDatasetTest(Dataset):
    def __init__(
        self, pkl_path, 
        ratio=2,
        subsample=None,
        features=["aromatic", "hybridization", "partial_charge", "num_bond_h"],
        transforms=["edge_order|2"],
        remove_hs=True,
        filter=False
    ):
        super().__init__()
        with open(pkl_path, 'rb') as f:
            self.data = pickle.load(f)
        if filter:
            self.data = filter_data(self.data)[0]

        # Get all smiles and subsample if needed
        all_smiles_dict = get_smiles_dict(self.data)
        
        if subsample is not None:
            smiles_keys = {k: all_smiles_dict[k] for k in subsample}
        else:
            smiles_keys = all_smiles_dict

        # Get the smiles keys
        # Now the data should be in the ratio of the number of inferences we need.
        self.data = []
        for smiles, (item, count) in smiles_keys.items():
            for i in range(count * ratio):
                self.data.append((smiles, item, i))
        
        self.features = features
        self.transforms = transforms
        self.remove_hs = remove_hs

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Get the molecule
        data_item = self.data[idx]
        smiles, mol_data, i = data_item
        mol = mol_data['rdmol']

        # Remove Hydrogens if needed
        if self.remove_hs:
            mol = Chem.RemoveHs(mol)

        # Get the atomic numbers and features  Jiaqi: what is node_feature?
        z, edge_index, edge_type, node_features = mol_2d(mol, self.features)

        # Get the coordinates
        coords = torch.tensor(mol.GetConformer(0).GetPositions(), dtype=torch.float32).unsqueeze(-1)

        data = Data(
            x=z, # [N,]
            pos=coords, # [N, 3, 1] (The 1 represents 1 time-step)
            edge_index=edge_index, # [2, M] 
            edge_attr=edge_type, # [M,],
            smiles=smiles,
            rdmol=mol,
            conf_idx=i,
        )

        if self.features:
            data.x_features = node_features # [N, F]

        # Apply the transforms
        applied_transforms = []
        for transform in self.transforms:
            transform_key = transform.split('|')[0]
            transform_arg = int(transform.split('|')[1])
            applied_transforms.append(TRANSFORMS[transform_key](transform_arg))
        all_transforms = Compose(applied_transforms)

        # Apply the transforms
        data = all_transforms(data)

        # Return value
        return data



class TrajectoryDatasetTestInterpolation(Dataset):
    def __init__(
        self, 
        folder_path: str,
        pkl_path: str,
        expected_time_dim: int,
        subsample = None,
        features: list[str] = ["aromatic", "hybridization", "partial_charge", "num_bond_h"],
        transforms: list[str] = ["edge_order|2"],
        remove_hs: bool = True,
        ratio = 1,
    ):
        super().__init__()

        start_frame = 0
        random_frame = False

        # Collect trajectory directories
        self.data = []
        self.mols = []
        self.features = features
        self.transforms = transforms
        self.remove_hs = remove_hs
        self.expected_time_dim = expected_time_dim

        # Iterate directories lazily with scandir
        xtc_data = {}
        for batch_entry in tqdm(os.scandir(folder_path), desc='Processing batches'):
            if not batch_entry.is_dir():
                continue
            for mol_entry in tqdm(os.scandir(batch_entry.path), desc=f'Processing molecules in {batch_entry.name}', leave=False):
                if not mol_entry.is_dir():
                    continue
                base_path = mol_entry.path
                pdb_path = os.path.join(base_path, 'system.pdb')
                if not os.path.isfile(pdb_path):
                    continue

                # Read SMILES first (cheap)
                smiles_file = os.path.join(base_path, 'smiles.txt')
                try:
                    with open(smiles_file, 'r') as f:
                        smiles = f.readline().strip()
                except IOError:
                    continue

                # Load molecule and trajectory
                xtc_file = os.path.join(base_path, 'traj.xtc')
    
                # Track count per SMILES
                if smiles not in xtc_data:
                    xtc_data[smiles] = (xtc_file, pdb_path)

        # Open the frame pickle
        with open(pkl_path, 'rb') as f:
            traj_data = pickle.load(f)

        # Get only the molecules that have the frame data
        keys = [smiles for smiles in traj_data if len(traj_data[smiles]) == 3]
        print("Number of Molecules with Frame Data: ", len(keys))
        if type(subsample) is float:
            total = len(keys)
            number = int(total * subsample)
            rng = random.Random(0)
            chosen_idxs = rng.sample(range(total), k=number)
            new_keys = [keys[i] for i in chosen_idxs]
            keys = new_keys
        elif type(subsample) is str:
            with open(subsample, 'rb') as f:
                new_keys = set(pickle.load(f))
            keys = set(keys) & new_keys
            print(keys)
            keys = sorted(keys)
            total = len(keys)
            number = total
            rng = random.Random(0)
            chosen_idxs = rng.sample(range(total), k=number)
            new_keys = [keys[i] for i in chosen_idxs]
            keys = new_keys

        new_dict = {k:traj_data[k] for k in keys}
        traj_data = new_dict
        print("Number of molecules after we have sampled: ", len(traj_data))

        # Build the final index with repetition ratio
        self.data = []
        for smiles, data in traj_data.items():
            mol, start_frames, end_frames = data['rdmol'], data['start_frames'], data['end_frames']
            traj_things = xtc_data[smiles]
            assert start_frames.shape[0] == end_frames.shape[0]
            assert start_frames.shape[0] == 1000
            for i in range(int(1000 * ratio)):
                self.data.append((smiles, mol, start_frames[i], end_frames[i], traj_things, i))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Get paths
        smiles, mol, start_frame, end_frame, (xtc_file, pdb_path), i = self.data[idx]

        # Align
        traj = md.load(xtc_file, top=pdb_path)
        traj.xyz[1, ...] = start_frame.numpy() / 10 # Inject as nm
        traj.xyz[2, ...] = end_frame.numpy() / 10 # Inject as nm
        traj.superpose(traj, frame=0)

        # Remove Hs if needed
        if self.remove_hs:
            z = mol_2d(mol)[0]
            keep_idxs = get_keep_atoms(mol, z)
            traj = traj.atom_slice(keep_idxs)
            mol = Chem.RemoveHs(mol)

        # Center and align the trajectory (in memory)
        traj.center_coordinates()
        start_frame = torch.tensor(traj.xyz[1, ...] * 10)
        end_frame = torch.tensor(traj.xyz[2, ...] * 10)
       
        # Convert coordinates: traj.xyz shape is (N, 3, T)
        N = start_frame.shape[0]
        coords = torch.zeros((N, 3, self.expected_time_dim))
        coords[..., 0] = start_frame
        coords[..., -1] = end_frame

        # Handle graph featurization
        z, edge_index, edge_type, features = mol_2d(mol, self.features)

        # Get the conditioning
        conditioning = torch.zeros(self.expected_time_dim, dtype=torch.bool) # Will contain True for conditioning frames
        conditioning[0] = True
        conditioning[-1] = True
        denoise_coords = coords[:, :, ~conditioning]

        # Make sure sizes make sense
        assert z.shape[0] == coords.shape[0]
        assert denoise_coords.shape[-1] == self.expected_time_dim - torch.sum(conditioning)

        # Create Data object
        data = Data(
            x=z,  # [N,]
            pos=denoise_coords,  # [N, 3, T - C]
            edge_index=edge_index,  # [2, M] 
            edge_attr=edge_type,  # [M, 1]
            x_features=features,  # [N, F]
            original_frames=coords, # [N, 3, T]
            smiles=smiles,
            rdmol=mol,
            conf_idx=i,
        )

        # Apply the transforms
        applied_transforms = []
        for transform in self.transforms:
            key_arg = transform.split('|')
            if len(key_arg) == 1:
                transform_key = key_arg[0]
                applied_transforms.append(TRANSFORMS[transform_key]())
            else:
                transform_key = key_arg[0]
                transform_arg = int(key_arg[1])
                applied_transforms.append(TRANSFORMS[transform_key](transform_arg))
        all_transforms = Compose(applied_transforms)

        # Apply the transforms
        data = all_transforms(data)

        # Return value
        return data


class TrajectoryDatasetTestUncond(Dataset):
    def __init__(
        self, 
        folder_path: str,
        expected_time_dim: int,
        features: list[str] = ["aromatic", "hybridization", "partial_charge", "num_bond_h"],
        transforms: list[str] = ["edge_order|2"],
        remove_hs: bool = True,
        num_frames: int | None = None,
        frame_rate: int | None = None,
        start_frame: int | None = None,
        random_frame: bool = False,
        ratio=2,
        subsample = None,
    ):
        super().__init__()

        # Collect trajectory directories
        self.data = []
        self.mols = []
        self.features = features
        self.transforms = transforms
        self.num_frames = num_frames
        self.start_frame = start_frame
        self.remove_hs = remove_hs
        self.frame_rate = frame_rate
        self.random_frame = True
        self.expected_time_dim = expected_time_dim

        # Accumulate raw data per SMILES
        traj_data: dict[str, list[Any, md.Trajectory, int]] = {}

        # Iterate directories lazily with scandir
        for batch_entry in tqdm(os.scandir(folder_path), desc='Processing batches'):
            if not batch_entry.is_dir():
                continue
            for mol_entry in tqdm(os.scandir(batch_entry.path), desc=f'Processing molecules in {batch_entry.name}', leave=False):
                if not mol_entry.is_dir():
                    continue
                base_path = mol_entry.path
                pdb_path = os.path.join(base_path, 'system.pdb')
                if not os.path.isfile(pdb_path):
                    continue

                # Read SMILES first (cheap)
                smiles_file = os.path.join(base_path, 'smiles.txt')
                try:
                    with open(smiles_file, 'r') as f:
                        smiles = f.readline().strip()
                except IOError:
                    continue

                # Load molecule and trajectory
                mol_file = os.path.join(base_path, 'mol.pkl')
                xtc_file = os.path.join(base_path, 'traj.xtc')
    
                # Track count per SMILESx
                if smiles in traj_data:
                    traj_data[smiles][3] += 1
                else:
                    traj_data[smiles] = [mol_file, xtc_file, pdb_path, 1]

        new_dict = traj_data
        if subsample is not None:
            new_dict = {}
            print("Subsampling the molecules")
            keys = list(traj_data.keys())
            total = len(keys)
            number = int(total * subsample)
            rng = random.Random(0)
            chosen_idxs = rng.sample(range(total), k=number)
            new_keys = [keys[i] for i in chosen_idxs]
            new_dict = {k:traj_data[k] for k in new_keys}
        print(f"New number of molecules {len(new_dict)} vs the old {len(traj_data)}")
        traj_data = new_dict

        # Build the final index with repetition ratio
        self.data: list[tuple[str, Any, md.Trajectory, int]] = []
        for smiles, (mol_file, xtc_file, pdb_path, count) in traj_data.items():
            num = int(count * ratio)
            # print(num)
            for rep_idx in range(num):
                self.data.append((smiles, mol_file, xtc_file, pdb_path, rep_idx))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Get paths
        smiles, mol_file, xtc_file, pdb_path, i = self.data[idx]
        mol = pickle.load(open(mol_file, 'rb'))
        traj = md.load(xtc_file, top=pdb_path)

        # Remove Hs if needed
        if self.remove_hs:
            z = mol_2d(mol)[0]
            keep_idxs = get_keep_atoms(mol, z)
            traj = traj.atom_slice(keep_idxs)
            mol = Chem.RemoveHs(mol)

        # Center and align the trajectory (in memory)
        # The trajectories were already prealigned
        traj.center_coordinates()

        # Subsample frames evenly (e.g., every Nth frame)
        if self.random_frame:
            max_start = NUM_FRAMES - self.num_frames
            start = np.random.randint(0, max_start + 1)
        else:
            start = self.start_frame
        end = start + self.num_frames
        traj = traj[start:end:self.frame_rate]

        # Convert the traj coordinates to Angstroms because MDTraj uses Nm
        traj.xyz *= 10

        # Convert coordinates: traj.xyz shape is (T, N, 3)
        coords = torch.tensor(traj.xyz, dtype=torch.float32).permute(1, 2, 0)  # [N, 3, T]

        # Handle graph featurization
        z, edge_index, edge_type, features = mol_2d(mol, self.features)

        # Get the conditioning
        conditioning = torch.zeros(self.expected_time_dim, dtype=torch.bool) # Will contain True for conditioning frames
        denoise_coords = coords[:, :, ~conditioning]

        # Make sure sizes make sense
        assert z.shape[0] == coords.shape[0]
        assert denoise_coords.shape[-1] == self.expected_time_dim - torch.sum(conditioning)

        # Create Data object
        data = Data(
            x=z,  # [N,]
            pos=denoise_coords,  # [N, 3, T - C]
            edge_index=edge_index,  # [2, M] 
            edge_attr=edge_type,  # [M, 1]
            x_features=features,  # [N, F]
            original_frames=coords, # [N, 3, T]
            smiles=smiles,
            rdmol=mol,
            conf_idx=i,
        )

        # Apply the transforms
        applied_transforms = []
        for transform in self.transforms:
            key_arg = transform.split('|')
            if len(key_arg) == 1:
                transform_key = key_arg[0]
                applied_transforms.append(TRANSFORMS[transform_key]())
            else:
                transform_key = key_arg[0]
                transform_arg = int(key_arg[1])
                applied_transforms.append(TRANSFORMS[transform_key](transform_arg))
        all_transforms = Compose(applied_transforms)

        # Apply the transforms
        data = all_transforms(data)

        # Return value
        return data


class TrajectoryDatasetTestForward(Dataset):
    def __init__(
        self, 
        folder_path: str,
        expected_time_dim: int,
        conditioning: Literal["none", "forward", "interpolation"],
        features: list[str] = ["aromatic", "hybridization", "partial_charge", "num_bond_h"],
        transforms: list[str] = ["edge_order|2"],
        remove_hs: bool = True,
        num_frames: int | None = None,
        frame_rate: int | None = None,
        subsample = None,
    ):
        super().__init__()

        start_frame = 0
        random_frame = False

        # Collect trajectory directories
        self.data = []
        self.mols = []
        self.features = features
        self.transforms = transforms
        self.num_frames = num_frames
        self.start_frame = start_frame
        self.remove_hs = remove_hs
        self.frame_rate = frame_rate
        self.random_frame = random_frame
        self.expected_time_dim = expected_time_dim
        self.conditioning = conditioning

        # Accumulate raw data per SMILES
        traj_data: dict[str, list[Any, md.Trajectory, int]] = {}

        # Iterate directories lazily with scandir
        for batch_entry in tqdm(os.scandir(folder_path), desc='Processing batches'):
            if not batch_entry.is_dir():
                continue
            for mol_entry in tqdm(os.scandir(batch_entry.path), desc=f'Processing molecules in {batch_entry.name}', leave=False):
                if not mol_entry.is_dir():
                    continue
                base_path = mol_entry.path
                pdb_path = os.path.join(base_path, 'system.pdb')
                if not os.path.isfile(pdb_path):
                    continue

                # Read SMILES first (cheap)
                smiles_file = os.path.join(base_path, 'smiles.txt')
                try:
                    with open(smiles_file, 'r') as f:
                        smiles = f.readline().strip()
                except IOError:
                    continue

                # Load molecule and trajectory
                mol_file = os.path.join(base_path, 'mol.pkl')
                xtc_file = os.path.join(base_path, 'traj.xtc')
    
                # Track count per SMILESx
                if smiles not in traj_data:
                    traj_data[smiles] = []
                traj_data[smiles].append((mol_file, xtc_file, pdb_path))

        if type(subsample) is float:
            print("Subsampling the molecules")
            keys = list(traj_data.keys())
            total = len(keys)
            number = int(total * subsample)
            rng = random.Random(0)
            chosen_idxs = rng.sample(range(total), k=number)
            new_keys = [keys[i] for i in chosen_idxs]
            new_dict = {k:traj_data[k] for k in new_keys}
            print(f"New number of molecules {len(new_dict)} vs the old {len(traj_data)}")
            traj_data = new_dict
        elif type(subsample) is str:
            print("Subsampling the molecules using provided path")
            with open(subsample, 'rb') as f:
                new_keys = pickle.load(f)
            new_dict = {k:traj_data[k] for k in new_keys}
            print(f"New number of molecules {len(new_dict)} vs the old {len(traj_data)}")
            traj_data = new_dict

        # Build the final index with repetition ratio
        self.data: list[tuple[str, Any, md.Trajectory, int]] = []
        for smiles, traj_data in traj_data.items():
            for rep_idx, (mol_file, xtc_file, pdb_path) in enumerate(traj_data):
                self.data.append((smiles, mol_file, xtc_file, pdb_path, rep_idx))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Get paths
        smiles, mol_file, xtc_file, pdb_path, i = self.data[idx]
        mol = pickle.load(open(mol_file, 'rb'))
        traj = md.load(xtc_file, top=pdb_path)

        # Remove Hs if needed
        if self.remove_hs:
            z = mol_2d(mol)[0]
            keep_idxs = get_keep_atoms(mol, z)
            traj = traj.atom_slice(keep_idxs)
            mol = Chem.RemoveHs(mol)

        # Center and align the trajectory (in memory)
        # The trajectories were already prealigned
        traj.center_coordinates()

        # Subsample frames evenly (e.g., every Nth frame)
        if self.random_frame:
            max_start = NUM_FRAMES - self.num_frames
            start = np.random.randint(0, max_start + 1)
        else:
            start = self.start_frame
        end = start + self.num_frames
        traj = traj[start:end:self.frame_rate]

        # Convert the traj coordinates to Angstroms because MDTraj uses Nm
        traj.xyz *= 10

        # Convert coordinates: traj.xyz shape is (T, N, 3)
        coords = torch.tensor(traj.xyz, dtype=torch.float32).permute(1, 2, 0)  # [N, 3, T]

        # Handle graph featurization
        z, edge_index, edge_type, features = mol_2d(mol, self.features)

        # Get the conditioning
        conditioning = torch.zeros(self.expected_time_dim, dtype=torch.bool) # Will contain True for conditioning frames
        if self.conditioning != 'none':
            conditioning[0] = True
        if self.conditioning == 'interpolation':
            conditioning[-1] = True
        denoise_coords = coords[:, :, ~conditioning]

        # Make sure sizes make sense
        assert z.shape[0] == coords.shape[0]
        assert denoise_coords.shape[-1] == self.expected_time_dim - torch.sum(conditioning)

        # Create Data object
        data = Data(
            x=z,  # [N,]
            pos=denoise_coords,  # [N, 3, T - C]
            edge_index=edge_index,  # [2, M] 
            edge_attr=edge_type,  # [M, 1]
            x_features=features,  # [N, F]
            original_frames=coords, # [N, 3, T]
            smiles=smiles,
            rdmol=mol,
            conf_idx=i,
        )

        # Apply the transforms
        applied_transforms = []
        for transform in self.transforms:
            key_arg = transform.split('|')
            if len(key_arg) == 1:
                transform_key = key_arg[0]
                applied_transforms.append(TRANSFORMS[transform_key]())
            else:
                transform_key = key_arg[0]
                transform_arg = int(key_arg[1])
                applied_transforms.append(TRANSFORMS[transform_key](transform_arg))
        all_transforms = Compose(applied_transforms)

        # Apply the transforms
        data = all_transforms(data)

        # Return value
        return data


def get_test_data(config):
    if config.dataset.type == 'conformer':
        print("Loading conformer testing dataset")
        test_dataset = ConformerDatasetTest(
            pkl_path=config.dataset.test_conf_path, 
            ratio=config.test.ratio,
            subsample=config.test.subsample,
            features=config.dataset.features,
            transforms=config.dataset.transforms,
            filter=config.dataset.filter_data,
            remove_hs=config.dataset.remove_hs
        )
    elif config.dataset.type == 'trajectory' and config.denoiser.conditioning =='forward':
        print("Loading trajectory testing dataset for FORWARD")
        test_dataset = TrajectoryDatasetTestForward(
            folder_path=config.dataset.test_traj_dir, 
            expected_time_dim=config.dataset.expected_time_dim,
            features=config.dataset.features,
            num_frames=config.dataset.num_frames,
            transforms=config.dataset.transforms,
            frame_rate=config.dataset.frame_rate,
            remove_hs=config.dataset.remove_hs,
            conditioning=config.denoiser.conditioning,
            subsample=config.test.subsample
        )
    elif config.dataset.type == 'trajectory' and config.denoiser.conditioning =='interpolation':
        print("Loading trajectory testing dataset for INTERPOLATION")
        test_dataset = TrajectoryDatasetTestInterpolation(
            folder_path=config.dataset.test_traj_dir,
            subsample=config.test.subsample,
            pkl_path=config.test.pkl_path, 
            expected_time_dim=config.dataset.expected_time_dim,
            features=config.dataset.features,
            transforms=config.dataset.transforms,
            remove_hs=config.dataset.remove_hs,
            ratio=config.test.ratio,
        )
    elif config.dataset.type == 'trajectory' and config.denoiser.conditioning =='none':
        print("Loading trajectory testing dataset for NO CONDITION")
        test_dataset = TrajectoryDatasetTestUncond(
            folder_path=config.dataset.test_traj_dir,
            subsample=config.test.subsample,
            num_frames=config.dataset.num_frames,
            expected_time_dim=config.dataset.expected_time_dim,
            features=config.dataset.features,
            transforms=config.dataset.transforms,
            remove_hs=config.dataset.remove_hs,
            ratio=config.test.ratio,
            frame_rate=config.dataset.frame_rate
        )
    else:
        raise NotImplementedError()

    return test_dataset