
from rdkit import Chem, RDLogger, DataStructs
from rdkit.Chem.rdchem import BondType as BT

import os
import os.path as osp
import pathlib
import hashlib
from typing import Any, Sequence

import torch
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np
from torch_geometric.data import Data, InMemoryDataset, download_url
from hydra.utils import get_original_cwd

from didigress.metrics.properties import qed, drd2
from didigress.datasets.dataset_utils import mol_to_torch_geometric, remove_hydrogens, StatisticsMolecule
from didigress.datasets.dataset_utils import load_pickle, save_pickle
from didigress.datasets.abstract_dataset import AbstractDatasetInfos, AbstractDataModule, AbstractAdaptiveDataModule
from didigress.metrics.metrics_utils import compute_all_statistics, atom_type_counts, edge_counts, charge_counts
from didigress.utils import PlaceHolder, get_atom_counts, get_mol_fingerprint

from didigress.datasets.abstract_dataset import AbstractDatasetInfos

from rdkit.Chem.MolStandardize import rdMolStandardize
from rdkit.Chem import rdMolDescriptors, Crippen, AllChem

from didigress.utils import rstrip1, clean_mol, graph2mol

TRAIN_HASH = '05ad85d871958a05c02ab51a4fde8530'
VALID_HASH = 'e53db4bff7dc4784123ae6df72e3b1f0'
TEST_HASH = '677b757ccec4809febd83850b43e1616'


def files_exist(files) -> bool:
    # NOTE: We return `False` in case `files` is empty, leading to a
    # re-processing of files on every instantiation.
    return len(files) != 0 and all([osp.exists(f) for f in files])


def to_list(value: Any) -> Sequence:
    if isinstance(value, Sequence) and not isinstance(value, str):
        return value
    else:
        return [value]


def compare_hash(output_file: str, correct_hash: str) -> bool:
    """
    Computes the md5 hash of a SMILES file and check it against a given one
    Returns false if hashes are different
    """
    output_hash = hashlib.md5(open(output_file, 'rb').read()).hexdigest()
    if output_hash != correct_hash:
        print(f'{output_file} file has different hash, {output_hash}, than expected, {correct_hash}!')
        return False

    return True

class RemoveYTransform:
    def __call__(self, data):
        data.guidance = torch.zeros((1, 0), dtype=torch.float)
        return data

fprint_size = 2048
# full_atom_encoder = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4, 'B': 5, 
#                      'Br': 6, 'Cl': 7, 'I': 8, 'P': 9, 'S': 10, 
#                      'Se': 11, 'Si': 12}
full_atom_encoder = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4, 'B': 5, 'Br': 6, 'Cl': 7, 'I': 8, 'P': 9, 'S': 10, 'Se': 11, 'Si': 12,
                     'C-1': 13, 'N+1': 14, 'N-1': 15, 'O+1': 16, 'O-1': 17, 'B-1': 18, 'P+1': 19, 'S+1': 20}
"""
atom_stats =  {'H': 0, 'C': 0, 'N': 0, 'O': 0, 'F': 0, 'B': 0, 'Br': 0, 'Cl': 0, 'I': 0, 'P': 0, 'S': 0, 'Se': 0, 'Si': 0, 
               'C+1': 69, 'C-1': 621, 'N+1': 78851, 'N-1': 2210, 'O+1': 113, 'O-1': 64957, 'F+1': 3, 'F-1': 1, 'B+1': 0, 'B-1': 220, 'Br+1': 0, 'Br-1': 1, 
               'Cl+1': 2, 'Cl-1': 1, 'I+1': 10, 'I-1': 0, 'P+1': 417, 'P-1': 2, 'S+1': 4578, 'S-1': 73, 'Se+1': 19, 'Se-1': 1, 'Si+1': 0, 'Si-1': 2, 
               'C+2': 0, 'C-2': 0, 'N+2': 0, 'N-2': 0, 'O+2': 0, 'O-2': 0, 'F+2': 0, 'F-2': 0, 'B+2': 0, 'B-2': 0, 'Br+2': 1, 'Br-2': 0, 'Cl+2': 1, 'Cl-2': 0, 
               'I+2': 1, 'I-2': 0, 'P+2': 0, 'P-2': 0, 'S+2': 0, 'S-2': 0, 'Se+2': 0, 'Se-2': 0, 'Si+2': 0, 'Si-2': 0, 'C+3': 0, 'C-3': 0, 
               'N+3': 0, 'N-3': 0, 'O+3': 0, 'O-3': 0, 'F+3': 0, 'F-3': 0, 'B+3': 0, 'B-3': 0, 'Br+3': 0, 'Br-3': 0, 'Cl+3': 1, 'Cl-3': 0, 
               'I+3': 1, 'I-3': 0, 'P+3': 0, 'P-3': 0, 'S+3': 0, 'S-3': 0, 'Se+3': 0, 'Se-3': 0, 'Si+3': 0, 'Si-3': 0, 'C+4': 0, 'C-4': 0, 
               'N+4': 0, 'N-4': 0, 'O+4': 0, 'O-4': 0, 'F+4': 0, 'F-4': 0, 'B+4': 0, 'B-4': 0, 'Br+4': 0, 'Br-4': 0, 'Cl+4': 0, 'Cl-4': 0, 
               'I+4': 0, 'I-4': 0, 'P+4': 0, 'P-4': 0, 'S+4': 0, 'S-4': 0, 'Se+4': 0, 'Se-4': 0, 'Si+4': 0, 'Si-4': 0}
"""
class FeatureExtractorTransform:
    def __init__(self, guidance_target):
        self.guidance_target = guidance_target

    def __call__(self, data):
        guidance = []
        k_len = 3 + len(full_atom_encoder)
        if 'logp' in self.guidance_target:
            guidance.append(data.guidance[..., 0:1])
        if 'bertz' in self.guidance_target:
            guidance.append(data.guidance[..., 1:2])
        if 'TPSA' in self.guidance_target:
            guidance.append(data.guidance[..., 2:3])
        if 'isomer' in self.guidance_target:
            guidance.append(data.guidance[..., 3:k_len])
        if 'fprint' in self.guidance_target:
            guidance.append(data.guidance[..., k_len:k_len+fprint_size])
        
        data.guidance = torch.hstack(tuple(guidance))

        return data
    
class DummyDatasetInfos():
    def __init__(self, cfg, atom_decoder):
        self.cfg=cfg
        self.num_node_types = len(atom_decoder)

    def to_one_hot(self, X,  E, node_mask, charges = None):
        X = F.one_hot(X, num_classes=self.num_node_types).float()
        E = F.one_hot(E, num_classes=5).float()

        if(self.cfg.features.use_charges):
            charges = F.one_hot(charges + 1, num_classes=5).float()
        
        placeholder = PlaceHolder(X=X, E=E, y=None, charges=charges, pos=None)
        pl = placeholder.mask(node_mask)
        return pl.X, pl.E, pl.charges
    

class GuacamolDataset(InMemoryDataset):
    train_url = 'https://figshare.com/ndownloader/files/13612760'
    test_url = 'https://figshare.com/ndownloader/files/13612757'
    valid_url = 'https://figshare.com/ndownloader/files/13612766'
    all_url = 'https://figshare.com/ndownloader/files/13612745'

    def __init__(self, split, root, remove_h: bool, cfg, transform=None, pre_transform=None, pre_filter=None):
        self.split = split
        if self.split == 'train':
            self.file_idx = 0
        elif self.split == 'val':
            self.file_idx = 1
        else:
            self.file_idx = 2
        self.remove_h = remove_h

        self.cfg = cfg
        self.use_charges = self.cfg.features.use_charges
        self.charges_policy = cfg.features.charges_policy

        self.atom_encoder = full_atom_encoder.copy()

        if remove_h:
            self.atom_encoder = {k: v - 1 for k, v in self.atom_encoder.items() if k != 'H'}

        # Dataset infos has this as well but we need it in process()
        self.atom_decoder = [key for key in self.atom_encoder.keys()]

        super().__init__(root, transform, pre_transform, pre_filter)
        self.data, self.slices = torch.load(self.processed_paths[0])


        valencies=load_pickle(self.processed_paths[5])

        charge_types=torch.from_numpy(np.load(self.processed_paths[4], allow_pickle=True)).float()
        
        self.statistics = StatisticsMolecule(num_nodes=load_pickle(self.processed_paths[1]),
                                     node_types=torch.from_numpy(np.load(self.processed_paths[2])).float(),
                                     edge_types=torch.from_numpy(np.load(self.processed_paths[3])).float(),
                                     cfg=self.cfg,
                                     charge_types=charge_types,
                                     valencies=valencies,
                                     bond_lengths=None,
                                     bond_angles=None)
        self.smiles = load_pickle(self.processed_paths[6])

    @property
    def raw_file_names(self):
        return ['guacamol_v1_train.smiles', 'guacamol_v1_valid.smiles', 'guacamol_v1_test.smiles']

    @property
    def split_file_name(self):
        return ['guacamol_v1_train.smiles', 'guacamol_v1_valid.smiles', 'guacamol_v1_test.smiles']


    @property
    def split_paths(self):
        r"""The absolute filepaths that must be present in order to skip
        splitting."""
        files = to_list(self.split_file_name)
        return [osp.join(self.raw_dir, f) for f in files]

    @property
    def processed_file_names(self):
        h = 'noh' if self.remove_h else 'h'
        if self.split == 'train':
            return [f'train_{h}.pt', f'train_n_{h}.pickle', f'train_node_types_{h}.npy', f'train_edge_types_{h}.npy',
                    f'train_charges_{h}.npy', f'train_valency_{h}.pickle', 'train_smiles.pickle']
        elif self.split == 'val':
            return [f'val_{h}.pt', f'val_n_{h}.pickle', f'val_node_types_{h}.npy', f'val_edge_types_{h}.npy',
                    f'val_charges_{h}.npy', f'val_valency_{h}.pickle', 'val_smiles.pickle']
        else:
            return [f'test_{h}.pt', f'test_n_{h}.pickle', f'test_node_types_{h}.npy', f'test_edge_types_{h}.npy',
                    f'test_charges_{h}.npy', f'test_valency_{h}.pickle', 'test_smiles.pickle']
        

    def download(self):
        import rdkit  # noqa
        train_path = download_url(self.train_url, self.raw_dir)
        os.rename(train_path, osp.join(self.raw_dir, 'guacamol_v1_train.smiles'))
        train_path = osp.join(self.raw_dir, 'guacamol_v1_train.smiles')

        test_path = download_url(self.test_url, self.raw_dir)
        os.rename(test_path, osp.join(self.raw_dir, 'guacamol_v1_test.smiles'))
        test_path = osp.join(self.raw_dir, 'guacamol_v1_test.smiles')

        valid_path = download_url(self.valid_url, self.raw_dir)
        os.rename(valid_path, osp.join(self.raw_dir, 'guacamol_v1_valid.smiles'))
        valid_path = osp.join(self.raw_dir, 'guacamol_v1_valid.smiles')

        # check the hashes
        # Check whether the md5-hashes of the generated smiles files match
        # the precomputed hashes, this ensures everyone works with the same splits.
        valid_hashes = [
            compare_hash(train_path, TRAIN_HASH),
            compare_hash(valid_path, VALID_HASH),
            compare_hash(test_path, TEST_HASH),
        ]

        if not all(valid_hashes):
            raise SystemExit('Invalid hashes for the dataset files')

        print('Dataset download successful. Hashes are correct.')

        if files_exist(self.split_paths):
            return

    def process(self):
        RDLogger.DisableLog('rdApp.*')

        if(self.file_idx != 0):
            root_dir                 = pathlib.Path(os.path.realpath(__file__)).parents[2]
            train_path               = "/data/guacamol/raw/new_train.smiles"
            smiles_path              = str(root_dir) + train_path
            my_file                  = open(smiles_path, "r")
            data                     = my_file.read()
            train_smiles             = data.split("\n")


        path = self.split_paths[self.file_idx]
        smile_list = open(self.split_paths[self.file_idx]).readlines()

        data_list = []
        smiles_kept = []

        all_smiles = []

        unmatches = 0
        exceptions = 0
        string_unmatches = 0
        n_already_present = 0
        num_errors = 0

        #full_atom_encoder instead of self.atom_encoder because we also need H
        atom_encoder_keys = sorted(full_atom_encoder.keys(), key=full_atom_encoder.get)

        print("Initial size: ", len(smile_list))
        for i, original_smile in enumerate(tqdm(smile_list)): #614150
            #rimuove \n
            smiles_2D = rstrip1(original_smile, "\n")

            #otteniamo la molecola dallo smiles_2D ed eseguiamo il preprocessing
            mol = clean_mol(smiles_2D)
            smiles_2D = Chem.MolToSmiles(mol)
            if smiles_2D is None:
                num_errors += 1

            # data = mol2graph(mol, self.types, self.bonds, i, smiles_2D, True, build_with_charges)
            data = mol_to_torch_geometric(mol=mol, atom_encoder=full_atom_encoder, smiles=smiles_2D, cfg=self.cfg)
            
            #questo succede se abbiamo un grafo vuoto (può succedere)
            #o se un atomo conteneva cariche non neutrali quando non
            #volevamo tenerne
            if(data == None):
                continue

            if self.remove_h:
                data = remove_hydrogens(data=data,cfg=self.cfg)

            if self.pre_filter is not None and not self.pre_filter(data):
                continue
            if self.pre_transform is not None:
                data = self.pre_transform(data)

            
            guidance_sz = 3+len(atom_encoder_keys)+fprint_size
            guidance = torch.zeros((1, guidance_sz))
            estimated_logp  = Crippen.MolLogP(mol)
            estimated_bertz = Chem.GraphDescriptors.BertzCT(mol)
            estimated_TPSA  = Chem.rdMolDescriptors.CalcTPSA(mol)
            estimated_isomer= torch.tensor(get_atom_counts(smiles_2D, predefined_atoms=atom_encoder_keys))
            estimated_fprint= get_mol_fingerprint(mol)

            k_len = 3+len(atom_encoder_keys)
            guidance[0, 0]  = estimated_logp
            guidance[0, 1]  = estimated_bertz
            guidance[0, 2]  = estimated_TPSA
            guidance[0, 3:k_len]                = estimated_isomer
            guidance[0, k_len:k_len+fprint_size]= estimated_fprint
            data.guidance   = guidance
            data.n_nodes    = data.x.size(0)

            node_stats      = atom_type_counts(data_list=[data], num_classes=len(self.atom_encoder), verbose=False)
            edge_stats      = edge_counts(data_list=[data], verbose=False)
            if(np.isnan(edge_stats).all()):
                print("found a NaN edge tensor")
                edge_stats = np.zeros_like(edge_stats)
                edge_stats[0] = 1 
            
            charge_types = charge_counts(data_list=[data], num_classes=len(self.atom_encoder), 
                                        # charges_dic={-2: 0, -1: 1, 0: 2, 1: 3, 2: 4, 3: 5},
                                        charges_dic={0:0}, 
                                        verbose=False)
            
            data.node_stats = torch.from_numpy(node_stats).unsqueeze(0)
            data.edge_stats = torch.from_numpy(edge_stats).unsqueeze(0)
            data.charge_types = torch.from_numpy(charge_types).unsqueeze(0)

            #controllo sanità del grafo generato. Qui viene effettivamente 
            #ricostruito il Chem.RWMol della molecola usando la variabile "data" 
            #sopra contenente il grafo della molecola e ne controlla la qualità 
            #(es: se si è spezzata o meno.). Se passa i test, finisce nel training/val/test set.
            #=> PERFETTO PER VEDERE SE LOGP ecc CAMBIANO UNA VOLTA RIASSEMBLATA LA MOLECOLA
            #   PARTENDO DAL GRAFO.
            # DummyDatasetInfos is, like, the worst thing ever from a software
            # engineering point of view since it replaces the true dataset infos
            # which is not yet instantiated at this point (actually, it uses 
            # this own class as constructor parameter). But the only thing in graph2mol
            # that uses it is to_dense, which in turn uses only to_one_hot, so
            # it's no biggie
            dummy_dataset_infos      = DummyDatasetInfos(cfg=self.cfg, atom_decoder=self.atom_decoder)
            reconstructed_mol        = graph2mol(data, dataset_infos=dummy_dataset_infos,
                                                 atom_decoder=self.atom_decoder)

            try:
                reconstructed_mol    = clean_mol(reconstructed_mol)
                reconstructed_smiles = Chem.MolToSmiles(reconstructed_mol)
            except:
                exceptions = exceptions + 1
                continue

            #Skips if an element of the test set was also in the training set
            if(self.file_idx != 0):
                if(reconstructed_smiles in train_smiles):
                    n_already_present += 1
                    continue

            ###############################################################
            #se gli smiles della molecola ricostruita e quello originale 
            #sono identici, non è neanche necessario controllarne le proprietà.
            #Se differiscono, potrebbero differirne anche le proprietà. Vediamo un po...
            if(smiles_2D == reconstructed_smiles):
                tmp_mol_original_smile       = mol
                tmp_mol_smiles_2D            = reconstructed_mol
                print_smiles                 = False

                #bertz estimated_TPSA estimated_isomer estimated_fprint
                tmp_mol_original_smile_logp  = Crippen.MolLogP(tmp_mol_original_smile)
                tmp_mol_original_smile_bertz = Chem.GraphDescriptors.BertzCT(tmp_mol_original_smile)
                tmp_mol_original_smile_TPSA  = Chem.rdMolDescriptors.CalcTPSA(tmp_mol_original_smile)
                tmp_mol_original_smile_isom  = torch.tensor(get_atom_counts(smiles_2D, predefined_atoms=atom_encoder_keys))
                tmp_mol_original_smile_fprint= get_mol_fingerprint(tmp_mol_original_smile)
                
                tmp_mol_smiles_2D_logp       = Crippen.MolLogP(tmp_mol_smiles_2D)
                tmp_mol_smiles_2D_bertz      = Chem.GraphDescriptors.BertzCT(tmp_mol_smiles_2D)
                tmp_mol_smiles_2D_TPSA       = Chem.rdMolDescriptors.CalcTPSA(tmp_mol_smiles_2D)
                tmp_mol_smiles_2D_isom       = torch.tensor(get_atom_counts(reconstructed_smiles, predefined_atoms=atom_encoder_keys))
                tmp_mol_smiles_2D_fprint     = get_mol_fingerprint(tmp_mol_smiles_2D)
                
                if(abs(tmp_mol_original_smile_logp - tmp_mol_smiles_2D_logp) > 1e-5):
                    print_smiles=True
                if(abs(tmp_mol_original_smile_bertz - tmp_mol_smiles_2D_bertz) > 1e-5):
                    print_smiles=True
                if(abs(tmp_mol_original_smile_TPSA - tmp_mol_smiles_2D_TPSA) > 1e-5):
                    print("tpsa check")
                    print_smiles=True
                if(torch.sum(torch.abs(tmp_mol_original_smile_isom - tmp_mol_smiles_2D_isom)) > 0): 
                    print("atom count check")
                    print_smiles=True
                if(torch.sum(torch.abs(tmp_mol_original_smile_fprint - tmp_mol_smiles_2D_fprint)) > 0):
                    print("fingerprint check")
                    print_smiles=True
                
                if(print_smiles):
                    reconstructed_smiles = None
                    unmatches            = unmatches + 1
                    num_errors          += 1
            else:
                string_unmatches = string_unmatches + 1
                num_errors      += 1

            ###############################################################

            if reconstructed_smiles is not None and smiles_2D == reconstructed_smiles: 
                try:
                    mol_frags = Chem.rdmolops.GetMolFrags(reconstructed_mol, asMols=True, sanitizeFrags=True)
                    if len(mol_frags) == 1:
                        data_list.append(data)
                        smiles_kept.append(reconstructed_smiles)

                        #DEBUG
                        all_smiles.append(reconstructed_smiles)

                except Chem.rdchem.AtomValenceException:
                    print("Valence error in GetmolFrags")
                except Chem.rdchem.KekulizeException:
                    print("Can't kekulize molecule")

        print("data_list size:", len(data_list))
        torch.save(self.collate(data_list), self.processed_paths[self.file_idx])

        print("Total unmatches: ", unmatches)
        print("exceptions =", exceptions)
        print("string_unmatches: ", string_unmatches)
        print("removed because already present in the training set: ", n_already_present)

        smiles_save_path = osp.join(pathlib.Path(self.raw_paths[0]).parent, f'new_{self.split}.smiles')
        print(smiles_save_path)
        with open(smiles_save_path, 'w') as f:
            f.writelines('%s\n' % s for s in smiles_kept)
        print(f"Number of molecules kept: {len(smiles_kept)} / {len(smile_list)}")

        statistics = compute_all_statistics(data_list, self.atom_encoder, charges_dic={-1: 0, 0: 1, 1: 2}, cfg=self.cfg, use_3d=False)

        save_pickle(statistics.num_nodes, self.processed_paths[1])
        np.save(self.processed_paths[2], statistics.node_types)
        np.save(self.processed_paths[3], statistics.edge_types)
        np.save(self.processed_paths[4], statistics.charge_types)
        save_pickle(statistics.valencies, self.processed_paths[5])
        print("Number of molecules that could not be mapped to smiles: ", num_errors)
        save_pickle(set(all_smiles), self.processed_paths[6])
        torch.save(self.collate(data_list), self.processed_paths[0])


class GuacamolDataModule(AbstractDataModule):
    def __init__(self, cfg):
        self.datadir = cfg.dataset.datadir
        base_path = pathlib.Path(get_original_cwd()).parents[0]
        root_path = os.path.join(base_path, self.datadir)

        target = cfg.guidance.guidance_target
        if len(set(['logp', 'TPSA', 'bertz', 'isomer', 'fprint']).intersection(target)) > 0:
            transform = FeatureExtractorTransform(cfg.guidance.guidance_target)
        else:
            transform = RemoveYTransform()

        train_dataset = GuacamolDataset(split='train', root=root_path, remove_h=cfg.dataset.remove_h, cfg=cfg, transform=transform)
        val_dataset = GuacamolDataset(split='val', root=root_path, remove_h=cfg.dataset.remove_h, cfg=cfg, transform=transform)
        test_dataset = GuacamolDataset(split='test', root=root_path, remove_h=cfg.dataset.remove_h, cfg=cfg, transform=transform)
        self.statistics = {'train': train_dataset.statistics, 'val': val_dataset.statistics,
                           'test': test_dataset.statistics}
        self.remove_h = cfg.dataset.remove_h
        super().__init__(cfg, train_dataset=train_dataset, val_dataset=val_dataset, test_dataset=test_dataset)



class Guacamolinfos(AbstractDatasetInfos):
    def __init__(self, datamodule, cfg):
        self.remove_h = cfg.dataset.remove_h
        self.statistics = datamodule.statistics
        self.name = 'moses'
        self.atom_encoder = full_atom_encoder
        self.collapse_charges = torch.Tensor([-2, -1, 0, 1, 2]).int()
        if self.remove_h:
            self.atom_encoder = {k: v - 1 for k, v in self.atom_encoder.items() if k != 'H'}
        super().complete_infos(datamodule.statistics, self.atom_encoder, cfg)
        
        use_conditional = cfg.guidance.p_uncond >= 0
        medium_extra = {'X': 0, 'E': 0, 'y': 0, 'p': 0,}
        guidance_sizes = {'logp': 1, 'bertz': 1, 'TPSA': 1, 'isomer': len(full_atom_encoder), 'fprint': 1024}   
        self.guidance_dims = 0

        if(use_conditional):
            guidance_mediums = cfg.guidance.guidance_medium
            targets = cfg.guidance.guidance_target

            for target in targets:
                self.guidance_dims += guidance_sizes[target]
                for medium in guidance_mediums:
                    medium_extra[medium] = medium_extra[medium] + guidance_sizes[target]

        print("self.guidance_dims =", self.guidance_dims)
        print("medium_extra =", medium_extra)

        self.use_ins_del= cfg.features.use_ins_del
        IN_ins_del_sz = 2 if self.use_ins_del else 0

        X_sz_out = self.num_node_types + IN_ins_del_sz
        E_sz_out = 5 + IN_ins_del_sz
        y_sz_out = 0
        c_sz_out = 1 + IN_ins_del_sz
        p_sz_out = 3 + IN_ins_del_sz

        X_sz = X_sz_out + medium_extra['X']
        E_sz = E_sz_out + medium_extra['E']
        y_sz = y_sz_out + medium_extra['y'] + 1 #The "+1" is the timestep t
        c_sz = c_sz_out 
        p_sz = p_sz_out + medium_extra['p']

        #TODO: add INS/DEL categories
        self.input_dims = PlaceHolder(X=X_sz, charges=c_sz, E=E_sz, y=y_sz, pos=p_sz, guidance=self.guidance_dims)
        self.output_dims = PlaceHolder(X=X_sz_out, charges=c_sz_out, E=E_sz_out, y=0, pos=p_sz_out, guidance=0)

        self.cfg = cfg
        self.charges_dic = {0: 0}

    def to_one_hot(self, X,  E, node_mask, charges = None):
        X = F.one_hot(X, num_classes=self.num_node_types).float()
        E = F.one_hot(E, num_classes=5).float()

        if(self.cfg.features.use_charges):
            charges = F.one_hot(charges + 1, num_classes=1).float()
        
        placeholder = PlaceHolder(X=X, E=E, y=None, charges=charges, pos=None)
        pl = placeholder.mask(node_mask)
        return pl.X, pl.E, pl.charges

    def one_hot_charges(self, charges):
        return F.one_hot((charges + 1).long(), num_classes=1).float()