import torch
import os
import pickle
import logging

from tqdm import tqdm
import numpy as np
from pytorch_lightning.callbacks import Callback
from schnetpack.diffusion.utils import sample_R, sample_R_time
from schnetpack.diffusion.sampling_analysis import check_validity, generate_bonds_data
from schnetpack.diffusion.noise_schedule import NoiseSchedule
from schnetpack import properties


class OutputWriterCallback(Callback):
    """
    Callback to store outputs using ``torch.save``.
    """

    def __init__(self, output_dir: str):
        """
        Args:
            output_dir: output directory for prediction files
            write_interval: can be one of ["batch", "epoch", "batch_and_epoch"]
        """
        super().__init__()
        self.output_dir = output_dir
        os.makedirs(output_dir, exist_ok=True)

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        bdir = os.path.join(self.output_dir, "train_" + str(trainer.current_epoch))
        os.makedirs(bdir, exist_ok=True)
        logs = {"inputs": batch, "outputs": outputs}
        if (
            (outputs["loss"] > 1e4).any()
            or torch.isnan(outputs["loss"]).any()
            or torch.isinf(outputs["loss"]).any()
        ):
            logging.warning(
                f"WARNING: exploded output in folder {bdir} batch_idx {batch_idx}"
            )
        torch.save(logs, os.path.join(bdir, f"{batch_idx}.pt"))

    def on_validation_batch_end(
        self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx = 0
    ):
        bdir = os.path.join(self.output_dir, "val_" + str(trainer.current_epoch))
        os.makedirs(bdir, exist_ok=True)
        logs = {"inputs": batch, "outputs": outputs}
        if (
            (outputs["val_loss"] > 1e4).any()
            or torch.isnan(outputs["val_loss"]).any()
            or torch.isinf(outputs["val_loss"]).any()
        ):
            logging.warning(
                f"WARNING: exploded output in folder {bdir} batch_idx {batch_idx}"
            )
        torch.save(logs, os.path.join(bdir, f"{batch_idx}.pt"))
