import argparse
import os

import pytorch_lightning as pl
import wandb
from pytorch_lightning.callbacks import ModelCheckpoint
from rdkit import RDLogger

from src.data.mol_module import MolDataModule

from src.metrics import calculate_all_sampling_metrics
from src.model.load_utils import load_model_from_id
from src.model.vae import MAGNet
from src.utils import ROOT_DIR, WB_LOG_PATH, save_model_config_to_file, WB_ENTITY

RDLogger.DisableLog("rdApp.*")


def run_molgnn_training(**kwargs):
    pl.utilities.seed.seed_everything(kwargs["seed_model"])

    dm_kwargs = dict(
        dataset=kwargs["dataset"],
        batch_size=kwargs["batch_size"],
        num_workers=kwargs["num_workers"],
        cache_dataset=kwargs["cache_dataset"],
    )
    if kwargs["cache_dataset"]:
        assert kwargs["num_workers"] == 0, "Caching and multiprocessing don't work together"

    dm = MolDataModule(**dm_kwargs)
    dm.setup()

    model_kwargs = dict(
        lr=kwargs["lr"],
        lr_sch_decay=kwargs["lr_sch_decay"],
        dim_config=kwargs["dim_config"],
        layer_config=kwargs["layer_config"],
        loss_weights=kwargs["loss_weights"],
        beta_annealing=kwargs["beta_annealing"],
    )
    model = MAGNet(
        feature_sizes=dm.feature_sizes,
        **model_kwargs,
    ).cuda()

    trainer_kwargs = dict(
        accelerator="gpu",
        gpus=1,
        num_sanity_val_steps=0,
        log_every_n_steps=1,
        max_epochs=kwargs["epochs"],
        check_val_every_n_epoch=1,
        gradient_clip_val=kwargs["gradclip"],
    )
    logger = pl.loggers.WandbLogger(
        entity=WB_ENTITY,
        project=WB_ENTITY,
        save_dir=str(WB_LOG_PATH),
    )
    wandb.config.update(kwargs)
    trainer_kwargs.update({"enable_progress_bar": False})
    checkpointing = ModelCheckpoint(
        monitor="val_loss",
        filename="model-{epoch:02d}-{val_loss:.2f}",
        save_last=True,
    )
    trainer_kwargs.update({"logger": logger, "callbacks": [checkpointing]})
    save_model_config_to_file(WB_ENTITY, str(logger.version), kwargs, model)

    trainer = pl.Trainer(**trainer_kwargs)
    trainer.fit(model, datamodule=dm)
    wandb.finish()

    model = load_model_from_id(
        WB_ENTITY,
        str(logger.version),
        dataset=kwargs["dataset"],
        model_class=MAGNet,
    )
    output_smiles = model.sample_molecules(10000)
    results = calculate_all_sampling_metrics(output_smiles, kwargs["dataset"])
    results["generated_smiles"] = output_smiles
    return results


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--seed_model", default=0)
    parser.add_argument("--dataset", default="ZINC")
    parser.add_argument("--batch_size", type=int, default=64)
    parser.add_argument("--epochs", type=int, default=30)
    parser.add_argument("--cache_dataset", type=bool, default=True)
    parser.add_argument("--num_workers", type=int, default=0)
    parser.add_argument("--lr", type=float, default=3.07e-4)
    parser.add_argument("--lr_sch_decay", type=float, default=0.9801)
    parser.add_argument(
        "--dim_config",
        default=dict(
            latent_dim=100,
            atom_id_dim=25,
            atom_charge_dim=10,
            shape_id_dim=35,
            atom_multiplicity_dim=10,
            shape_multiplicity_dim=10,
            motif_positional_dim=15,
            motif_seq_positional_dim=15,
            motif_feat_dim=50,
            enc_atom_dim=25,
            enc_shapes_dim=25,
            enc_joins_dim=25,
            enc_leafs_dim=25,
            enc_global_dim=25,
            leaf_rnn_hidden=256,
            shape_rnn_hidden=256,
            shape_gnn_dim=128,
            max_shape_mult=20,
            max_atom_mult=40,
            add_std=0.1,
        ),
    )
    parser.add_argument(
        "--layer_config",
        default=dict(
            num_layers_enc=2,
            num_layers_hgraph=3,
            num_layers_latent=2,
            num_layers_shape_enc=4,
            node_aggregation="sum",
        ),
    )
    parser.add_argument(
        "--loss_weights",
        type=float,
        default=dict(shapeset=1, shapeadj=1, motifs=1, joins=1, leafs=1),
    )
    parser.add_argument("--gradclip", type=float, default=10)
    parser.add_argument("--beta_annealing", default=dict(init=0, max=0.01, start=2000, every=2500, step=0.0005))
    args = parser.parse_args()
    os.chdir(ROOT_DIR)
    results = run_molgnn_training(**vars(args))
    print(results)
