import graph_tool as gt
import os
import pathlib
import warnings

import torch

torch.cuda.empty_cache()
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.utilities.warnings import PossibleUserWarning

from src import utils
from metrics.abstract_metrics import TrainAbstractMetricsDiscrete, TrainAbstractMetrics
from bayesian_flow_discrete import BFN_Discrete
from diffusion_model import LiftedDenoisingDiffusion
from diffusion_model_discrete import DiscreteDenoisingDiffusion
from diffusion.extra_features import DummyExtraFeatures, ExtraFeatures
import argparse
from absl import logging
from src.config.config import Config
from pytorch_lightning.loggers import WandbLogger
import datetime, pytz

import pdb
from src.callbacks.basic import EMACallback
from src.metrics import molecular_metrics
from torch_geometric.data.lightning import LightningDataset as tgLightningDataset

warnings.filterwarnings("ignore", category=PossibleUserWarning)


def get_resume(cfg, model_kwargs):
    """Resumes a run. It loads previous config without allowing to update keys (used for testing)."""
    saved_cfg = cfg.copy()
    name = cfg.exp_name + "_resume"
    resume = cfg.general.test_only
    if cfg.model.type == "discrete":
        model = DiscreteDenoisingDiffusion.load_from_checkpoint(resume, **model_kwargs)
    else:
        model = LiftedDenoisingDiffusion.load_from_checkpoint(resume, **model_kwargs)
    cfg = model.cfg
    cfg.general.test_only = resume
    cfg.exp_name = name
    cfg = utils.update_config_with_new_keys(cfg, saved_cfg)
    return cfg, model


def get_resume_adaptive(cfg, model_kwargs):
    """Resumes a run. It loads previous config but allows to make some changes (used for resuming training)."""
    saved_cfg = cfg.copy()
    # Fetch path to this file to get base path
    current_path = os.path.dirname(os.path.realpath(__file__))
    root_dir = current_path.split("outputs")[0]

    resume_path = os.path.join(root_dir, cfg.general.resume)

    if cfg.model.type == "discrete":
        model = DiscreteDenoisingDiffusion.load_from_checkpoint(
            resume_path, **model_kwargs
        )
    else:
        model = LiftedDenoisingDiffusion.load_from_checkpoint(
            resume_path, **model_kwargs
        )
    new_cfg = model.cfg

    for category in cfg:
        for arg in cfg[category]:
            new_cfg[category][arg] = cfg[category][arg]

    new_cfg.general.resume = resume_path
    new_cfg.exp_name = new_cfg.exp_name + "_resume"

    new_cfg = utils.update_config_with_new_keys(new_cfg, saved_cfg)
    return new_cfg, model


def main(cfg):
    """ "Step1: Dataset Prep"""
    dataset_config = cfg.dataset.todict()

    if dataset_config["name"] in ["sbm", "comm-20", "planar", "protein"]:
        from datasets.spectre_dataset import SpectreGraphDataModule, SpectreDatasetInfos
        from analysis.spectre_utils import (
            PlanarSamplingMetrics,
            SBMSamplingMetrics,
            Comm20SamplingMetrics,
            ProteinSamplingMetrics,
        )
        from analysis.visualization import NonMolecularVisualization

        datamodule = SpectreGraphDataModule(cfg)
        if dataset_config["name"] == "sbm":
            sampling_metrics = SBMSamplingMetrics(datamodule)
        elif dataset_config["name"] == "comm-20":
            sampling_metrics = Comm20SamplingMetrics(datamodule)
        elif dataset_config["name"] == "protein":
            sampling_metrics = ProteinSamplingMetrics(datamodule)
        # elif dataset_config["name"] == "ZINC25k":
        # sampling_metrics = ZINC25kSamplingMetrics(datamodule)
        else:
            sampling_metrics = PlanarSamplingMetrics(datamodule)

        dataset_infos = SpectreDatasetInfos(datamodule, dataset_config)
        train_metrics = (
            TrainAbstractMetricsDiscrete()
            if cfg.model.type == "discrete"
            else TrainAbstractMetrics()
        )
        visualization_tools = NonMolecularVisualization()

        if cfg.model.type in [
            "discrete",
            "bayesian",
        ] and cfg.model.extra_features not in [None, "null"]:
            if cfg.model.extra_mode == "direct":
                # cfg.model.extra_features = "direct_only"
                extra_features = ExtraFeatures(
                    "direct_only", dataset_info=dataset_infos
                )
                print(f"Extra Feature direct_only is Enabled!")
            else:
                extra_features = ExtraFeatures(
                    cfg.model.extra_features, dataset_info=dataset_infos
                )
                print(f"Extra Feature '{cfg.model.extra_features}' is Enabled!")
        else:
            print("Extra Feature DISabled!")
            extra_features = DummyExtraFeatures()

        domain_features = DummyExtraFeatures()

        dataset_infos.compute_input_output_dims(
            datamodule=datamodule,
            extra_features=extra_features,
            domain_features=domain_features,
        )

        model_kwargs = {
            "dataset_infos": dataset_infos,
            "train_metrics": train_metrics,
            "sampling_metrics": sampling_metrics,
            "visualization_tools": visualization_tools,
            "extra_features": extra_features,
            "domain_features": domain_features,
        }

    elif dataset_config["name"] in ["qm9", "zinc", "guacamol", "moses"]:
        from metrics.molecular_metrics import (
            TrainMolecularMetrics,
            SamplingMolecularMetrics,
        )
        from metrics.molecular_metrics_discrete import TrainMolecularMetricsDiscrete
        from diffusion.extra_features_molecular import ExtraMolecularFeatures
        from analysis.visualization import MolecularVisualization

        with_moses_metrics = False
        test_smiles = None

        if dataset_config["name"] == "qm9":
            from datasets import qm9_dataset

            datamodule = qm9_dataset.QM9DataModule(cfg)
            dataset_infos = qm9_dataset.QM9infos(datamodule=datamodule, cfg=cfg)
            train_smiles = qm9_dataset.get_train_smiles(
                cfg=cfg,
                train_dataloader=datamodule.train_dataloader(),
                dataset_infos=dataset_infos,
                evaluate_dataset=False,
            )
        if dataset_config["name"] == "zinc":
            from datasets import zinc_dataset

            datamodule = zinc_dataset.ZincDataModule(cfg)
            dataset_infos = zinc_dataset.Zincinfos(datamodule, cfg)
            train_smiles = datamodule.train_dataset.smiles
            test_smiles = datamodule.test_dataset.smiles
            val_smiles = datamodule.val_dataset.smiles

        elif dataset_config["name"] == "guacamol":
            from datasets import guacamol_dataset

            datamodule = guacamol_dataset.GuacamolDataModule(cfg)
            dataset_infos = guacamol_dataset.Guacamolinfos(datamodule, cfg)
            train_smiles = None

        elif dataset_config["name"] == "moses":
            from datasets import moses_dataset

            datamodule = moses_dataset.MosesDataModule(cfg)
            dataset_infos = moses_dataset.MOSESinfos(datamodule, cfg)
            train_smiles = molecular_metrics.get_dataset_smiles(
                cfg.dataset.datadir + "/raw", "train"
            )
            test_smiles = molecular_metrics.get_dataset_smiles(
                cfg.dataset.datadir + "/raw", "test"
            )  # test_scaffolds
            if cfg.general.test_only:
                with_moses_metrics = (
                    True  # todo get the train smile and test smiles.  only for test.
                )
        else:
            raise ValueError("Dataset not implemented")

        if cfg.model.type in [
            "discrete",
            "bayesian",
        ] and cfg.model.extra_features not in [None, "null"]:
            extra_features = ExtraFeatures(
                cfg.model.extra_features, dataset_info=dataset_infos
            )
            domain_features = ExtraMolecularFeatures(dataset_infos=dataset_infos)
            print("Extra Feature Enabled!")
        else:
            extra_features = DummyExtraFeatures()
            domain_features = DummyExtraFeatures()
            print("Extra Feature DISabled!")

        dataset_infos.compute_input_output_dims(
            datamodule=datamodule,
            extra_features=extra_features,
            domain_features=domain_features,
        )

        if cfg.model.type == "discrete":
            train_metrics = TrainMolecularMetricsDiscrete(dataset_infos)
        else:
            train_metrics = TrainMolecularMetrics(dataset_infos)

        # We do not evaluate novelty during training
        visualization_tools = MolecularVisualization(
            cfg.dataset.remove_h, dataset_infos=dataset_infos
        )
        sampling_metrics = SamplingMolecularMetrics(
            dataset_infos,
            train_smiles,
            test_smiles,
            with_moses_metrics=with_moses_metrics,
        )

        model_kwargs = {
            "dataset_infos": dataset_infos,
            "train_metrics": train_metrics,
            "sampling_metrics": sampling_metrics,
            "visualization_tools": visualization_tools,
            "extra_features": extra_features,
            "domain_features": domain_features,
        }
    else:
        raise NotImplementedError("Unknown dataset {}".format(cfg["dataset"]))

    os.makedirs(cfg.general.chains_path, exist_ok=True)
    os.makedirs(cfg.general.graphs_path, exist_ok=True)

    # model = BFN_Discrete(cfg=cfg, **model_kwargs)

    if cfg.model.type == "discrete":
        model = DiscreteDenoisingDiffusion(cfg=cfg, **model_kwargs)
    elif cfg.model.type == "bayesian":
        model = BFN_Discrete(cfg=cfg, **model_kwargs)
    else:
        model = LiftedDenoisingDiffusion(cfg=cfg, **model_kwargs)

    monitor = None
    ckpt_filename = None
    if dataset_config["name"] in ["comm-20"]:
        monitor = "clustering"
    if dataset_config["name"] in ["protein"]:
        monitor = "orbit"
        ckpt_filename = "{epoch}-{orbit:2f}"
    elif dataset_config["name"] in ["planar"]:
        monitor = "frac_unic_non_iso_valid"  # V.U.N.
        ckpt_filename = "{epoch}-{frac_unic_non_iso_valid:2f}"
    elif dataset_config["name"] in ["sbm"]:
        monitor = "frac_unic_non_iso_valid"  # V.U.N.
        ckpt_filename = "{epoch}-{frac_unic_non_iso_valid:2f}"
    elif dataset_config["name"] in ["qm9"]:
        monitor = "mol_stable"  # Molecular Stability
        ckpt_filename = "{epoch}-{mol_stable:2f}-{atm_stable:2f}-{Validity:2f}"
    elif dataset_config["name"] in ["guacamol", "moses", "zinc"]:
        monitor = "RValidity"
        ckpt_filename = "{epoch}-{RValidity:2f}"

    callbacks = []
    if cfg.train.save_model:
        checkpoint_callback = ModelCheckpoint(
            dirpath=f"{cfg.general.logdir}/checkpoints/{cfg.exp_name}",
            filename=ckpt_filename,
            save_last=True,
            save_top_k=20,
            mode="max",
            every_n_epochs=cfg.general.check_point_every_n_epochs,
            monitor=monitor,
            save_on_train_epoch_end=False,
        )

        callbacks.append(checkpoint_callback)

    if cfg.train.ema_decay > 0:
        ema_callback = EMACallback(decay=cfg.train.ema_decay, ema_device="cuda")
        callbacks.append(ema_callback)

    if cfg.debug:
        print("[WARNING]: Run is called 'debug' -- it will run with fast_dev_run. ")

    if cfg.overfit_batches > 0:
        print(
            f"[WARNING]: Run is enabled with overfit_batches -- it will overfit {cfg.overfit_batches} batches, each containing {cfg.train.batch_size} training sample(s). "
        )

    # for data in datamodule:
    #     print(data.x, data.edge_index, data.edge_attr)
    avilable_gpu = torch.cuda.device_count()
    # if avilable_gpu > 1:
    #     datamodule = tgLightningDataset(
    #         train_dataset=datamodule.train_dataset,
    #         val_dataset=datamodule.val_dataset,
    #         test_dataset=datamodule.test_dataset,
    #         batch_size=cfg.train.batch_size,
    #         num_workers=cfg.train.num_workers,
    #     )

    wandb_logger = WandbLogger(
        name=cfg.exp_name
        + f'_{datetime.datetime.now(pytz.timezone("Universal")).strftime("%Y-%m-%d-%H:%M:%S")}',
        project=cfg.project_name,
        offline=cfg.debug or cfg.no_wandb,
        save_dir=cfg.general.wandb_dir,
    )  # add wandb parameters

    trainer = Trainer(
        default_root_dir=cfg.general.logdir,
        gradient_clip_val=cfg.train.clip_grad,
        strategy="ddp_find_unused_parameters_true" if avilable_gpu > 1 else "auto",
        # strategy="ddp_find_unused_parameters_true",  # Needed to load old checkpoints
        # accelerator="gpu" if use_gpu else "cpu",
        logger=wandb_logger,
        max_epochs=cfg.train.n_epochs,
        check_val_every_n_epoch=cfg.general.check_val_every_n_epochs,
        fast_dev_run=cfg.debug,
        overfit_batches=cfg.overfit_batches,
        enable_progress_bar=True,
        callbacks=callbacks,
        num_sanity_val_steps=2,
        accumulate_grad_batches=cfg.train.todict().get("accumulate_grad_batches", 1),
    )
    if not cfg.general.test_only:
        trainer.fit(model, datamodule=datamodule, ckpt_path=cfg.general.resume)
        if cfg.exp_name not in ["debug", "test"]:
            trainer.test(model, datamodule=datamodule)
    else:
        trainer.test(model, datamodule=datamodule, ckpt_path=cfg.general.test_only)
        if cfg.general.evaluate_all_checkpoints:
            directory = pathlib.Path(cfg.general.test_only).parents[0]
            print("Directory:", directory)
            files_list = os.listdir(directory)
            for file in files_list:
                if ".ckpt" in file:
                    ckpt_path = os.path.join(directory, file)
                    if ckpt_path == cfg.general.test_only:
                        continue
                    print("Loading checkpoint", ckpt_path)
                    trainer.test(model, datamodule=datamodule, ckpt_path=ckpt_path)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--config_file",
        type=str,
        default="configs/debug.yaml",
    )
    parser.add_argument("--batch_size", type=int, default=512)
    parser.add_argument("--accumulate_grad_batches", type=int, default=1)
    parser.add_argument("--lr", type=float, default=0.0002)
    parser.add_argument("--exp_name", type=str, default="debug")
    parser.add_argument("--logging_level", type=str, default="warning")
    parser.add_argument("--debug", action="store_true")
    parser.add_argument("--overfit_batches", type=int, default=0.0)
    parser.add_argument("--no_wandb", action="store_true")
    parser.add_argument("--epochs", type=int, default=1000)
    parser.add_argument("--sample_steps", type=int, default=500)
    parser.add_argument("--clip_grad", type=float, default=0.0)
    parser.add_argument("--ema_decay", type=float, default=0.0)
    parser.add_argument("--resume", type=str, default=None)
    parser.add_argument("--test_ckpt_fname", type=str, default=None)
    parser.add_argument("--sampling_bs", type=int, default="2048")

    parser.add_argument("--lambda_train_node", type=float, default=1.0)
    parser.add_argument("--lambda_train_edge", type=float, default=1.0)
    parser.add_argument("--lambda_train_y", type=float, default=0.0)

    parser.add_argument("--beta_node", type=float, default=3.0)
    parser.add_argument("--beta_node_init", type=float, default=0.0)

    parser.add_argument("--beta_edge", type=float, default=3.0)
    parser.add_argument("--beta_edge_init", type=float, default=0.0)

    # beta_init is used to init the start point of the model.[\aplha_0].  \beta(0) = 0 => \beta(1) = beta.

    parser.add_argument("--input_dist_sample", action="store_true")

    parser.add_argument("--prior", type=str, default="uniform")
    parser.add_argument("--extra_features", type=str)  # default is None
    parser.add_argument("--extra_mode", type=str, default="prob")
    parser.add_argument("--n_iid", type=int, default=10)
    parser.add_argument("--node_time_scheduler", type=str, default="quad")
    parser.add_argument("--edge_time_scheduler", type=str, default="quad")

    # Dev switches:
    parser.add_argument("--discretised_time", action="store_true")
    parser.add_argument("--force_symmetric_theta_E", action="store_true")
    parser.add_argument("--compare_input_output_dist_samples", action="store_true")
    parser.add_argument("--plot_input_dist_entropy", action="store_true")
    parser.add_argument(
        "--alternative_sampling_theta_update_ratio", type=float, default=0.0
    )
    parser.add_argument("--output_dist_extra", action="store_true")

    _args = parser.parse_args()
    print(_args)
    # _args, unknown = parser.parse_known_args()
    cfg = Config(**_args.__dict__)
    print(f"The config of this process is:\n{cfg}")
    logging_level = {
        "info": logging.INFO,
        "debug": logging.DEBUG,
        "warning": logging.WARNING,
        "error": logging.ERROR,
        "fatal": logging.FATAL,
    }
    logging.set_verbosity(logging_level[cfg.logging_level])
    # create dir if not exist
    os.makedirs(cfg.general.wandb_dir, exist_ok=True)

    main(cfg)
