import os
import json
import pickle
import random
import numpy as np
import pandas as pd
import rdkit
import rdkit.Chem as Chem
from rdkit.Chem import AllChem
from rdkit.Chem.rdchem import Mol as Molecule
from rdkit.Chem.rdmolops import RemoveAllHs
from typing import Tuple, List

from data.config import GEOM_QM9_RDKIT_SUMMARY_PATH, RDKIT_FOLDER_DIR, GEOM_QM9_PICKLE_PATH, GEOM_QM9_SMALL_PICKLE_PATH

WEIGHT_GATE = 0.01
N_SAMPLE = 30
GENERATE_RATE = 2


def sample_from_weights(n_sample: int, weights: List[float]) -> List[int]:
    ret = []
    for i in range(n_sample):
        a = random.random()
        t = 0.
        for j, w in enumerate(weights):
            t += w
            if a < t:
                ret.append(j)
                break
        if len(ret) == i:
            ret.append(len(weights) - 1)
    return ret


def mol_pickle2list(p: dict) -> List[Molecule]:
    confs = p['conformers']
    weights = []
    mols = []
    for conf in confs:
        weight = conf['boltzmannweight']
        if weight < WEIGHT_GATE:
            continue

        mol = conf['rd_mol']
        mol = RemoveAllHs(mol)
        weights.append(weight)
        mols.append(mol)

    if len(mols) == 0:
        raise ValueError
    s = sum(weights)
    weights = [w / s for w in weights]
    indices = sample_from_weights(N_SAMPLE, weights)
    return [mols[i] for i in indices]


def process_qm9():
    random.seed(0)
    np.random.seed(0)
    print('Loading QM9...')

    with open(GEOM_QM9_RDKIT_SUMMARY_PATH) as fp:
        summary: dict = json.load(fp)
    n_k = len(summary.keys())

    mol_list_mol = []
    for i, (k, v) in enumerate(summary.items()):
        if (i + 1) % 1000 == 0:
            print('\t{}/{} loaded'.format(i + 1, n_k))
        if "pickle_path" not in v.keys():
            continue
        path = v["pickle_path"]
        with open(f'{RDKIT_FOLDER_DIR}/{path}', 'rb') as fp:
            p: dict = pickle.load(fp)
        try:
            mol_list_mol.append(mol_pickle2list(p))
        except ValueError:
            pass
        except rdkit.Chem.rdchem.KekulizeException:
            pass

    n_mol = len(mol_list_mol)
    print('\tProcessed: {:.2f}%'.format(100 * n_mol / n_k))
    return mol_list_mol


def sample_qm9(mol_list_mol, small=False):
    n_mol = len(mol_list_mol)
    print('\tSampling...')
    indices = np.random.permutation(n_mol)
    if small:
        n_train, n_validate, n_test = 5000, 1781, 1781
    else:
        n_train, n_validate, n_test = 50000, 17813, 17813
    train_list_mol = [mol_list_mol[indices[random.randrange(300, n_mol)]][random.randrange(0, N_SAMPLE)]
                      for _ in range(n_train)]
    validate_list_list_mol = [
        [mol_list_mol[indices[i]][random.randrange(0, N_SAMPLE)] for _ in range(int(n_validate / 150))]
        for i in range(150, 300)]
    test_list_list_mol = [
        [mol_list_mol[indices[i]][random.randrange(0, N_SAMPLE)] for _ in range(int(n_test / 150))]
        for i in range(0, 150)]

    print('\tCaching QM9...')
    with open(GEOM_QM9_SMALL_PICKLE_PATH if small else GEOM_QM9_PICKLE_PATH, 'wb+') as fp:
        pickle.dump((train_list_mol, validate_list_list_mol, test_list_list_mol), fp)
    print('\tCaching Finished!')


def redump_qm9_from_tj_dataset(small=False):
    print('Redumping...')
    with open('data/geom_qm9/train_QM9.pkl', 'rb') as fp:
        train_list_mol = [AllChem.RemoveHs(m) for m in pickle.load(fp)]
    with open('data/geom_qm9/val_QM9.pkl', 'rb') as fp:
        validate_list_mol = []
        for m in pickle.load(fp):
            try:
                m_ = AllChem.RemoveHs(m)
            except rdkit.Chem.rdchem.KekulizeException:
                continue
            validate_list_mol.append(m_)
    with open('data/geom_qm9/test_QM9.pkl', 'rb') as fp:
        test_list_mol = [AllChem.RemoveHs(m) for m in pickle.load(fp)]
    if small:
        train_list_mol = train_list_mol[::10]
        validate_list_mol = validate_list_mol[::10]
        test_list_mol = test_list_mol[::10]
    d = {}
    for m in validate_list_mol:
        smiles = Chem.MolToSmiles(m)
        d.setdefault(smiles, []).append(m)
    validate_list_list_mol = list(d.values())
    d = {}
    for m in test_list_mol:
        smiles = Chem.MolToSmiles(m)
        d.setdefault(smiles, []).append(m)
    test_list_list_mol = list(d.values())
    print('\tCaching QM9...')
    with open(GEOM_QM9_SMALL_PICKLE_PATH if small else GEOM_QM9_PICKLE_PATH, 'wb+') as fp:
        pickle.dump((train_list_mol, validate_list_list_mol, test_list_list_mol), fp)
    print('\tCaching Finished!')


def load_qm9(small=False, force_save=False) -> Tuple[List[Molecule], List[List[Molecule]], List[List[Molecule]]]:
    print('\tLoading QM9...')
    path = GEOM_QM9_SMALL_PICKLE_PATH if small else GEOM_QM9_PICKLE_PATH
    assert os.path.exists(path), "geom_qm9 hasn't been cached. Please run cache_geom.py first."
    try:
        with open(path, 'rb') as fp:
            train_list_mol, validate_list_list_mol, test_list_list_mol = pickle.load(fp)
    except EOFError:
        assert False, "geom_qm9 hasn't been cached. Please run cache_geom.py first."
    print('\tLoading Finished!')

    return train_list_mol, validate_list_list_mol, test_list_list_mol
