import os
from pathlib import Path

parent_dir = str(str(Path(os.path.abspath(__file__)).parents[0]))

import tensorflow as tf

from externals.benchmarks.MolGAN.utils.sparse_molecular_dataset import (
    SparseMolecularDataset,
)
from externals.benchmarks.MolGAN.utils.trainer import Trainer
from externals.benchmarks.MolGAN.utils.utils import *

from externals.benchmarks.MolGAN.models.gan import GraphGANModel
from externals.benchmarks.MolGAN.models import (
    encoder_rgcn,
    decoder_adj,
    decoder_dot,
    decoder_rnn,
)

from externals.benchmarks.MolGAN.optimizers.gan import GraphGANOptimizer

from rdkit import rdBase

rdBase.DisableLog("rdApp.error")
tf.logging.set_verbosity(tf.logging.ERROR)

batch_dim = 32
la = 1
dropout = 0
n_critic = 5
metric = "unique,sas"
n_samples = 5000
z_dim = 32
epochs = 300
save_every = 10
create_G_data = True

data = SparseMolecularDataset()
data.load(parent_dir + "/data/QM9_5k_ori_Mg.sparsedataset")

steps = len(data) // batch_dim


def train_fetch_dict(i, steps, epoch, epochs, min_epochs, model, optimizer):
    a = [optimizer.train_step_G] if i % n_critic == 0 else [optimizer.train_step_D]
    b = [optimizer.train_step_V] if i % n_critic == 0 and la < 1 else []
    return a + b


def train_feed_dict(i, steps, epoch, epochs, min_epochs, model, optimizer, batch_dim):
    mols, _, _, a, x, _, _, _, _ = data.next_train_batch(batch_dim)
    embeddings = model.sample_z(batch_dim)

    if la < 1:

        if i % n_critic == 0:
            rewardR = reward(mols)

            n, e = session.run(
                [model.nodes_gumbel_argmax, model.edges_gumbel_argmax],
                feed_dict={model.training: False, model.embeddings: embeddings},
            )
            n, e = np.argmax(n, axis=-1), np.argmax(e, axis=-1)
            mols = [data.matrices2mol(n_, e_, strict=True) for n_, e_ in zip(n, e)]

            rewardF = reward(mols)

            feed_dict = {
                model.edges_labels: a,
                model.nodes_labels: x,
                model.embeddings: embeddings,
                model.rewardR: rewardR,
                model.rewardF: rewardF,
                model.training: True,
                model.dropout_rate: dropout,
                optimizer.la: la if epoch > 0 else 1.0,
            }

        else:
            feed_dict = {
                model.edges_labels: a,
                model.nodes_labels: x,
                model.embeddings: embeddings,
                model.training: True,
                model.dropout_rate: dropout,
                optimizer.la: la if epoch > 0 else 1.0,
            }
    else:
        feed_dict = {
            model.edges_labels: a,
            model.nodes_labels: x,
            model.embeddings: embeddings,
            model.training: True,
            model.dropout_rate: dropout,
            optimizer.la: 1.0,
        }

    return feed_dict


def eval_fetch_dict(i, epochs, min_epochs, model, optimizer):
    return {
        "loss D": optimizer.loss_D,
        "loss G": optimizer.loss_G,
        "loss RL": optimizer.loss_RL,
        "loss V": optimizer.loss_V,
        "la": optimizer.la,
    }


def eval_feed_dict(i, epochs, min_epochs, model, optimizer, batch_dim):
    mols, _, _, a, x, _, _, _, _ = data.next_validation_batch()
    embeddings = model.sample_z(a.shape[0])

    rewardR = reward(mols)

    n, e = session.run(
        [model.nodes_gumbel_argmax, model.edges_gumbel_argmax],
        feed_dict={model.training: False, model.embeddings: embeddings},
    )
    n, e = np.argmax(n, axis=-1), np.argmax(e, axis=-1)
    mols = [data.matrices2mol(n_, e_, strict=True) for n_, e_ in zip(n, e)]

    rewardF = reward(mols)

    feed_dict = {
        model.edges_labels: a,
        model.nodes_labels: x,
        model.embeddings: embeddings,
        model.rewardR: rewardR,
        model.rewardF: rewardF,
        model.training: False,
    }
    return feed_dict


def test_fetch_dict(model, optimizer):
    return {
        "loss D": optimizer.loss_D,
        "loss G": optimizer.loss_G,
        "loss RL": optimizer.loss_RL,
        "loss V": optimizer.loss_V,
        "la": optimizer.la,
    }


def test_feed_dict(model, optimizer, batch_dim):
    mols, _, _, a, x, _, _, _, _ = data.next_test_batch()
    embeddings = model.sample_z(a.shape[0])

    rewardR = reward(mols)

    n, e = session.run(
        [model.nodes_gumbel_argmax, model.edges_gumbel_argmax],
        feed_dict={model.training: False, model.embeddings: embeddings},
    )
    n, e = np.argmax(n, axis=-1), np.argmax(e, axis=-1)
    mols = [data.matrices2mol(n_, e_, strict=True) for n_, e_ in zip(n, e)]

    rewardF = reward(mols)

    feed_dict = {
        model.edges_labels: a,
        model.nodes_labels: x,
        model.embeddings: embeddings,
        model.rewardR: rewardR,
        model.rewardF: rewardF,
        model.training: False,
    }
    return feed_dict


def reward(mols):
    rr = 1.0
    for m in ("logp,sas,qed,unique" if metric == "all" else metric).split(","):

        if m == "np":
            rr *= MolecularMetrics.natural_product_scores(mols, norm=True)
        elif m == "logp":
            rr *= MolecularMetrics.water_octanol_partition_coefficient_scores(
                mols, norm=True
            )
        elif m == "sas":
            rr *= MolecularMetrics.synthetic_accessibility_score_scores(mols, norm=True)
        elif m == "qed":
            rr *= MolecularMetrics.quantitative_estimation_druglikeness_scores(
                mols, norm=True
            )
        elif m == "novelty":
            rr *= MolecularMetrics.novel_scores(mols, data)
        elif m == "dc":
            rr *= MolecularMetrics.drugcandidate_scores(mols, data)
        elif m == "unique":
            rr *= MolecularMetrics.unique_scores(mols)
        elif m == "diversity":
            rr *= MolecularMetrics.diversity_scores(mols, data)
        elif m == "validity":
            rr *= MolecularMetrics.valid_scores(mols)
        else:
            raise RuntimeError("{} is not defined as a metric".format(m))

    return rr.reshape(-1, 1)


def _eval_update(i, epochs, min_epochs, model, optimizer, batch_dim, eval_batch):
    mols = samples(data, model, session, model.sample_z(n_samples), sample=True)
    m0, m1 = all_scores(mols, data, norm=True)
    m0 = {k: np.array(v)[np.nonzero(v)].mean() for k, v in m0.items()}
    m0.update(m1)
    return m0


def _test_update(model, optimizer, batch_dim, test_batch, create_G_data=False):
    noise_vectors = model.sample_z(n_samples)
    mols, n, e = samples(data, model, session, noise_vectors, sample=True)  # , n, e
    m0, m1 = all_scores(mols, data, norm=True)
    m0 = {k: np.array(v)[np.nonzero(v)].mean() for k, v in m0.items()}
    m0.update(m1)
    # return m0
    return mols, n, e, noise_vectors


# model
model = GraphGANModel(
    data.vertexes,
    data.bond_num_types,
    data.atom_num_types,
    z_dim,
    decoder_units=(128, 256, 512),
    discriminator_units=((128, 64), 128, (128, 64)),
    decoder=decoder_adj,
    discriminator=encoder_rgcn,
    soft_gumbel_softmax=False,
    hard_gumbel_softmax=False,
    batch_discriminator=False,
)

# optimizer
optimizer = GraphGANOptimizer(model, learning_rate=1e-3, feature_matching=False)

# session
session = tf.Session()
session.run(tf.global_variables_initializer())

# trainer
trainer = Trainer(model, optimizer, session)

print(
    "Parameters: {}".format(
        np.sum([np.prod(e.shape) for e in session.run(tf.trainable_variables())])
    )
)

if create_G_data == False:
    trainer.train(
        batch_dim=batch_dim,
        epochs=epochs,
        steps=steps,
        train_fetch_dict=train_fetch_dict,
        train_feed_dict=train_feed_dict,
        eval_fetch_dict=eval_fetch_dict,
        eval_feed_dict=eval_feed_dict,
        test_fetch_dict=test_fetch_dict,
        test_feed_dict=test_feed_dict,
        save_every=save_every,
        directory="./results",
        # here users need to first create and then specify a folder where to save the model
        _eval_update=_eval_update,
        _test_update=_test_update,
    )
