import argparse
import os

import pytorch_lightning as pl
import wandb
from pytorch_lightning.callbacks import ModelCheckpoint
from tqdm.auto import tqdm

from src.data.latent_module import LatentDataModule
from src.metrics import calculate_all_sampling_metrics
from src.model.flow_vae import FlowMAGNet, get_flow_training_args
from src.model.load_utils import load_model_from_id
from src.utils import (
    ROOT_DIR,
    WB_LOG_PATH,
    save_model_config_to_file,
    WB_ENTITY,
    WB_COLLECTION,
)


def run_fm_train(**kwargs):
    kwargs = get_flow_training_args(kwargs)
    dm = LatentDataModule(
        collection=WB_COLLECTION,
        model_id=kwargs["magnet_id"],
        batch_size=kwargs["batch_size"],
        ndatapoints=kwargs["n_datapoints"],
        num_workers=kwargs["num_workers"],
    )
    model, config = load_model_from_id(
        collection=WB_COLLECTION,
        run_id=kwargs["magnet_id"],
        load_config=dict(
            patience=max(1, kwargs["epochs"] // 400),
            lr=kwargs["lr"],
            lr_sch_decay=kwargs["lr_sch_decay"],
            flow_dim_config=kwargs["flow_dim_config"],
            sample_config=kwargs["sample_config"],
        ),
        model_class=FlowMAGNet,
        seed_model=kwargs["seed_model"],
        return_config=True,
    )
    model.cuda()

    logger = pl.loggers.WandbLogger(
        project=results,
        entity=WB_ENTITY,
        save_dir=str(WB_LOG_PATH),
    )
    wandb.config.update(kwargs)
    save_model_config_to_file(results, str(logger.version), config, model)

    checkpointing = ModelCheckpoint(save_last=True, save_on_train_epoch_end=True)
    trainer = pl.Trainer(
        accelerator="gpu",
        gpus=1,
        num_sanity_val_steps=0,
        log_every_n_steps=20,
        max_epochs=kwargs["epochs"],
        gradient_clip_val=kwargs["gradient_clip"],
        logger=logger,
        callbacks=[checkpointing],
        enable_progress_bar=False,
        check_val_every_n_epoch=kwargs["val_n_epochs"],
        limit_val_batches=1,
    )
    trainer.fit(model, datamodule=dm)

    all_smiles = model.sample_molecules(10000)
    results = calculate_all_sampling_metrics(all_smiles, "zinc")
    return results


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--seed_model", default=0)
    parser.add_argument("--batch_size", type=int, default=1024)
    parser.add_argument("--epochs", type=int, default=5000)
    parser.add_argument("--val_n_times", type=int, default=1)
    parser.add_argument("--num_workers", type=int, default=3)
    parser.add_argument("--magnet_id", type=str, default="TO-BE-REPLACED")
    parser.add_argument("--flow_dim_config", type=dict, default=dict(hidden_dim=512))
    args = parser.parse_args()
    os.chdir(ROOT_DIR)

    results = run_fm_train(**vars(args))
