# Do not move these imports, the order seems to matter
import torch
import pytorch_lightning as pl
import numpy as np

import os
import warnings
import pathlib

import hydra
import omegaconf
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.utilities.warnings import PossibleUserWarning

from didigress.analysis.rdkit_functions import make_molecular_list

from didigress.datasets import qm9_dataset, geom_dataset, zinc250k_dataset, moses_dataset, guacamol_dataset
from didigress.digress import DiGress
from didigress.freegress import FreeGress

import torchmetrics
from didigress.metrics.train_metrics import TrainLoss
from didigress.metrics.abstract_metrics import SumExceptBatchMetric, SumExceptBatchKL, NLL
import didigress.metrics.abstract_metrics as custom_metrics
from didigress.metrics.molecular_metrics import TrainMolecularMetrics, SamplingMetrics

from didigress.diffusion.noise_model import DiscreteUniformTransition, MarginalUniformTransition

from didigress.analysis.molecular_visualization import MolecularVisualization

warnings.filterwarnings("ignore", category=PossibleUserWarning)

from omegaconf import DictConfig, OmegaConf, open_dict
from didigress import utils

def get_resume(cfg, checkpoint_path, test: bool, model_kwargs):
    saved_cfg = cfg.copy()
    
    name   = cfg.general.name + ('_test' if test else '_resume')
    gpus   = cfg.general.gpus
    test_only = cfg.general.test_only if test else None
    resume = cfg.general.resume if not test else None

    batch_size                  = cfg.train.batch_size
    n_epochs                    = cfg.train.n_epochs
    n_layers                    = cfg.model.n_layers
    n_test_molecules_to_sample  = cfg.guidance.n_test_molecules_to_sample
    n_samples_per_test_molecule = cfg.guidance.n_samples_per_test_molecule
    s                           = cfg.guidance.s
    include_split               = cfg.guidance.include_split
    wandb                       = cfg.general.wandb
    final_model_samples_to_generate  = cfg.general.final_model_samples_to_generate
    final_model_chains_to_save  = cfg.general.final_model_chains_to_save
    final_model_samples_to_save = cfg.general.final_model_samples_to_save
    max_n                       = cfg.features.max_n
    sampling_nT                 = cfg.features.sampling_nT
    number_chain_steps          = cfg.general.number_chain_steps
    freeze_n_nodes_at_sampling  = cfg.features.freeze_n_nodes_at_sampling
    chains_to_save              = cfg.general.chains_to_save
    samples_to_generate         = cfg.general.samples_to_generate
    samples_to_save             = cfg.general.samples_to_save
    experiment_type             = cfg.guidance.experiment_type
    improvement_threshold       = cfg.guidance.improvement_threshold
    corruption_step             = cfg.guidance.corruption_step
    similarity_threshold        = cfg.guidance.similarity_threshold
    improvement_type            = cfg.guidance.improvement_type
    improvement_target          = cfg.guidance.improvement_target
    improvement_limits          = cfg.guidance.improvement_limits
    progress_bar                = cfg.train.progress_bar
    delt_model_path             = cfg.features.delt_model_path
    zeta_w                      = cfg.features.zeta_w
    fixed_sampling_target       = cfg.guidance.fixed_sampling_target

    if(cfg.guidance.p_uncond < 0):
        model = DiGress.load_from_checkpoint(checkpoint_path=checkpoint_path,
                                            map_location='cpu', **model_kwargs)
    else:
        model = FreeGress.load_from_checkpoint(checkpoint_path=checkpoint_path, 
                                            map_location='cpu', **model_kwargs)
    cfg = model.cfg
    cfg.general.test_only = test_only
    cfg.general.resume = resume
    cfg.general.name = name

    del(model)
    
    OmegaConf.set_struct(cfg, True)
    with open_dict(cfg):
        cfg.guidance.experiment_type                = experiment_type           
        cfg.guidance.improvement_threshold          = improvement_threshold     
        cfg.guidance.corruption_step                = corruption_step
        cfg.guidance.similarity_threshold           = similarity_threshold      
        cfg.features.freeze_n_nodes_at_sampling     = freeze_n_nodes_at_sampling
        cfg.general.number_chain_steps              = number_chain_steps
        cfg.features.sampling_nT                    = sampling_nT
        cfg.features.max_n                          = max_n
        cfg.general.samples_to_save                 = samples_to_save
        cfg.general.chains_to_save                  = chains_to_save     
        cfg.general.samples_to_generate             = samples_to_generate  
        cfg.general.final_model_samples_to_save     = final_model_samples_to_save
        cfg.general.final_model_chains_to_save      = final_model_chains_to_save     
        cfg.general.final_model_samples_to_generate = final_model_samples_to_generate
        cfg.general.wandb                           = wandb
        cfg.guidance.include_split                  = include_split
        cfg.guidance.s                              = s
        cfg.guidance.n_samples_per_test_molecule    = n_samples_per_test_molecule
        cfg.guidance.n_test_molecules_to_sample     = n_test_molecules_to_sample
        cfg.model.n_layers                          = n_layers
        cfg.train.n_epochs                          = n_epochs
        cfg.train.batch_size                        = batch_size
        cfg.guidance.improvement_type               = improvement_type
        cfg.guidance.improvement_target             = improvement_target
        cfg.guidance.improvement_limits             = improvement_limits 
        cfg.train.progress_bar                      = progress_bar
        cfg.features.delt_model_path                = delt_model_path
        cfg.features.max_n                          = max_n
        cfg.features.zeta_w                         = zeta_w
        cfg.guidance.fixed_sampling_target          = fixed_sampling_target

        cfg.general.gpus = gpus
        cfg.general.name = name

    cfg = utils.update_config_with_new_keys(cfg, saved_cfg)
    return cfg


@hydra.main(version_base='1.3', config_path='../configs', config_name='config')
def main(cfg: omegaconf.DictConfig):
    print("CFG = ", cfg)
    dataset_config = cfg.dataset

    print("get_num_threads", torch.get_num_threads())
    torch.set_num_threads(cfg.train.num_workers)
    torch.set_num_interop_threads(cfg.train.num_workers)
    print("get_num_threads", torch.get_num_threads())

    torch.manual_seed(cfg.train.seed)
    pl.seed_everything(cfg.train.seed)
    np.random.seed(cfg.train.seed)
    
    if dataset_config.name in ['qm9', "geom", "zinc250k", 'moses', 'guacamol']:
        if dataset_config.name == 'qm9':
            datamodule = qm9_dataset.QM9DataModule(cfg)
            dataset_infos = qm9_dataset.QM9infos(datamodule=datamodule, cfg=cfg)
        elif dataset_config.name == 'zinc250k':
            datamodule = zinc250k_dataset.ZINC250KDataModule(cfg)
            dataset_infos = zinc250k_dataset.ZINC250Kinfos(datamodule=datamodule, cfg=cfg)
            cfg.features.use_3d = False
        elif dataset_config.name == 'moses':
            datamodule = moses_dataset.MOSESDataModule(cfg)
            dataset_infos = moses_dataset.MOSESinfos(datamodule=datamodule, cfg=cfg)
            cfg.features.use_3d = False
            cfg.features.use_charges = False
            cfg.features.charges_policy="no"
        elif dataset_config.name == 'guacamol':
            datamodule = guacamol_dataset.GuacamolDataModule(cfg)
            dataset_infos = guacamol_dataset.Guacamolinfos(datamodule=datamodule, cfg=cfg)
            cfg.features.use_3d = False
            cfg.features.use_charges = False
            cfg.features.charges_policy="dictionary"
        else:
            datamodule = geom_dataset.GeomDataModule(cfg)
            dataset_infos = geom_dataset.GeomInfos(datamodule=datamodule, cfg=cfg)

        val_metrics_array = [custom_metrics.XKl(), custom_metrics.EKl()]
        test_metrics_array = [custom_metrics.XKl(), custom_metrics.EKl()]
        if(cfg.features.use_3d):
            val_metrics_array.append(custom_metrics.PosMSE())   
            test_metrics_array.append(custom_metrics.PosMSE())
        if(cfg.features.use_charges):
            val_metrics_array.append(custom_metrics.ChargesKl())
            test_metrics_array.append(custom_metrics.ChargesKl())

        # Train Metrics
        train_smiles = list(datamodule.train_dataloader().dataset.smiles) if cfg.general.test_only else []
        train_loss = TrainLoss(lambda_train=cfg.model.lambda_train
                                     if hasattr(cfg.model, "lambda_train") else cfg.train.lambda0,
                                     cfg = cfg)
        train_metrics = TrainMolecularMetrics(dataset_infos)    #TODO: check why they removed everything
        
        #TODO: convert SamplingMetrics to SamplingMetricsMolecular
        # Val Metrics
        val_loss = TrainLoss(lambda_train=cfg.model.lambda_train
                            if hasattr(cfg.model, "lambda_train") else cfg.train.lambda0,
                            cfg = cfg, name="val")
        val_metrics = torchmetrics.MetricCollection(val_metrics_array)
        val_nll = NLL()
        val_sampling_metrics = SamplingMetrics(train_smiles, dataset_infos, test=False)

        # Test metrics
        test_metrics = torchmetrics.MetricCollection(test_metrics_array)
        test_nll = NLL()
        test_sampling_metrics = SamplingMetrics(train_smiles, dataset_infos, test=True)

        ignore_hyperparameters = ['train_metrics', 'val_sampling_metrics', 
                                  'test_sampling_metrics', 'dataset_infos', 'train_smiles', 'losses']

        loss_and_metrics = {'train_loss' : train_loss, 'train_metrics': train_metrics,

                            'val_metrics': val_metrics, 'val_nll': val_nll, 
                            'val_sampling_metrics': val_sampling_metrics, 
                            'val_loss': val_loss,

                            'test_metrics': test_metrics, 'test_nll': test_nll,
                            'test_sampling_metrics': test_sampling_metrics,
                            'graph_list_maker': make_molecular_list,
                            'ignore_hyperparameters': ignore_hyperparameters}
        
        visualizer = MolecularVisualization(cfg=cfg)

        """
        # TODO: change DiscreteUniformTransition and MarginalUniformTransition
        #       into something specialized for molecules
        """
        if cfg.model.transition == 'uniform':
            noise_model = DiscreteUniformTransition(output_dims=dataset_infos.output_dims, cfg=cfg)
        elif cfg.model.transition == 'marginal':
            #TODO: remove charges if not used?
            print(f"Marginal distribution of the classes: nodes: {dataset_infos.node_types} --"
                  f" edges: {dataset_infos.edge_types} -- charges: {dataset_infos.charges_marginals}")

            noise_model = MarginalUniformTransition(dataset_infos=dataset_infos, cfg=cfg)
        else:
            assert ValueError(f"Transition type '{cfg.model.transition}' not implemented.")

        model_kwargs = {"dataset_infos":dataset_infos, "losses":loss_and_metrics, 
                        "noise_model":noise_model, "visualizer":visualizer}
    else:
        raise NotImplementedError("Unknown dataset {}".format(cfg["dataset"]))

    if cfg.general.test_only:
        cfg = get_resume(cfg, cfg.general.test_only, True, model_kwargs)
    elif cfg.general.resume is not None:
        # When resuming, we can override some parts of previous configuration
        print("Resuming from {}".format(cfg.general.resume))
        cfg = get_resume(cfg, cfg.general.resume, False, model_kwargs)

    # utils.create_folders(cfg)
    if(cfg.guidance.p_uncond < 0):
        model = DiGress(cfg=cfg, **model_kwargs)
    else:
        model = FreeGress(cfg=cfg, **model_kwargs)

    callbacks = []
    # need to ignore metrics because otherwise ddp tries to sync them
    params_to_ignore = ['module.model.train_smiles', 'module.model.dataset_infos']

    torch.nn.parallel.DistributedDataParallel._set_params_and_buffers_to_ignore_for_model(model, params_to_ignore)

    if cfg.train.save_model:
        values_arr = ['val/epoch_NLL'] if cfg.features.validation_loss_type == "NLL" \
                    else ['val_epoch/overall_loss', 'val_epoch/x_CE', 'val_epoch/E_CE']
        
        if cfg.features.use_ins_del: 
            values_arr.append('val_epoch/delt_CE')
        for monitor_value in values_arr:
            checkpoint_callback = ModelCheckpoint(dirpath=f"checkpoints/{cfg.general.name}",
                                                filename=f'{monitor_value}_' + '{epoch}',
                                                monitor = monitor_value,
                                                save_top_k=2,
                                                mode='min',
                                                every_n_epochs=1)
            callbacks.append(checkpoint_callback)

        # fix a name and keep overwriting
        mid_ckpt_save = ModelCheckpoint(dirpath=f"checkpoints/{cfg.general.name}", filename='mid', 
                                        every_n_epochs=int(cfg.train.n_epochs/2) + 1, save_on_train_epoch_end=True)
        callbacks.append(mid_ckpt_save)
        last_ckpt_save = ModelCheckpoint(dirpath=f"checkpoints/{cfg.general.name}", filename='last', every_n_epochs=1)
        callbacks.append(last_ckpt_save)

    name = cfg.general.name
    if name == 'debug':
        print("[WARNING]: Run is called 'debug' -- it will run with fast_dev_run. ")

    if(isinstance(cfg.general.gpus, int)):
        gpus_ok = cfg.general.gpus >= 0
    elif(isinstance(cfg.general.gpus, omegaconf.ListConfig)):
        gpus_ok = True
        for gpu in cfg.general.gpus:
            if(gpu < 0): gpus_ok = False
    else:
        gpus_ok = False
    
    print("cfg.general.gpus", cfg.general.gpus)
    print("torch.cuda.is_available()", torch.cuda.is_available())

    use_gpu = torch.cuda.is_available() and gpus_ok
    print("gpus_ok", gpus_ok)
    print("use_gpu", use_gpu)

    if cfg.guidance.p_uncond >= 0 and not cfg.guidance.experiment_type == 'optimization':
        limit_test_batches = cfg.guidance.n_test_molecules_to_sample
    else:
        limit_test_batches = 1.0

    trainer = Trainer(gradient_clip_val=cfg.train.clip_grad,
                      strategy="ddp_find_unused_parameters_true",  # Needed to load old checkpoints
                      accelerator='gpu' if use_gpu else 'cpu',
                      devices=cfg.general.gpus if use_gpu else None,
                      max_epochs=cfg.train.n_epochs,
                      check_val_every_n_epoch=cfg.general.check_val_every_n_epochs,
                      fast_dev_run=cfg.general.name == 'debug',
                      enable_progress_bar=cfg.train.progress_bar,
                      callbacks=callbacks,
                      log_every_n_steps=50 if name != 'debug' else 1,
                      limit_test_batches=limit_test_batches,
                      limit_val_batches = 1 if cfg.guidance.p_uncond >= 0 else 1.0,
                      logger=[]
                      )

    if not cfg.general.test_only:
        trainer.fit(model, datamodule=datamodule, ckpt_path=cfg.general.resume)
        # if cfg.general.name not in ['debug', 'test']:
        #     trainer.test(model, datamodule=datamodule)
    else:
        # Start by evaluating test_only_path
        #for i in range(cfg.general.num_final_sampling):
        trainer.test(model, datamodule=datamodule, ckpt_path=cfg.general.test_only)
        
        if cfg.general.evaluate_all_checkpoints:
            directory = pathlib.Path(cfg.general.test_only).parents[0]
            print("Directory:", directory)
            files_list = os.listdir(directory)
            for file in files_list:
                if '.ckpt' in file:
                    ckpt_path = os.path.join(directory, file)
                    if ckpt_path == cfg.general.test_only:
                        continue
                    print("Loading checkpoint", ckpt_path)
                    trainer.test(model, datamodule=datamodule, ckpt_path=ckpt_path)


if __name__ == '__main__':
    main()
