import os
import json
import pickle
import random
import numpy as np
import pandas as pd
import rdkit
import rdkit.Chem as Chem
from rdkit.Chem import AllChem
from rdkit.Chem.rdchem import Mol as Molecule
from rdkit.Chem.rdmolops import RemoveAllHs
from typing import Tuple, List

from data.config import GEOM_DRUGS_RDKIT_SUMMARY_PATH, RDKIT_FOLDER_DIR, GEOM_DRUGS_PICKLE_PATH, \
    GEOM_DRUGS_SMALL_PICKLE_PATH

WEIGHT_GATE = 0.01
N_SAMPLE = 30
GENERATE_RATE = 2


def sample_from_weights(n_sample: int, weights: List[float]) -> List[int]:
    ret = []
    for i in range(n_sample):
        a = random.random()
        t = 0.
        for j, w in enumerate(weights):
            t += w
            if a < t:
                ret.append(j)
                break
        if len(ret) == i:
            ret.append(len(weights) - 1)
    return ret


def mol_pickle2list(p: dict) -> List[Molecule]:
    confs = p['conformers']
    weights = []
    mols = []
    for conf in confs:
        weight = conf['boltzmannweight']
        if weight < WEIGHT_GATE:
            continue

        mol = conf['rd_mol']
        mol = RemoveAllHs(mol)
        weights.append(weight)
        mols.append(mol)

    if len(mols) == 0:
        raise ValueError
    s = sum(weights)
    weights = [w / s for w in weights]
    indices = sample_from_weights(N_SAMPLE, weights)
    return [mols[i] for i in indices]


def process_drugs():
    random.seed(0)
    np.random.seed(0)
    print('Loading DRUGS...')

    with open(GEOM_DRUGS_RDKIT_SUMMARY_PATH) as fp:
        summary: dict = json.load(fp)
    n_k = len(summary.keys())

    mol_list_mol = []
    for i, (k, v) in enumerate(summary.items()):
        if (i + 1) % 1000 == 0:
            print('\t{}/{} loaded'.format(i + 1, n_k))
        # if random.random() > 0.7:
        #     continue
        if i >= 100000:
            break
        if "pickle_path" not in v.keys():
            continue
        path = v["pickle_path"]
        try:
            with open(f'{RDKIT_FOLDER_DIR}/{path}', 'rb') as fp:
                p: dict = pickle.load(fp)
        except FileNotFoundError:
            continue
        try:
            mol_list_mol.append(mol_pickle2list(p))
        except ValueError:
            pass
        except rdkit.Chem.rdchem.KekulizeException:
            pass

    n_mol = len(mol_list_mol)
    print(f'\tProcessed: {100 * n_mol / n_k:.2f}% ({n_mol} out of {n_k})')
    return mol_list_mol


def sample_drugs(mol_list_mol, small=False):
    n_mol = len(mol_list_mol)
    print('\tSampling...')
    indices = np.random.permutation(n_mol)
    if small:
        n_train, n_validate, n_test = 5000, 916, 916
    else:
        n_train, n_validate, n_test = 50000, 9161, 9161
    train_list_mol = [mol_list_mol[indices[random.randrange(200, n_mol)]][random.randrange(0, N_SAMPLE)]
                      for _ in range(n_train)]
    validate_list_list_mol = [
        [mol_list_mol[indices[i]][random.randrange(0, N_SAMPLE)] for _ in range(int(n_validate / 100))]
        for i in range(100, 200)]
    test_list_list_mol = [
        [mol_list_mol[indices[i]][random.randrange(0, N_SAMPLE)] for _ in range(int(n_test / 100))]
        for i in range(0, 100)]

    print('\tCaching DRUGS...')
    with open(GEOM_DRUGS_SMALL_PICKLE_PATH if small else GEOM_DRUGS_PICKLE_PATH, 'wb+') as fp:
        pickle.dump((train_list_mol, validate_list_list_mol, test_list_list_mol), fp)
    print('\tCaching Finished!')


def redump_drugs_from_tj_dataset(small=False):
    print('Redumping...')
    with open('data/geom_drugs/train_Drugs.pkl', 'rb') as fp:
        train_list_mol = [AllChem.RemoveHs(m) for m in pickle.load(fp)]
    with open('data/geom_drugs/val_Drugs.pkl', 'rb') as fp:
        validate_list_mol = [AllChem.RemoveHs(m) for m in pickle.load(fp)]
    with open('data/geom_drugs/test_Drugs.pkl', 'rb') as fp:
        test_list_mol = [AllChem.RemoveHs(m) for m in pickle.load(fp)]
    if small:
        train_list_mol = train_list_mol[::10]
        validate_list_mol = validate_list_mol[::10]
        test_list_mol = test_list_mol[::10]
    d = {}
    for m in validate_list_mol:
        smiles = Chem.MolToSmiles(m)
        d.setdefault(smiles, []).append(m)
    validate_list_list_mol = list(d.values())
    d = {}
    for m in test_list_mol:
        smiles = Chem.MolToSmiles(m)
        d.setdefault(smiles, []).append(m)
    test_list_list_mol = list(d.values())
    print('\tCaching DRUGS...')
    with open(GEOM_DRUGS_SMALL_PICKLE_PATH if small else GEOM_DRUGS_PICKLE_PATH, 'wb+') as fp:
        pickle.dump((train_list_mol, validate_list_list_mol, test_list_list_mol), fp)
    print('\tCaching Finished!')


def load_drugs(small=False, force_save=False) -> Tuple[List[Molecule], List[List[Molecule]], List[List[Molecule]]]:
    print('\tLoading DRUGS...')
    path = GEOM_DRUGS_SMALL_PICKLE_PATH if small else GEOM_DRUGS_PICKLE_PATH
    assert os.path.exists(path), "geom_drugs hasn't been cached. Please run cache_geom.py first."
    try:
        with open(path, 'rb') as fp:
            train_list_mol, validate_list_list_mol, test_list_list_mol = pickle.load(fp)
    except EOFError:
        assert False, "geom_drugs hasn't been cached. Please run cache_geom.py first."
    print('\tLoading Finished!')

    return train_list_mol, validate_list_list_mol, test_list_list_mol
