"""
The file contains code to read a dataset stored as SMILES strings, generates synthetic coordinates for it,
and stores the resulting molecules as .xyz files.
"""
import os
import argparse
import logging
logging.getLogger().setLevel(logging.INFO)

from tqdm import tqdm
import rdkit.Chem.AllChem as Chem

from synthetic_coordinates.rdkit_helpers import smiles_to_mol, write_xyz_with_formal_charges
from synthetic_coordinates.conformer_generation import set_3D_coords_rdkit


def smiles_to_xyz(data_folder, split, synthetic_coords_method, filter=None, only_explicit_H=False):
    """
    Reads a .txt file of SMILES strings, generates 3D coordinates for them and saves them back to .xyz files.
    Might skip some molecules if the 3D generation method fails.

    Parameters:
        data_folder (str): path to folder, where data is stored. Should contain another folder named `smiles` containing
                           text files with SMILES strings of train/valid/test splits
        split (str): train | valid | test | all. If None then assumed to bo all
        synthetic_coords_method (str): rdkit | other methods TBD
        filter (str): name of a function: smiles \mapsto bool, s.t. if it returns True, the corresponding smiles will be skipped 
                      and not included in the saved dataset. Can be 'no_charges' | TBD
    """
    if split is None:
        split = 'all'

    if filter is not None:
        filter_fn = get_filter_fn(filter)
    
    smiles_file_path = os.path.join(data_folder, 'smiles', split+'.txt')
    with open(smiles_file_path, "rb") as file:
        all_smiles = file.readlines()

    synthetic_coords_fn = get_synthetic_coords_fn(synthetic_coords_method)

    output_folder_name = 'synthetic_coords_' + synthetic_coords_method
    if filter is not None:
        output_folder_name = output_folder_name + '_' + filter
    output_dir = os.path.join(data_folder, output_folder_name, split)
    os.makedirs(output_dir, exist_ok=True)

    n_bad_conformer_id = 0
    n_unknown_error = 0
    n_filtered_out = 0

    for mol_idx in tqdm(range(len(all_smiles))):
        smiles = all_smiles[mol_idx].decode("utf-8").strip("\n")
        if filter is not None:
            if filter_fn(smiles):
                n_filtered_out += 1
                continue
        try:
            # TODO: decide if kekulize should be True or False
            mol = smiles_to_mol(smiles, kekulize=False, only_explicit_H=only_explicit_H)
            # this will add all Hs
            mol = synthetic_coords_fn(mol)

            if only_explicit_H:
                mol = Chem.RemoveHs(mol, implicitOnly=True)
            write_xyz_with_formal_charges(mol, smiles, os.path.join(output_dir, f'mol_{mol_idx}.xyz'))
            #Chem.rdmolfiles.MolToXYZFile(mol, os.path.join(output_dir, f'mol_{mol_idx}.xyz'))
        except ValueError as e:
            if "Bad Conformer Id" in str(e):
                n_bad_conformer_id += 1
                logging.warning("Encountered Bad Conformer Id Error")
            else:
                n_unknown_error += 1
                logging.warning(f"Encountered Unknown Error: {str(e)}")
        except Exception as e:
            n_unknown_error += 1
            logging.warning(f"Encountered Unknown Error: {str(e)}")

    logging.info('Done with generating .xyz files from SMILES!')
    logging.info(f'{n_bad_conformer_id} bad conformer IDs out of {len(all_smiles)} -> {n_bad_conformer_id/len(all_smiles)*100}%')
    logging.info(f'{n_unknown_error} other unknown errros out of {len(all_smiles)} -> {n_unknown_error/len(all_smiles)*100}%')
    logging.info(f'{n_filtered_out} skipped molecules for filter {filter} out of {len(all_smiles)} -> {n_filtered_out/len(all_smiles)*100}%')

    return output_folder_name


def get_synthetic_coords_fn(synthetic_coords_method):
    mapping = {
        'rdkit': set_3D_coords_rdkit
    }
    if synthetic_coords_method not in mapping:
        raise NotImplementedError(f"The requested method for synthetic coordinates generation {synthetic_coords_method}"
                                    "is not implemented")
    return mapping[synthetic_coords_method]

def get_filter_fn(filter):
    if filter == 'no_charges':
        def has_charges(smiles):
            return '+' in smiles or '-' in smiles
        
        return has_charges

    else:
        raise NotImplementedError(f"The requested filter function {filter} is not implemented")


# To run this, go to the main folder molecule_generation and run
# PYTHONPATH="${PYTHONPATH}:." python src/datasets/mol_dataset.py --split train
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--data-folder", default="data/qm9_smiles/")
    parser.add_argument("--split", default="train")
    parser.add_argument("--sc-method", default="rdkit")
    parser.add_argument("--filter", type=str, default=None)
    args = parser.parse_args()

    smiles_to_xyz(args.data_folder, args.split, args.sc_method, filter=args.filter)
