import os
import os.path as osp
import pathlib
from typing import Any, Sequence
import pickle
import copy
import re
from multiprocessing import Pool

from multiprocessing import Lock, Process, Queue, current_process
import time
import queue # imported for using queue.Empty exception

from rdkit.Chem import PeriodicTable
from rdkit.Chem import rdChemReactions
import torch
import torch.nn.functional as F
from rdkit import Chem
from rdkit.Chem.rdchem import BondType as BT
from tqdm import tqdm
import numpy as np
import pandas as pd
from torch_geometric.data import Data, InMemoryDataset
from torch_geometric.utils import subgraph
from src.datasets.abstract_dataset import AbstractDataModule, seed_worker
from torch_geometric.loader import DataLoader

from src.utils import graph, mol, setup
from src.datasets.abstract_dataset import AbstractDataModule, AbstractDatasetInfos, DistributionNodes
from src.utils.rdkit import  mol2smiles, build_molecule_with_partial_charges
from src.utils.rdkit import compute_molecular_metrics
import logging

MAX_ATOMS_RXN = 1000
# MAX_NODES_MORE_THAN_PRODUCT = 35 <- this shouldn't be used!

DUMMY_RCT_NODE_TYPE = 'U'

# THESE ARE NOT USED ANYMORE
# size_bins = {
#     'train': [64, 83, 102], # [64,83,102]
#     'test': [250],
#     'val': [250]
# }

# batchsize_bins = { 
#     'train': [32, 16, 8], # [128, 64, 16]
#     'test': [32], # [64]
#     'val': [32] # [64]
# }

def add_chem_bond_types(bond_types):
    '''
        Add the bond types to the list of bond types.
    '''
    new_bond_types = []
    
    for b in bond_types:
        if b=='SINGLE': new_bond_types.append(BT.SINGLE)
        elif b=='DOUBLE': new_bond_types.append(BT.DOUBLE)
        elif b=='TRIPLE': new_bond_types.append(BT.TRIPLE)
        elif b=='AROMATIC': new_bond_types.append(BT.AROMATIC)
        else: new_bond_types.append(b)
    
    return new_bond_types

def get_bond_orders(bond_types):
    '''
        Get the bond orders from the bond types.
    '''
    bond_orders = []
    for b in bond_types:
        if b==BT.SINGLE: bond_orders.append(1)
        if b==BT.DOUBLE: bond_orders.append(2)
        if b==BT.TRIPLE: bond_orders.append(3)
        if b==BT.AROMATIC: bond_orders.append(1.5)
        else: bond_orders.append(0)
    
    return bond_orders

def get_bond_orders_correct(bond_types):
    '''
        Get the bond orders from the bond types.
    '''
    bond_orders = []

    for b in bond_types:
        if b==BT.SINGLE or b=='SINGLE': bond_orders.append(1)
        elif b==BT.DOUBLE or b=='DOUBLE': bond_orders.append(2)
        elif b==BT.TRIPLE or b=='TRIPLE': bond_orders.append(3)
        elif b==BT.AROMATIC or b=='AROMATIC': bond_orders.append(1.5)
        else: bond_orders.append(0)
    
    return bond_orders

raw_files = ['train.csv', 'test.csv', 'val.csv']
processed_files = ['train.pt', 'test.pt', 'val.pt']

logging.basicConfig(filename='subprocess_log.txt', level=logging.DEBUG,
                        format='%(asctime)s %(levelname)s %(name)s %(message)s')

class Dataset(InMemoryDataset):
    def __init__(self, stage, root, atom_types, bond_types, size_test_splits=0, with_explicit_h=False, permute_mols=False,
                 with_formal_charge=False, max_nodes_more_than_product=35, canonicalize_molecule=True, 
                 num_processes=1, add_supernode_edges=False):
        self.stage = stage
        self.root = root
        self.size_test_splits = size_test_splits
        self.with_explicit_h = with_explicit_h
        self.with_formal_charge = with_formal_charge
        self.max_nodes_more_than_product = max_nodes_more_than_product
        self.canonicalize_molecule = canonicalize_molecule
        self.add_supernode_edges = add_supernode_edges
        self.atom_types = atom_types
        self.bond_types = bond_types
        self.permute_mols = permute_mols
        self.num_processes = num_processes
        if self.stage == 'train':
            self.file_idx = 0
        elif self.stage == 'test':
            self.file_idx = 1
        else:
            self.file_idx = 2
        super().__init__(root)
        if 'test_' in self.stage:
            test_path = os.path.join(root, 'processed', self.stage+'.pt')
            self.data, self.slices = torch.load(test_path)
        else:
            self.data, self.slices = torch.load(self.processed_paths[self.file_idx])

    @property
    def raw_file_names(self):
        return raw_files

    @property
    def processed_file_names(self):
        return processed_files

    def split_test_data(self, graphs, size_test_splits=100):
        '''
            (optional) Split data test file to smaller chunks to be used in parallel while testing (e.g. in slurm array jobs).

            Input:
                graphs: a list of processed graph objects (test data)
                size_test_splits: size of one test split to generate. Usually set in the config file.
            Output:
                size_test_splits saved to the dataset's processed directory.
        '''
        filepath = self.processed_paths[self.file_idx].split('.pt')[0]
        for i in range(0, len(graphs), size_test_splits):
            print(f'len(graphs[i:i+size_test_splits]) {len(graphs[i:i+size_test_splits])}\n')
            torch.save(self.collate(graphs[i:i+size_test_splits]), filepath+f'_{int(i/size_test_splits)}.pt')

    # Function to split your dataset into chunks
    def split_dataset(self, dataset, num_chunks):
        # Split the dataset into 'num_chunks' parts and return a list of chunks
        chunk_size = len(dataset) // num_chunks

        return [(dataset[i:i + chunk_size], i) for i in range(0, len(dataset), chunk_size)]

    # Main function to process the dataset in parallel
    def process(self):
        # Split dataset into chunks based on the number of processes
        dataset_chunks = self.split_dataset(open(self.raw_paths[self.file_idx], 'r').readlines(), int(self.num_processes))
        # Create a pool of worker processes
        # with Pool(self.num_processes) as pool:
        #     # Parallelize the sub_process function across the dataset chunks
        #     all_graphs_nested_list = pool.map(self.sub_process, dataset_chunks)
        
        # TODO this is now sequential
        # for dataset_chunk in dataset_chunks:
        #     self.sub_process(dataset_chunk)
        # if self.file_idx != 1:
        #     return

        with Pool(self.num_processes) as pool:
            # Parallelize the sub_process function across the dataset chunks without collecting any output
            for dataset_chunk in dataset_chunks:
                pool.apply_async(self.sub_process, args=(dataset_chunk,))
            pool.close() # Prevents any more tasks from being submitted to the pool
            pool.join() # Wait for the worker processes to exit

        all_graphs = []
        for dataset_chunk in dataset_chunks:
            subprocess_path = os.path.join(self.processed_paths[self.file_idx].split('.pt')[0], f'graphs_{dataset_chunk[1]}.pickle')
            graph = pickle.load(open(subprocess_path, 'rb'))
            all_graphs.append(graph)
            
        every_graph = [g for subgraph in all_graphs for g in subgraph]
        # interim_path = os.path.join('/'.join(self.processed_paths[self.file_idx].split('/')[:-1]), 'nested_list.pickle')
        # pickle.dump(all_graphs_nested_list, open(interim_path, 'wb'))
        
        # # 'all_graphs_nested_list' will be a list of lists (each inner list is the result from a single chunk)
        # all_graphs = [graph for sublist in all_graphs_nested_list for graph in sublist]
        # list_path = os.path.join(self.processed_paths[self.file_idx].split('/')[:-1], self.stage+'.pickle')
        # pickle.dump(all_graphs, open(list_path, 'wb'))
        torch.save(self.collate(every_graph), self.processed_paths[self.file_idx])

    def turn_reactants_and_product_into_graph(self, reactants, products, data_idx):
        offset = 0 
        cannot_generate = False
        # mask: (n), with n = nb of nodes
        mask_product_and_sn = torch.zeros(MAX_ATOMS_RXN, dtype=torch.bool) # only reactant nodes = True
        mask_reactant_and_sn = torch.zeros(MAX_ATOMS_RXN, dtype=torch.bool) # only product nodes = True
        mask_sn = torch.ones(MAX_ATOMS_RXN, dtype=torch.bool) # only sn = False
        mask_atom_mapping = torch.zeros(MAX_ATOMS_RXN, dtype=torch.long)
        mol_assignment = torch.zeros(MAX_ATOMS_RXN, dtype=torch.long)

        # preprocess: get total number of product nodes
        nb_product_nodes = sum([len(Chem.MolFromSmiles(p).GetAtoms()) for p in products])
        nb_rct_nodes = sum([len(Chem.MolFromSmiles(r).GetAtoms()) for r in reactants])
        
        # add dummy nodes: (nodes_in_product + max_added) - nodes_in_reactants
        nb_dummy_toadd = nb_product_nodes + self.max_nodes_more_than_product - nb_rct_nodes
        if nb_dummy_toadd<0 and self.stage=='train':
            # drop the rxns in the training set which we cannot generate
            return None
        if nb_dummy_toadd<0 and (self.stage=='test' or self.stage=='val'):
            # cut the rct nodes
            nb_dummy_toadd = 0
            cannot_generate = True

        for j, r in enumerate(reactants):
            # NOTE: no supernodes for reactants (treated as one block)
            gi_nodes, gi_edge_index, gi_edge_attr, atom_map = mol.mol_to_graph(mol=r, atom_types=self.atom_types, 
                                                                            bond_types=self.bond_types,
                                                                            with_explicit_h=self.with_explicit_h,
                                                                            with_formal_charge=self.with_formal_charge,
                                                                            offset=offset, get_atom_mapping=True,
                                                                            canonicalize_molecule=self.canonicalize_molecule)
            g_nodes_rct = torch.cat((g_nodes_rct, gi_nodes), dim=0) if j > 0 else gi_nodes # already a tensor
            g_edge_index_rct = torch.cat((g_edge_index_rct, gi_edge_index), dim=1) if j > 0 else gi_edge_index
            g_edge_attr_rct = torch.cat((g_edge_attr_rct, gi_edge_attr), dim=0) if j > 0 else gi_edge_attr

            atom_mapped_idx = (atom_map!=0).nonzero()
            mask_atom_mapping[atom_mapped_idx+offset] = atom_map[atom_mapped_idx]
            mol_assignment[offset:offset+gi_nodes.shape[0]] = j+1
            offset += gi_nodes.shape[0] 


        g_nodes_dummy = torch.ones(nb_dummy_toadd, dtype=torch.long) * self.atom_types.index(DUMMY_RCT_NODE_TYPE)
        g_nodes_dummy = F.one_hot(g_nodes_dummy, num_classes=len(self.atom_types)).float()
        # edges: fully connected to every node in the rct side with edge type 'none'
        g_edges_idx_dummy = torch.zeros([2, 0], dtype=torch.long)
        g_edges_attr_dummy = torch.zeros([0, len(self.bond_types)], dtype=torch.long)
        mask_product_and_sn[:g_nodes_rct.shape[0]+g_nodes_dummy.shape[0]] = True
        mol_assignment[offset:offset+g_nodes_dummy.shape[0]] = 0
        # mask_atom_mapping[offset:offset+g_nodes_dummy.shape[0]] = MAX_ATOMS_RXN
        offset += g_nodes_dummy.shape[0]
        
        g_nodes = torch.cat([g_nodes_rct, g_nodes_dummy], dim=0)
        g_edge_index = torch.cat([g_edge_index_rct, g_edges_idx_dummy], dim=1)
        g_edge_attr = torch.cat([g_edge_attr_rct, g_edges_attr_dummy], dim=0)

        # Permute the rows here to make sure that the NN can only process topological information
        def permute_rows(nodes, mask_atom_mapping, mol_assignment, edge_index):
            # Permutes the graph specified by nodes, mask_atom_mapping, mol_assignment and edge_index
            # nodes: (n,d_x) node feature tensor
            # mask_atom_mapping (n,) tensor
            # mol_assignment: (n,) tensor
            # edge_index: (2,num_edges) tensor
            # does everything in-place
            rct_section_len = nodes.shape[0]
            perm = torch.randperm(rct_section_len)
            nodes[:] = nodes[perm]
            mask_atom_mapping[:rct_section_len] = mask_atom_mapping[:rct_section_len][perm]
            mol_assignment[:rct_section_len] = mol_assignment[:rct_section_len][perm]
            inv_perm = torch.zeros(rct_section_len, dtype=torch.long)
            inv_perm.scatter_(dim=0, index=perm, src=torch.arange(rct_section_len))
            edge_index[:] = inv_perm[edge_index]

        if self.permute_mols:
            permute_rows(g_nodes, mask_atom_mapping, mol_assignment, g_edge_index)

        supernodes_prods = []
        for j, p in enumerate(products):
            # NOTE: still need supernode for product to distinguish it from reactants
            gi_nodes, gi_edge_index, gi_edge_attr, atom_map = mol.rxn_to_graph_supernode(mol=p, atom_types=self.atom_types, bond_types=self.bond_types,
                                                                                        with_explicit_h=self.with_explicit_h, supernode_nb=offset+1,
                                                                                        with_formal_charge=self.with_formal_charge,
                                                                                        add_supernode_edges=self.add_supernode_edges, get_atom_mapping=True,
                                                                                        canonicalize_molecule=self.canonicalize_molecule)
            
            g_nodes_prod = torch.cat((g_nodes_prod, gi_nodes), dim=0) if j > 0 else gi_nodes # already a tensor
            g_edge_index_prod = torch.cat((g_edge_index_prod, gi_edge_index), dim=1) if j > 0 else gi_edge_index
            g_edge_attr_prod = torch.cat((g_edge_attr_prod, gi_edge_attr), dim=0) if j > 0 else gi_edge_attr
            atom_mapped_idx = (atom_map!=0).nonzero()
            mask_atom_mapping[atom_mapped_idx+offset] = atom_map[atom_mapped_idx]
            mask_reactant_and_sn[offset:gi_nodes.shape[0]+offset] = True
            mol_assignment[offset] = 0 # supernode does not belong to any molecule
            suno_idx = offset # there should only be one supernode and one loop through the products
            mol_assignment[offset+1:offset+1+gi_nodes.shape[0]] = len(reactants)+j+1 # TODO: Is there one too many assigned as a product atom here?
            mask_sn[offset] = False
            mask_reactant_and_sn[offset] = False
            # supernode is always in the first position
            si = 0 # gi_edge_index[0][0].item()
            supernodes_prods.append(si)
            offset += gi_nodes.shape[0]

        # Keep the supernode intact here, others are permuted
        def permute_rows_product(g_nodes_prod, mask_atom_mapping, g_edge_index_prod):
            prod_indices = (suno_idx, suno_idx + g_nodes_prod.shape[0])
            perm = torch.cat([torch.tensor([0], dtype=torch.long), 1 + torch.randperm(g_nodes_prod.shape[0]-1)], 0)
            inv_perm = torch.zeros(len(perm), dtype=torch.long)
            inv_perm.scatter_(dim=0, index=perm, src=torch.arange(len(perm)))
            g_nodes_prod[:] = g_nodes_prod[perm]
            
            # sn_and_prod_selection = (prod_selection | suno_idx == torch.arange(len(prod_selection)))
            mask_atom_mapping[prod_indices[0]:prod_indices[1]] = mask_atom_mapping[prod_indices[0]:prod_indices[1]][perm]
            
            # The following because g_edge_index_prod are counted with their offset in the final graph
            offset_padded_perm = torch.cat([torch.zeros(suno_idx, dtype=torch.long), suno_idx + perm]) # for debugging
            offset_padded_inv_perm = torch.cat([torch.zeros(suno_idx, dtype=torch.long), suno_idx + inv_perm])
            
            g_edge_index_prod[:] = offset_padded_inv_perm[g_edge_index_prod]

        if self.permute_mols:
            permute_rows_product(g_nodes_prod, mask_atom_mapping, g_edge_index_prod)

        # concatenate all types of nodes and edges
        g_nodes = torch.cat([g_nodes, g_nodes_prod], dim=0)
        g_edge_index = torch.cat([g_edge_index, g_edge_index_prod], dim=1)
        g_edge_attr = torch.cat([g_edge_attr, g_edge_attr_prod], dim=0)

        y = torch.zeros((1, 0), dtype=torch.float)
        
        # trim masks => one element per node in the rxn graph
        mask_product_and_sn = mask_product_and_sn[:g_nodes.shape[0]] # only reactant nodes = True
        mask_reactant_and_sn = mask_reactant_and_sn[:g_nodes.shape[0]]
        mask_sn = mask_sn[:g_nodes.shape[0]]
        mask_atom_mapping = mask_atom_mapping[:g_nodes.shape[0]]
        mol_assignment = mol_assignment[:g_nodes.shape[0]]
        
        mask_atom_mapping = mol.sanity_check_and_fix_atom_mapping(mask_atom_mapping, g_nodes)
        
        assert mask_atom_mapping.shape[0]==g_nodes.shape[0] and mask_sn.shape[0]==g_nodes.shape[0] and \
            mask_reactant_and_sn.shape[0]==g_nodes.shape[0] and mask_product_and_sn.shape[0]==g_nodes.shape[0] and \
            mol_assignment.shape[0]==g_nodes.shape[0]

        # erase atom mapping absolute information for good. 
        perm = torch.arange(mask_atom_mapping.max().item()+1)[1:]
        perm = perm[torch.randperm(len(perm))]
        perm = torch.cat([torch.zeros(1, dtype=torch.long), perm])
        mask_atom_mapping = perm[mask_atom_mapping]

        graph = Data(x=g_nodes, edge_index=g_edge_index, 
                    edge_attr=g_edge_attr, y=y, idx=data_idx,
                    mask_sn=mask_sn, mask_reactant_and_sn=mask_reactant_and_sn, 
                    mask_product_and_sn=mask_product_and_sn, mask_atom_mapping=mask_atom_mapping,
                    mol_assignment=mol_assignment, cannot_generate=cannot_generate)

        return graph

    def sub_process(self, dataset_chunk):
        assert DUMMY_RCT_NODE_TYPE in self.atom_types, 'DUMMY_RCT_NODE_TYPE not in atom_types.'
        try:
            graphs = []
            for i, rxn_ in enumerate(dataset_chunk[0]):
                # Define the signal handler function
                def handler(signum, frame):
                    raise Exception("Timeout!")
                # Set the signal handler for the alarm signal
                # signal.signal(signal.SIGALRM, handler)
                # Set an alarm for 5 seconds
                # signal.alarm(5)
                try:
                    # cannot_generate = False
                    reactants = [r for r in rxn_.split('>>')[0].split('.')]
                    products = [p for p in rxn_.split('>>')[1].split('.')]

                    graph = self.turn_reactants_and_product_into_graph(reactants, products, data_idx=i+dataset_chunk[1])
                    if graph is not None:
                        graphs.append(graph)

                except Exception as e:
                    print(e)
                    print(f"Couldn't handle reaction {rxn_}")
                
            os.makedirs(self.processed_paths[self.file_idx].split('.pt')[0], exist_ok=True)
            subprocess_path = os.path.join(self.processed_paths[self.file_idx].split('.pt')[0], f'graphs_{dataset_chunk[1]}.pickle')
            pickle.dump(graphs, open(subprocess_path, 'wb'))
        except Exception as e:
            logger = logging.getLogger(__name__)
            logger.error(f'Error in sub_process: {e}', exc_info=True)
    
class DataModule(AbstractDataModule):
    def __init__(self, cfg):
        self.num_processes = cfg.dataset.num_processes
        self.with_explicit_h = cfg.dataset.with_explicit_h
        self.with_formal_charge = cfg.dataset.with_formal_charge
        self.datadir = cfg.dataset.datadir
        self.datadist_dir = cfg.dataset.datadist_dir
        self.max_nodes_more_than_product = cfg.dataset.nb_rct_dummy_nodes
        self.canonicalize_molecule = cfg.dataset.canonicalize_molecule
        print(f'cfg.dataset.add_supernode_edges {cfg.dataset.add_supernode_edges}\n')
        self.add_supernode_edges = cfg.dataset.add_supernode_edges
        self.atom_types = cfg.dataset.atom_types
        self.bond_types = add_chem_bond_types(cfg.dataset.bond_types)
        
        self.permute_mols = cfg.dataset.permute_mols
        print(f'self.atom_types {self.atom_types}\n')
        if cfg.dataset.dataset_nb!='':
            self.datadir += '-'+str(cfg.dataset.dataset_nb)
            self.datadist_dir += '-'+str(cfg.dataset.dataset_nb)
        super().__init__(cfg)
    
    def prepare_data(self, shuffle=True, slices={'train':None, 'val':None, 'test':None}) -> None:
        base_path = pathlib.Path(os.path.realpath(__file__)).parents[2]
        root_path = os.path.join(base_path, self.datadir)
        datasets = {'train': Dataset(stage='train', root=root_path, atom_types=self.atom_types, bond_types=self.bond_types, with_explicit_h=self.with_explicit_h, permute_mols=self.permute_mols,
                                     with_formal_charge=self.with_formal_charge, add_supernode_edges=self.add_supernode_edges, num_processes=self.num_processes,
                                     max_nodes_more_than_product=self.max_nodes_more_than_product, canonicalize_molecule=self.canonicalize_molecule),
                    'val': Dataset(stage='val', root=root_path, atom_types=self.atom_types, bond_types=self.bond_types, with_explicit_h=self.with_explicit_h, num_processes=self.num_processes,
                                   with_formal_charge=self.with_formal_charge, max_nodes_more_than_product=self.max_nodes_more_than_product, 
                                   canonicalize_molecule=self.canonicalize_molecule, add_supernode_edges=self.add_supernode_edges,
                                   permute_mols=self.permute_mols),
                    'test': Dataset(stage='test', root=root_path, atom_types=self.atom_types, bond_types=self.bond_types, size_test_splits=self.cfg.test.size_test_splits, 
                                    with_explicit_h=self.with_explicit_h, with_formal_charge=self.with_formal_charge, num_processes=self.num_processes,
                                    max_nodes_more_than_product=self.max_nodes_more_than_product, add_supernode_edges=self.add_supernode_edges,
                                    canonicalize_molecule=self.canonicalize_molecule,
                                    permute_mols=self.permute_mols)}
        
        for key in slices.keys():
            if slices[key] is not None:
                datasets[key] = datasets[key][slices[key][0]:slices[key][1]]
                
        print(f'len test datasets {len(datasets["test"])}\n')
        print(f'len train datasets {len(datasets["train"])}\n')
        print(f'len val datasets {len(datasets["val"])}\n')
                
        super().prepare_data(datasets, shuffle=shuffle)

    def node_counts(self, max_nodes_possible=MAX_ATOMS_RXN):
        '''
            Number of nodes in a rxn - supernodes.
        '''
        all_counts = torch.zeros(max_nodes_possible)
        for split in ['train', 'val', 'test']: # over all datasets?
            for i, data in enumerate(self.dataloaders[split]):
                batch_without_sn = data.batch[data.mask_sn] # true everywhere but on sn nodes
                unique, counts = torch.unique(batch_without_sn, return_counts=True)
                for count in counts:
                    all_counts[count] += 1

        max_index = max(all_counts.nonzero())
        all_counts = all_counts[:max_index + 1]
        all_counts = all_counts/all_counts.sum()
        
        return all_counts

    def node_types(self):
        data = next(iter(self.dataloaders['train']))
        num_classes = data.x.shape[1] # including supernode 
        d = torch.zeros(num_classes)

        for data in self.dataloaders['train']:
            d += data.x.sum(dim=0) # supernode is at encoder index -1 => discard it
        d = d / d.sum()

        return d

    def edge_types(self):
        num_classes = None
        data = next(iter(self.dataloaders['train']))
        num_classes = data.edge_attr.shape[1]
        d = torch.zeros(num_classes)

        for i, data in enumerate(self.dataloaders['train']):
            # batch_without_sn = data.batch[data.mask_sn]
            unique, counts = torch.unique(data.batch, return_counts=True)
            all_pairs = 0
            for count in counts:
                all_pairs += count * (count - 1)
            non_sn_node_idx = (data.mask_sn==True).nonzero(as_tuple=True)[0]
            non_sn_edge_index, non_sn_edge_attr = subgraph(non_sn_node_idx, data.edge_index, data.edge_attr)

            num_edges = non_sn_edge_index.shape[1]
            num_non_edges = all_pairs - num_edges

            edge_types = non_sn_edge_attr.sum(dim=0)
            assert num_non_edges >= 0
            d[0] += num_non_edges
            d[1:] += edge_types[1:] 

        d = d/d.sum() 

        return d

    def node_types_unnormalized(self):
        #TODO: Can this be abstracted to the AbstractDataModule class?
        '''
            Return distribution over the of atom types in molecules.

            Output:
                counts: distribution over types of atoms in molecules.
        '''
        data = next(iter(self.dataloaders['train']))
        num_classes = data.x.shape[1] # get number of atom types from node encoding
        counts = torch.zeros(num_classes)

        for data in self.dataloaders['train']:
            counts += data.x.sum(dim=0)
            
        # ignore SuNo atom type 
        # (set frequencies to 0. because it does not appear in the data anyway)
        suno_idx = self.atom_types.index('SuNo')
        counts[suno_idx] = 0.

        return counts.long()
    
    def edge_types_unnormalized(self):
        #TODO: Can this be abstracted to the AbstractDataModule class?
        data = next(iter(self.dataloaders['train']))
        num_classes = data.edge_attr.shape[1]

        d = torch.zeros(num_classes)

        for i, data in enumerate(self.dataloaders['train']):
            _, counts = torch.unique(data.batch, return_counts=True)

            all_pairs = 0
            for count in counts:
                all_pairs += count * (count - 1) # all_pairs does not include edge from the node to itself

            num_edges = data.edge_index.shape[1]
            num_non_edges = all_pairs - num_edges
            
            edge_types = data.edge_attr.sum(dim=0)

            assert num_non_edges >= 0
            d[0] += num_non_edges
            d[1:] += edge_types[1:]   
            
        # ignore SuNo edge types
        for t in ['mol', 'within', 'across']:
            suno_idx = self.bond_types.index(t)
            d[suno_idx] = 0.

        return d.long()
    
class DatasetInfos:
    def __init__(self, datamodule, atom_types, bond_types, bond_orders=False, allowed_bonds=False, zero_bond_order=False, recompute_info=False, remove_h=True):
        '''
            zero_bond_order is a temporary fix to a bug in the creation of bond_orders that affects extra_features later.
            The fix is to accommodate models trained with the bug.
        '''
        self.datamodule = datamodule
        self.name = 'supernode_graphs'
        # self.atom_encoder = ['none']+atom_types # takes type (str) get idx (int)
        self.atom_decoder = atom_types
        self.bond_decoder = add_chem_bond_types(bond_types)
        self.remove_h = remove_h
        # self.valencies = [0] + list(abs[0] for atom_type, abs in allowed_bonds.items() if atom_type in atom_types) + [0]
        self.valencies = list(self.get_possible_valences(atom_type)[0] for atom_type in atom_types)
        periodic_table = Chem.rdchem.GetPeriodicTable()
        atom_weights = [0] + [periodic_table.GetAtomicWeight(re.split(r'\+|\-', atom_type)[0]) for atom_type in atom_types[1:-1]] + [0] # discard charge
        atom_weights = {atom_type: weight for atom_type, weight in zip(atom_types, atom_weights)}
        self.atom_weights = atom_weights
        self.max_weight = 390
        print(f'zero_bond_order {zero_bond_order}\n')
        if zero_bond_order: self.bond_orders = get_bond_orders(bond_types)
        else: self.bond_orders = get_bond_orders_correct(bond_types)
        print(f'self.bond_orders {self.bond_orders}\n')

        base_path = pathlib.Path(os.path.realpath(__file__)).parents[2]
        root_path = os.path.join(base_path, datamodule.datadist_dir, 'processed')
        node_count_path = os.path.join(root_path, 'n_counts.txt')
        atom_type_path = os.path.join(root_path, 'atom_types.txt')
        edge_type_path = os.path.join(root_path, 'edge_types.txt')
        atom_type_unnorm_path = os.path.join(root_path, 'atom_types_unnorm_mol.txt')
        edge_type_unnorm_path = os.path.join(root_path, 'edge_types_unnorm_mol.txt')
        paths_exist = os.path.exists(node_count_path) and os.path.exists(atom_type_path)\
                      and os.path.exists(edge_type_path) and os.path.exists(atom_type_unnorm_path)\
                      and os.path.exists(edge_type_unnorm_path)

        if not recompute_info and paths_exist:
            # use the same distributions for all subsets of the dataset
            self.n_nodes = torch.from_numpy(np.loadtxt(node_count_path)).float()
            self.node_types = torch.from_numpy(np.loadtxt(atom_type_path)).float()
            self.edge_types = torch.from_numpy(np.loadtxt(edge_type_path)).float()
            self.node_types_unnormalized = torch.from_numpy(np.loadtxt(atom_type_unnorm_path)).long()
            self.edge_types_unnormalized = torch.from_numpy(np.loadtxt(edge_type_unnorm_path)).long()
        else:
            print('Recomputing\n')
            np.set_printoptions(suppress=True, precision=5)

            self.n_nodes = datamodule.node_counts()
            print("Distribution of number of nodes", self.n_nodes)
            np.savetxt(node_count_path, self.n_nodes.cpu().numpy())

            self.node_types_unnormalized = datamodule.node_types_unnormalized()
            print("Counts of node types", self.node_types_unnormalized)
            np.savetxt(atom_type_unnorm_path, self.node_types_unnormalized.cpu().numpy())

            self.edge_types_unnormalized = datamodule.edge_types_unnormalized()
            print("Counts of edge types", self.edge_types_unnormalized)
            np.savetxt(edge_type_unnorm_path, self.edge_types_unnormalized.cpu().numpy())

            self.node_types = self.node_types_unnormalized / self.node_types_unnormalized.sum()
            print("Distribution of node types", self.node_types)
            np.savetxt(atom_type_path, self.node_types.cpu().numpy())

            self.edge_types = self.edge_types_unnormalized / self.edge_types_unnormalized.sum()
            print("Distribution of edge types", self.edge_types)
            np.savetxt(edge_type_path, self.edge_types.cpu().numpy())

        self.complete_infos(n_nodes=self.n_nodes, node_types=self.node_types)
        
    def get_possible_valences(self, atom_type):
        pt = Chem.GetPeriodicTable()
        try:
            valence_list = pt.GetValenceList(atom_type)
        except:
            valence_list = [0]
    
        return list(valence_list)

    def complete_infos(self, n_nodes, node_types):
        self.input_dims = None
        self.output_dims = None
        self.num_classes = len(node_types)
        self.max_n_nodes = len(n_nodes) - 1
        self.nodes_dist = DistributionNodes(n_nodes)

    def compute_input_output_dims(self, dx=None, de=None, dy=None, datamodule=None):
        assert datamodule is not None or dx is not None, f'Got datamodule={datamodule} and dx={dx}. One of the two should be specified.\n'
        
        if dx is not None and de is not None and dy is not None:
            self.input_dims = {'X': dx, 'E': de, 'y': dy+1}  # + 1 due to time conditioning

            self.output_dims = {'X': dx, # output dim = # of features
                                'E': de,
                                'y': 0}
 
        else:
            example_batch = next(iter(datamodule.train_dataloader()))

            self.input_dims = {'X': example_batch['x'].size(1), # n or dx?
                            'E': example_batch['edge_attr'].size(1),
                            'y': example_batch['y'].size(1) + 1}  # + 1 due to time conditioning

            self.output_dims = {'X': example_batch['x'].size(1), # output dim = # of features
                                'E': example_batch['edge_attr'].size(1),
                                'y': 0}


            
