import logging
logging.getLogger().setLevel(logging.INFO)
import os
import argparse

import torch
from tqdm import tqdm
from torch.nn.utils.rnn import pad_sequence
import numpy as np
import rdkit.Chem as Chem

from optimization.props.properties import penalized_logp, drd2, qed, tpsa, get_morgan_fingerprint
from qm9.data.prepare.compute_atomic_features import get_atom_features_from_smiles
from synthetic_coordinates.rdkit_helpers import smiles_to_mol

# only QM9
#symbol_to_atomic_number = {'H': 1, 'C': 6, 'N': 7, 'O': 8, 'F': 9}

# ZINC250k but includes QM9
#symbol_to_atomic_number = {'H': 1, 'C': 6, 'N': 7, 'O': 8, 'F': 9, 'P': 15, 'S': 16, 'Cl': 17, 'Br': 35, 'I': 53}

# GuacaMol but includes ZINC150k
symbol_to_atomic_number = {'H': 1, 'B': 5, 'C': 6, 'N': 7, 'O': 8, 'F': 9, 'Si': 14, 'P': 15, 'S': 16, 'Cl': 17, 'Se': 34, 'Br': 35, 'I': 53}

def element_to_atomic_num(atomic_element):
    return Chem.Atom(atomic_element).GetAtomicNum()

def process_dataset(datadir, xyz_folder_name, split, subset=None, only_explicit_H=False, prop=None):
    split_dir = os.path.join(datadir, xyz_folder_name, split)
    smiles_file = os.path.join(datadir, f'smiles/{split}.txt')
    if subset is None:
        processed_split = process_xyz_files(split_dir, process_xyz, smiles_file, stack=True, only_explicit_H=only_explicit_H, prop=prop)
        savedir = os.path.join(datadir, split+'.npz')
    else:
        # process only first subset molecules
        processed_split = process_xyz_files(split_dir, process_xyz, smiles_file, file_idx_list=range(subset), stack=True, only_explicit_H=only_explicit_H, prop=prop)
        savedir = os.path.join(datadir, f'{split}_{subset}_subset.npz')

    logging.info('Saving processed data:')
    np.savez_compressed(savedir, **processed_split)
    logging.info('Processing/saving complete!')


def process_xyz_files(data, process_file_fn, smiles_file, file_ext=None, file_idx_list=None, stack=True, only_explicit_H=False, prop=None):
    """
    Take a set of datafiles and apply a predefined data processing script to each
    one. Data can be stored in a directory, tarfile, or zipfile. An optional
    file extension can be added.

    Parameters
    ----------
    data : str
        Complete path to datafiles. Files must be in a directory, tarball, or zip archive.
    process_file_fn : callable
        Function to process files. Can be defined externally.
        Must input a file, and output a dictionary of properties, each of which
        is a torch.tensor. Dictionary must contain at least three properties:
        {'num_elements', 'charges', 'positions'}
    file_ext : str, optional
        Optionally add a file extension if multiple types of files exist.
    file_idx_list : ?????, optional
        Optionally add a file filter to check a file index is in a
        predefined list, for example, when constructing a train/valid/test split.
    stack : bool, optional
        ?????
    """
    logging.info('Processing data file: {}'.format(data))
    if os.path.isdir(data):
        files = os.listdir(data)
        files = [os.path.join(data, file) for file in files]

        readfile = lambda data_pt: open(data_pt, 'r')

    else:
        raise ValueError('Can only read from directory!')

    # Use only files that end with specified extension.
    if file_ext is not None:
        files = [file for file in files if file.endswith(file_ext)]

    # Use only files that match desired filter.
    if file_idx_list is not None:
        files = [file for idx, file in enumerate(files) if idx in file_idx_list]

    # Now loop over files using readfile function defined above
    # Process each file accordingly using process_file_fn

    all_smiles = read_smiles_file(smiles_file)
    max_n_atoms = get_max_n_atoms(files, readfile)
    print('max number of atoms across dataset:', max_n_atoms)

    molecules = []
    idx = 0
    file_names = os.listdir(data)
    for file in tqdm(files):
        # file names are of the form f'mol_{mol_idx}.xyz'
        mol_idx = int(file_names[idx][4:-4])
        with readfile(file) as openfile:
            molecules.append(process_file_fn(openfile, all_smiles[mol_idx], max_n_atoms, only_explicit_H=only_explicit_H, prop=prop))
        idx += 1

    # Check that all molecules have the same set of items in their dictionary:
    props = molecules[0].keys()
    assert all(props == mol.keys() for mol in molecules), 'All molecules must have same set of properties/keys!'

    # Convert list-of-dicts to dict-of-lists
    molecules = {prop: [mol[prop] for mol in molecules] for prop in props}

    # If stacking is desireable, pad and then stack.
    if stack:
        molecules = {key: pad_sequence(val, batch_first=True) if val[0].dim() > 0 else torch.stack(val) for key, val in molecules.items()}

    return molecules


def process_xyz(datafile, smiles, max_n_atoms, only_explicit_H=False, prop=None):
    """
    Read xyz file and return a molecular dict with number of atoms, coordinates and atom-type for the requested dataset.

    Parameters
    ----------
    datafile : python file object
        File object containing the molecular data in the MD17 dataset.

    Returns
    -------
    molecule : dict
        Dictionary containing the molecular properties of the associated file object.

    Notes
    -----
    TODO : Replace breakpoint with a more informative failure?
    """
    xyz_lines = datafile.readlines()
    assert smiles == xyz_lines[1].strip("\n"), f"smiles {xyz_lines[1]} from the .xyz file does not match with the smiles {smiles} passed into this function"

    num_atoms = int(xyz_lines[0])
    mol_xyz = xyz_lines[2:num_atoms+2]

    atomic_numbers, positions, formal_charges = [], [], []
    for line in mol_xyz:
        atom_symbol, posx, posy, posz, formal_charge = line.replace('*^', 'e').split()
        atomic_numbers.append(symbol_to_atomic_number[atom_symbol])
        positions.append([float(posx), float(posy), float(posz)])
        formal_charges.append(int(formal_charge))

    adj_matrix = get_adj_matrix_from_smiles(smiles, only_explicit_H=only_explicit_H)
    assert adj_matrix.shape[0] == adj_matrix.shape[1] == num_atoms, 'The smiles file and the xyz file do not match'
    adj_list = get_adj_list_from_adj_matrix(adj_matrix)

    # pre-compute adjacency matrices only for smaller datasets, otherwise it's too much space
    if max_n_atoms < 50:
        # pad adj_matrix
        adj_matrix_padded = np.zeros((max_n_atoms, max_n_atoms))
        adj_matrix_padded[:num_atoms, :num_atoms] = adj_matrix
        # define aromatic bond as a further bond type
        adj_matrix_padded[adj_matrix_padded == 1.5] = 4
    else:
        adj_matrix_padded = None

    # compute all supported properties
    prop_values = {}
    if 'penalized_logP' in prop:
        prop_values['penalized_logP'] = penalized_logp(smiles)
    if 'morgan_fingerprint' in prop:
        prop_values['morgan_fingerprint'] = get_morgan_fingerprint(smiles, n_bits=1024)
    if 'qed' in prop:
        prop_values['qed'] = qed(smiles)
    if 'drd2' in prop:
        prop_values['drd2'] = drd2(smiles)
    if 'tpsa' in prop:
        prop_values['tpsa'] = tpsa(smiles)

    # this is how it should go for other props
    # if 'another_prop' in prop:
    #     prop_values['another_prop'] = compute_another_prop(smiles)

    # compute extra atomic features
    atomic_features = get_atom_features_from_smiles(smiles)
    assert [atomic_features[i][0] for i in range(len(atomic_features))] == atomic_numbers, "smiles does not match"
    #TODO: this fails for guacamol because the atom_features are not geenral enough, work on this!
    #assert [atomic_features[i][2]-1 for i in range(len(atomic_features))] == formal_charges, "smiles does not match"

    molecule = {'num_atoms': num_atoms, 'atomic_numbers': atomic_numbers, 'positions': positions, 
                'formal_charges': formal_charges, 'adj_list': adj_list,
                'atomic_features': atomic_features}
    if adj_matrix_padded is not None:
        molecule['adj_matrix'] = adj_matrix_padded
    # add the molecule properties computed above
    molecule.update(prop_values)
    molecule = {key: torch.tensor(val) for key, val in molecule.items()}

    return molecule


def get_max_n_atoms(files, readfile):
    max_n_atoms = 0
    for file in tqdm(files):
        with readfile(file) as datafile:
            xyz_lines = datafile.readlines()
            num_atoms = int(xyz_lines[0])
            max_n_atoms = max(max_n_atoms, num_atoms)
    return max_n_atoms


def read_smiles_file(smiles_file):
    with open(smiles_file, "rb") as file:
        all_smiles = file.readlines()
    all_smiles = [smiles.decode("utf-8").strip("\n") for smiles in all_smiles]
    return all_smiles


def get_adj_matrix_from_smiles(smiles, kekulize=False, only_explicit_H=False):
    """
    set kekulize to False to get the aromatic bonds in the adjacency matrix
    """
    mol = smiles_to_mol(smiles, kekulize=kekulize, only_explicit_H=only_explicit_H)
    if not only_explicit_H:
        mol = Chem.AddHs(mol)

    # useBO takes Bond Order into account so that we get 4 or 5 (including aromatic) bond types
    adj_matrix = Chem.rdmolops.GetAdjacencyMatrix(mol, useBO=True)
    return adj_matrix


def get_adj_list_from_adj_matrix(adj_matrix):
    """
    constructs a sparser representation of the adjacency matrix in order to save space
    """
    # only consider strict upper triangle because adj_matrix is symmetric and no atom is connected to itself
    edges = np.where(np.triu(adj_matrix)!=0)

    # construct a list of all edges where an edge is defined as (i, j, bond_type)
    adj_list = []
    for edge in zip(edges[0], edges[1]):
        i = edge[0]
        j = edge[1]
        bond_order = adj_matrix[i,j]
        adj_list.append([i, j, bond_order])

    if len(adj_list) == 0:
        # molecule with no edge: e.g. a single atom
        adj_list.append([0., 0., 0.])
    return adj_list


def get_adj_matrix_from_adj_list(adj_list, n_nodes):
    adj_matrix = np.zeros((n_nodes, n_nodes))
    for i, j, bond_order in adj_list:
        adj_matrix[i,j] = bond_order
        adj_matrix[j,i] = bond_order

    assert (adj_matrix == torch.transpose(adj_matrix, 1, 2)).all(), "Adjacency matrix is not symmetric, something is wrong"
    
    return adj_matrix


# To run this, go to the main folder e3_diffusion and run
# PYTHONPATH="${PYTHONPATH}:." python qm9/data/prepare/process_synthetic_coordinates.py --datadir data/zinc250k/ --split train
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--datadir", default="data/qm9")
    parser.add_argument("--xyz_folder_name", default="synthetic_coords_rdkit")
    parser.add_argument("--split", default="train")
    parser.add_argument("--subset", type=int, default=None)
    args = parser.parse_args()

    process_dataset(args.datadir, args.xyz_folder_name, args.split, args.subset)
