import os
from os.path import join
import pickle
import argparse

from tqdm import tqdm
from rdkit import Chem
import sklearn

from synthetic_coordinates.rdkit_helpers import smiles_to_mol

def get_molecule_stats(molecule, attributes, atom_encoder):
    assert len(attributes) == 4
    mol_stats = {}
    for attribute in attributes:
        mol_stats[attribute] = {}

    n_nodes = molecule.GetNumAtoms()
    assert 'n_nodes' in attributes
    mol_stats['n_nodes'][n_nodes] = 1
    
    atoms = molecule.GetAtoms()   
    assert 'atom_types' in attributes
    assert 'formal_charges' in attributes
    for atom in atoms:
        symbol = atom.GetSymbol()
        atom_type = atom_encoder[symbol]
        if atom_type in mol_stats['atom_types']:
            mol_stats['atom_types'][atom_type] += 1
        else:
            mol_stats['atom_types'][atom_type] = 1
        
        formal_charge = atom.GetFormalCharge()
        if formal_charge in mol_stats['formal_charges']:
            mol_stats['formal_charges'][formal_charge] += 1
        else:
            mol_stats['formal_charges'][formal_charge] = 1

    assert 'edge_types' in attributes
    num_edge_types = 5
    adj_matrix = Chem.rdmolops.GetAdjacencyMatrix(molecule, useBO=True)
    # map aromatic bonds to type 4
    adj_matrix[adj_matrix == 1.5] = 4
    for edge_type in range(num_edge_types):
        if edge_type == 0:
            twice_n = ((adj_matrix == edge_type).sum() - n_nodes)
            assert twice_n % 2 == 0
            mol_stats['edge_types'][edge_type] = (twice_n // 2).item()
        else:
            twice_n = (adj_matrix == edge_type).sum()
            assert twice_n % 2 == 0
            mol_stats['edge_types'][edge_type] = (twice_n // 2).item()

    return mol_stats

def extend_stats(dataset_info, mol_stats, attributes):
    for attribute in attributes:
        assert attribute in dataset_info
        assert attribute in mol_stats
        for occurence, count in mol_stats[attribute].items():
            if occurence in dataset_info[attribute]:
                dataset_info[attribute][occurence] += count
            else:
                dataset_info[attribute][occurence] = count
    return dataset_info

def compute_class_weights(stats_dict):
    y = []
    for key, value in stats_dict.items():
        y.extend([key]*value)
    class_weights = sklearn.utils.class_weight.compute_class_weight(class_weight='balanced', classes=list(stats_dict.keys()), y=y)
    return class_weights.tolist()

def get_all_atom_types(smiles_list, only_explicit_H):
    symbol_to_atomic_number = {}
    for mol_idx in tqdm(range(len(smiles_list))):
        smiles = smiles_list[mol_idx].decode("utf-8").strip("\n")
        mol = smiles_to_mol(smiles, kekulize=False, only_explicit_H=only_explicit_H)
        if not only_explicit_H:
            mol = Chem.AddHs(mol)
        
        atoms = mol.GetAtoms()   
        for atom in atoms:
            symbol = atom.GetSymbol()
            if symbol not in symbol_to_atomic_number:
                symbol_to_atomic_number[symbol] = atom.GetAtomicNum()
    symbol_to_atomic_number = dict(sorted(symbol_to_atomic_number.items(), key=lambda x: x[1]))
    return symbol_to_atomic_number

def compute_dataset_info(datadir, dataset_name, only_explicit_H=True):
    split = 'train'
    smiles_file_path = join(datadir, dataset_name, 'smiles', split+'.txt')
    with open(smiles_file_path, "rb") as file:
        all_smiles = file.readlines()

    # first we need all the atoms present in the dataset
    symbol_to_atomic_number = get_all_atom_types(all_smiles, only_explicit_H)

    dataset_info = {}
    dataset_info['name'] = dataset_name
    dataset_info['n_molecules'] = len(all_smiles)
    dataset_info['atom_decoder'] = list(symbol_to_atomic_number.keys())
    dataset_info['atom_encoder'] = {dataset_info['atom_decoder'][i]: i for i in range(len(dataset_info['atom_decoder']))}

    attributes = ['n_nodes', 'atom_types', 'formal_charges', 'edge_types']
    for attribute in attributes:
        dataset_info[attribute] = {}

    for mol_idx in tqdm(range(len(all_smiles))):
        smiles = all_smiles[mol_idx].decode("utf-8").strip("\n")
        mol = smiles_to_mol(smiles, kekulize=False, only_explicit_H=only_explicit_H)
        if not only_explicit_H:
            mol = Chem.AddHs(mol)
        mol_stats = get_molecule_stats(mol, attributes, dataset_info['atom_encoder'])
        dataset_info = extend_stats(dataset_info, mol_stats, attributes)

    # sort stats
    for attribute in attributes:
        dataset_info[attribute] = dict(sorted(dataset_info[attribute].items()))

    # infer max_n_nodes
    dataset_info['max_n_nodes'] = list(dataset_info['n_nodes'].keys())[-1]

    # compute class_weights
    for attribute in attributes:
        if attribute != 'n_nodes':
            dataset_info[f'class_weights_{attribute}'] = compute_class_weights(dataset_info[attribute])

    with open(join(datadir, dataset_name, 'dataset_info.p'), 'wb') as fp:
        pickle.dump(dataset_info, fp)

# To run this, go to the main folder e3_diffusion and run
# PYTHONPATH="${PYTHONPATH}:." python qm9/data/prepare/prepare_dataset_info.py --datadir data/ --dataname zinc250k_explicitH_new --only_explicit_H
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--datadir", default="data/")
    parser.add_argument("--dataname", default="zinc250k")
    parser.add_argument('--only_explicit_H', action='store_true', default=True,
                        help='Only Store explicit H atoms (that do not come from stereochemistry)')
    args = parser.parse_args()

    compute_dataset_info(args.datadir, args.dataname, only_explicit_H=args.only_explicit_H)
