# These imports are tricky because they use c++, do not move them
import os, shutil
import warnings
import sys
import torch
import hydra
from omegaconf import DictConfig
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.loggers import WandbLogger

import math
import wandb

import utils
from _datasets import dataset
from metrics.molecular_metrics_train import TrainMolecularMetricsDiscrete
from metrics.molecular_metrics_sampling import SamplingMolecularMetrics
from analysis.visualization import MolecularVisualization
from ema import AdaptiveEMACallback

warnings.filterwarnings("ignore", category=UserWarning)
torch.set_float32_matmul_precision("medium")

CURRENT_DIR, _ = os.path.split(os.path.abspath(__file__))
CURRENT_DIR = CURRENT_DIR.replace("meld", "")

# Add the parent directory to sys.path if not already there
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if parent_dir not in sys.path:
    sys.path.append(parent_dir)

def remove_folder(folder):
    for filename in os.listdir(folder):
        file_path = os.path.join(folder, filename)
        try:
            if os.path.isfile(file_path) or os.path.islink(file_path):
                os.unlink(file_path)
            elif os.path.isdir(file_path):
                shutil.rmtree(file_path)
        except Exception as e:
            print("Failed to delete %s. Reason: %s" % (file_path, e))


def get_resume(cfg, model, model_kwargs):
    """Resumes a run. It loads previous config without allowing to update keys (used for testing)."""
    saved_cfg = cfg.copy()
    name = cfg.general.name + "_resume"
    resume = cfg.general.test_only
    batch_size = cfg.train.batch_size
    model = model.load_from_checkpoint(resume, **model_kwargs)
    ckpt = torch.load(cfg.general.test_only, map_location='cpu')
    if 'noise_schedule.node_vocab' in ckpt['state_dict']:
        model.noise_schedule.node_vocab = torch.nn.Parameter(ckpt['state_dict']['noise_schedule.node_vocab'])
    cfg = model.cfg
    cfg.general.test_only = resume
    cfg.general.name = name
    cfg.train.batch_size = batch_size
    cfg = utils.update_config_with_new_keys(cfg, saved_cfg)
    return cfg, model

def get_resume_adaptive(cfg, model, model_kwargs):
    """Resumes a run. It loads previous config but allows to make some changes (used for resuming training)."""
    saved_cfg = cfg.copy()
    # Fetch path to this file to get base path
    current_path = os.path.dirname(os.path.realpath(__file__))

    model = model.load_from_checkpoint(
        cfg.general.resume, **model_kwargs
    )

    new_cfg = model.cfg
    for category in cfg:
        for arg in cfg[category]:
            new_cfg[category][arg] = cfg[category][arg]
    new_cfg.general.name = new_cfg.general.name + "_resume"

    new_cfg = utils.update_config_with_new_keys(new_cfg, saved_cfg)
    return new_cfg, model


@hydra.main(
    version_base="1.1", config_path="../configs", config_name="config"
)
def main(cfg: DictConfig):
    
    seed_everything(cfg.general.seed)
    if cfg.dataset.guidance_target == "QM9":
        indices = torch.load(os.path.join(CURRENT_DIR, f"data/split/qm9_split.pt"), weights_only=False)
    elif cfg.dataset.guidance_target == "ZINC":
        indices = torch.load(os.path.join(CURRENT_DIR, f"data/split/zinc250k_split.pt"), weights_only=False)
    elif cfg.dataset.guidance_target in ['O2', 'N2', 'CO2', 'O2-N2-CO2']:
        indices = torch.load(os.path.join(CURRENT_DIR, f"data/split/O2-N2-CO2_split.pt"), weights_only=False)
    train_index = indices['train_index']
    val_index = indices['val_index']
    test_index = indices['test_index']
    
    datamodule = dataset.DataModule(cfg, train_index, val_index, test_index)
    datamodule.prepare_data()
    
    dataset_infos = dataset.DataInfos(datamodule=datamodule, cfg=cfg)
    train_smiles = datamodule.train_dataset.smiles
    train_y = datamodule.train_dataset.y
    reference_smiles = datamodule.test_dataset.smiles
    test_scaffold_smiles = None

    print(f"Number of train: {len(datamodule.train_dataset)}")
    print(f"Number of val: {len(datamodule.val_dataset)}")
    print(f"Number of test: {len(datamodule.test_dataset)}")

    dataset_infos.compute_input_output_dims(datamodule=datamodule)
    train_metrics = TrainMolecularMetricsDiscrete(dataset_infos)
    sampling_metrics = SamplingMolecularMetrics(
        dataset_infos, 
        train_smiles, 
        reference_smiles, 
        train_y=train_y, 
        n_jobs=1,
        batch_size=256,
    )
    visualization_tools = MolecularVisualization(dataset_infos)
        
    model_kwargs = {
        "dataset_infos": dataset_infos,
        "train_metrics": train_metrics,
        "sampling_metrics": sampling_metrics,
        "visualization_tools": visualization_tools,
        "num_train_steps": math.ceil(len(train_smiles) / cfg.train.batch_size),
        "test_scaffold_smiles": test_scaffold_smiles,
    }
    resume_path = cfg.general.resume

    from masked_diffusion_model import MELD
    model = MELD(cfg=cfg, **model_kwargs)
    model._model_kwargs = model_kwargs

    if cfg.general.test_only:
        # When testing, previous configuration is fully loaded
        cfg, _ = get_resume(cfg, model, model_kwargs)
        os.chdir(cfg.general.test_only.split("checkpoints")[0])
    elif cfg.general.resume is not None:
        # When resuming, we can override some parts of previous configuration
        cfg, _ = get_resume_adaptive(cfg, model, model_kwargs)
        os.chdir(cfg.general.resume.split("checkpoints")[0])
    
    os.environ["WANDB_API_KEY"] = cfg.general.wandb_key
    wandb.login()
    logger = WandbLogger(project="MELD", name=cfg.general.exp_name)
    
    if cfg.train.use_ema and not cfg.general.test_only:
        ema_callback = AdaptiveEMACallback(
            decay=0.999,  # EMA decay rate
        )
        callbacks = [ema_callback]
    else:
        callbacks = []

    trainer = Trainer(
        gradient_clip_val=cfg.train.clip_grad,
        accelerator="gpu"
        if torch.cuda.is_available() and cfg.general.gpus > 0 else "cpu",
        devices=cfg.general.gpus
        if torch.cuda.is_available() and cfg.general.gpus > 0 else None,
        max_epochs=cfg.train.n_epochs,
        enable_checkpointing=False,
        check_val_every_n_epoch=cfg.train.check_val_every_n_epoch,
        val_check_interval=cfg.train.val_check_interval,
        strategy="ddp_find_unused_parameters_true" if cfg.general.gpus > 1 else "auto",
        enable_progress_bar=cfg.general.enable_progress_bar,
        callbacks=callbacks,
        reload_dataloaders_every_n_epochs=0,
        logger=[logger],
    )

    if not cfg.general.test_only:
        trainer.fit(model, datamodule=datamodule, ckpt_path=resume_path)
        if cfg.general.save_model:
            trainer.save_checkpoint(f"checkpoints/{cfg.general.exp_name}/last.ckpt")
        trainer.test(model, datamodule=datamodule)
    else:
        trainer.test(model, datamodule=datamodule, ckpt_path=cfg.general.test_only)




if __name__ == "__main__":
    main()