import torch
from torch import nn
import inspect
import logging

from schnetpack.task import AtomisticTask, ModelOutput, UnsupervisedModelOutput
from schnetpack.transform import Transform
from schnetpack.diffusion.utils import sample_R
from schnetpack.diffusion.loss import nll
from schnetpack.diffusion import NoiseSchedule
from schnetpack import properties


log = logging.getLogger(__name__)


class DiffusionTask(AtomisticTask):
    def __init__(
        self,
        skip_exploding_batches: bool = True,
        noise_schedule: NoiseSchedule = None,
        log_nll: bool = False,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.skip_exploding_batches = skip_exploding_batches
        if log_nll and noise_schedule is None:
            raise ValueError(
                "You need to provide a noise schedule for the task module if you want to log the variational lower bound!"
            )
        self._log_nll = log_nll
        self.noise_schedule = noise_schedule

    def setup(self, stage=None):
        AtomisticTask.setup(self, stage=stage)
        # force some post-processing transforms during training
        forced_postprocessors = []
        for pp in self.model.postprocessors:
            if hasattr(pp, "force_apply"):
                if pp.force_apply:
                    forced_postprocessors.append(pp)
        self.model.forced_postprocessors = nn.ModuleList(forced_postprocessors)

    def predict_without_postprocessing(self, batch):
        tmp_postprocessors = self.model.postprocessors
        self.model.postprocessors = self.model.forced_postprocessors
        pred = self(batch)
        self.model.postprocessors = tmp_postprocessors
        return pred

    def forward_t0(self, batch):
        # forward pass for t=0 to compute L_0 in the VLB

        # tmp save the initial (relevant) predictions for L_t
        tmp_diff_step = batch["diff_step"].clone()
        tmp_eps_pred = batch["eps_pred"].clone()
        tmp_R = batch[properties.R].clone()

        # feed-forward for t=0
        # This will override the neighbors and distances computed for previous t
        idx_t0 = torch.where((batch["t"] == 0).all(0))[0][0]
        batch[properties.R] = batch["all_diff_R"][:, idx_t0]
        batch["diff_step"] = torch.zeros_like(batch["diff_step"])
        pred = self.predict_without_postprocessing(batch)

        # restore the initial (relevant) predictions for L_t
        batch["eps_0"] = batch["eps_all"][:, idx_t0]
        batch["eps_0_pred"] = pred["eps_pred"]
        batch["diff_step"] = tmp_diff_step
        batch["eps_pred"] = tmp_eps_pred
        batch[properties.R] = tmp_R

        return batch

    def log_nll(self, batch, subset):
        nll_terms = nll(
            batch, self.noise_schedule, include_l0=True, include_lT=True, training=False
        )
        for k, v in nll_terms.items():
            self.log(
                f"{subset}_{k}",
                v,
                on_step=(subset == "train"),
                on_epoch=(subset != "train"),
                prog_bar=False,
            )

    def training_step(self, batch, batch_idx):

        targets = {
            output.target_property: batch[output.target_property]
            for output in self.outputs
            if not isinstance(output, UnsupervisedModelOutput)
        }
        try:
            targets["considered_atoms"] = batch["considered_atoms"]
        except:
            pass

        pred = self.predict_without_postprocessing(batch)
        pred, targets = self.apply_constraints(pred, targets)

        loss = self.loss_fn(pred, targets)

        self.log("train_loss", loss, on_step=True, on_epoch=False, prog_bar=False)
        self.log_metrics(pred, targets, "train")

        # this needs an extra forward pass to get the correct eps_0 and eps_0_pred
        # and will override the previous inputs of l_t
        if self._log_nll:
            batch = self.forward_t0(batch)
            # rewrite the predictions of l_t
            self.log_nll(batch, "train")

        if self.skip_exploding_batches and (
            torch.isnan(loss) or torch.isinf(loss) or loss > 1e10
        ):
            log.warning(
                f"Loss is {loss} for train batch_idx {batch_idx} and training step {self.global_step}, training step will be skipped!"
            )
            return None

        return loss

    def validation_step(self, batch, batch_idx):
        torch.set_grad_enabled(self.grad_enabled)

        targets = {
            output.target_property: batch[output.target_property]
            for output in self.outputs
            if not isinstance(output, UnsupervisedModelOutput)
        }
        try:
            targets["considered_atoms"] = batch["considered_atoms"]
        except:
            pass

        pred = self.predict_without_postprocessing(batch)
        pred, targets = self.apply_constraints(pred, targets)

        loss = self.loss_fn(pred, targets)

        self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log_metrics(pred, targets, "val")

        # this needs an extra forward pass to get the correct eps_0 and eps_0_pred
        # and will override the previous inputs of l_t
        if self._log_nll:
            batch = self.forward_t0(batch)
            # rewrite the predictions of l_t
            self.log_nll(batch, "val")

        return {"val_loss": loss}

    def test_step(self, batch, batch_idx):
        torch.set_grad_enabled(self.grad_enabled)

        targets = {
            output.target_property: batch[output.target_property]
            for output in self.outputs
            if not isinstance(output, UnsupervisedModelOutput)
        }
        try:
            targets["considered_atoms"] = batch["considered_atoms"]
        except:
            pass

        pred = self.predict_without_postprocessing(batch)
        pred, targets = self.apply_constraints(pred, targets)

        loss = self.loss_fn(pred, targets)

        self.log("test_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log_metrics(pred, targets, "test")

        # this needs an extra forward pass to get the correct eps_0 and eps_0_pred
        # and will override the previous inputs of l_t
        if self._log_nll:
            batch = self.forward_t0(batch)
            # rewrite the predictions of l_t
            self.log_nll(batch, "test")

        return {"test_loss": loss}


class DiffModelOutput(ModelOutput):
    def calculate_loss(self, pred, target):
        if self.loss_weight == 0 or self.loss_fn is None:
            return 0.0

        args_ = inspect.getfullargspec(self.loss_fn).args[2:]

        kwargs = {k: pred[k] for k in args_ if k in pred}

        if kwargs:
            loss = self.loss_weight * self.loss_fn(
                pred[self.name], target[self.target_property], **kwargs
            )
        else:
            loss = self.loss_weight * self.loss_fn(
                pred[self.name], target[self.target_property]
            )
        return loss
