import torch

from . import base_sde
from . import misc
from ..settings import SDE_TYPES, NOISE_TYPES
from ..types import TensorOrTensors, Sequence


class AdjointSDE(base_sde.BaseSDE):

    def __init__(self,
                 sde: base_sde.ForwardSDE,
                 params: TensorOrTensors,
                 shapes: Sequence[torch.Size]):
        # There's a mapping from the noise type of the forward SDE to the noise type of the adjoint.
        # Usually, these two aren't the same, e.g. when the forward SDE has additive noise, the adjoint SDE's diffusion
        # is a linear function of the adjoint variable, so it is not of additive noise.
        sde_type = sde.sde_type
        noise_type = {
            NOISE_TYPES.general: NOISE_TYPES.general,
            NOISE_TYPES.additive: NOISE_TYPES.general,
            NOISE_TYPES.scalar: NOISE_TYPES.scalar,
            NOISE_TYPES.diagonal: NOISE_TYPES.diagonal,
        }.get(sde.noise_type)
        super(AdjointSDE, self).__init__(sde_type=sde_type, noise_type=noise_type)

        self._base_sde = sde
        self._params = params
        self._shapes = shapes

        # Register the core functions. This avoids polluting the codebase with if-statements and achieves speed-ups
        # by making sure it's a one-time cost. The `sde_type` and `noise_type` of the forward SDE determines the
        # registered functions.
        self.f = {
            SDE_TYPES.ito: {
                NOISE_TYPES.diagonal: self.f_corrected_diagonal,
                NOISE_TYPES.additive: self.f_uncorrected,
                NOISE_TYPES.scalar: self.f_corrected_default,
                NOISE_TYPES.general: self.f_corrected_default
            }.get(sde.noise_type),
            SDE_TYPES.stratonovich: self.f_uncorrected
        }.get(sde.sde_type)
        self.gdg_prod = {
            NOISE_TYPES.diagonal: self.gdg_prod_diagonal,
        }.get(sde.noise_type, self.gdg_prod_default)

    def _get_state(self, t, y_aug, v=None):
        """Unpacks y_aug, whilst enforcing the necessary checks so that we can calculate derivatives wrt state."""

        # These leaf checks are very important.
        # _get_state is used where we want to compute:
        # ```
        # with torch.enable_grad():
        #     s = some_function(y)
        #     torch.autograd.grad(s, [y] + params, ...)
        # ```
        # where `some_function` implicitly depends on `params`.
        # However if y has history of its own then in principle it could _also_ depend upon `params`, and this call to
        # `grad` will go all the way back to that. To avoid this, we require that every input tensor be a leaf tensor.
        #
        # This is also the reason for the `y0.detach()` in adjoint.py::_SdeintAdjointMethod.forward. If we don't detach,
        # then y0 may have a history and these checks will fail. This is a spurious failure as
        # `torch.autograd.Function.forward` has an implicit `torch.no_grad()` guard, i.e. we definitely don't want to
        # use its history there.
        assert t.is_leaf, "Internal error: please report a bug to torchsde"
        assert y_aug.is_leaf, "Internal error: please report a bug to torchsde"
        if v is not None:
            assert v.is_leaf, "Internal error: please report a bug to torchsde"

        requires_grad = torch.is_grad_enabled()

        y, adj_y = misc.flat_to_shape(y_aug, self._shapes[:2])

        # To support the later differentiation wrt y, we set it to require_grad if it doesn't already.
        if not y.requires_grad:
            y = y.detach().requires_grad_(True)
        return y, adj_y, requires_grad

    ########################################
    #                  f                   #
    ########################################

    def f_uncorrected(self, t, y_aug):  # For Ito additive and Stratonovich.
        y, adj_y, requires_grad = self._get_state(t, y_aug)
        with torch.enable_grad():
            f = self._base_sde.f(-t, y)
            vjp_y_and_params = misc.vjp(
                outputs=f,
                inputs=[y] + self._params,
                grad_outputs=adj_y,
                allow_unused=True,
                create_graph=requires_grad
            )
            if not requires_grad:
                # We had to build a computational graph to be able to compute the above vjp.
                # However, if we don't require_grad then we don't need to backprop through this function, so we should
                # delete the computational graph to avoid a memory leak. (Which for example would keep the local
                # variable `y` in memory: f->grad_fn->...->AccumulatedGrad->y.)
                f = f.detach()
        return misc.flatten((-f, *vjp_y_and_params))

    def f_corrected_default(self, t, y_aug):  # For Ito general/scalar.
        y, adj_y, requires_grad = self._get_state(t, y_aug)
        with torch.enable_grad():
            g_columns = [g_column.squeeze(dim=-1) for g_column in self._base_sde.g(-t, y).split(1, dim=-1)]
            dg_g_jvp = sum([
                misc.jvp(
                    outputs=g_column,
                    inputs=y,
                    grad_inputs=g_column,
                    allow_unused=True,
                    create_graph=True
                )[0] for g_column in g_columns
            ])
            # Double Stratonovich correction.
            f = self._base_sde.f(-t, y) - dg_g_jvp
            vjp_y_and_params = misc.vjp(
                outputs=f,
                inputs=[y] + self._params,
                grad_outputs=adj_y,
                allow_unused=True,
                retain_graph=True,
                create_graph=requires_grad
            )
            # Convert the adjoint Stratonovich SDE to Itô form.
            extra_vjp_y_and_params = []
            for g_column in g_columns:
                a_dg_vjp, = misc.vjp(
                    outputs=g_column,
                    inputs=y,
                    grad_outputs=adj_y,
                    allow_unused=True,
                    retain_graph=True,
                    create_graph=requires_grad
                )
                extra_vjp_y_and_params_column = misc.vjp(
                    outputs=g_column,
                    inputs=[y] + self._params,
                    grad_outputs=a_dg_vjp,
                    allow_unused=True,
                    create_graph=requires_grad
                )
                extra_vjp_y_and_params.append(extra_vjp_y_and_params_column)
            vjp_y_and_params = misc.seq_add(vjp_y_and_params, *extra_vjp_y_and_params)
            if not requires_grad:
                f = f.detach()
        return misc.flatten((-f, *vjp_y_and_params))

    def f_corrected_diagonal(self, t, y_aug):  # For Ito diagonal.
        y, adj_y, requires_grad = self._get_state(t, y_aug)
        with torch.enable_grad():
            g = self._base_sde.g(-t, y)
            g_dg_vjp, = misc.vjp(
                outputs=g,
                inputs=y,
                grad_outputs=g,
                allow_unused=True,
                create_graph=True
            )
            # Double Stratonovich correction.
            f = self._base_sde.f(-t, y) - g_dg_vjp
            vjp_y_and_params = misc.vjp(
                outputs=f,
                inputs=[y] + self._params,
                grad_outputs=adj_y,
                allow_unused=True,
                retain_graph=True,
                create_graph=requires_grad
            )
            # Convert the adjoint Stratonovich SDE to Itô form.
            a_dg_vjp, = misc.vjp(
                outputs=g,
                inputs=y,
                grad_outputs=adj_y,
                allow_unused=True,
                retain_graph=True,
                create_graph=requires_grad
            )
            extra_vjp_y_and_params = misc.vjp(
                outputs=g,
                inputs=[y] + self._params,
                grad_outputs=a_dg_vjp,
                allow_unused=True,
                create_graph=requires_grad
            )
            vjp_y_and_params = misc.seq_add(vjp_y_and_params, extra_vjp_y_and_params)
            if not requires_grad:
                # See corresponding note in f_uncorrected.
                f = f.detach()
        return misc.flatten((-f, *vjp_y_and_params))

    ########################################
    #                  g                   #
    ########################################

    def g(self, t, y):
        # We don't want to define it, it's super inefficient to compute.
        # In theory every part of the code which _could_ call it either does something else, or has some more
        # informative error message to tell the user what went wrong.
        # This is here as a fallback option.
        raise RuntimeError("Adjoint `g` not defined. Please report a bug to torchsde.")

    ########################################
    #                g_prod                #
    ########################################

    def g_prod(self, t, y_aug, v):
        y, adj_y, requires_grad = self._get_state(t, y_aug, v)
        with torch.enable_grad():
            g_prod = self._base_sde.g_prod(-t, y, v)
            vjp_y_and_params = misc.vjp(
                outputs=g_prod,
                inputs=[y] + self._params,
                grad_outputs=adj_y,
                allow_unused=True,
                create_graph=requires_grad
            )
            if not requires_grad:
                # See corresponding note in f_uncorrected.
                g_prod = g_prod.detach()
        return misc.flatten((-g_prod, *vjp_y_and_params))

    ########################################
    #               gdg_prod               #
    ########################################

    def gdg_prod_default(self, t, y, v):  # For Ito/Stratonovich general/additive/scalar.
        raise NotImplementedError

    def gdg_prod_diagonal(self, t, y_aug, v):  # For Ito/Stratonovich diagonal.
        y, adj_y, requires_grad = self._get_state(t, y_aug, v)
        with torch.enable_grad():
            g = self._base_sde.g(-t, y)
            vg_dg_vjp, = misc.vjp(
                outputs=g,
                inputs=y,
                grad_outputs=v * g,
                allow_unused=True,
                retain_graph=True,
                create_graph=requires_grad
            )
            dgdy, = misc.vjp(
                outputs=g.sum(),
                inputs=y,
                allow_unused=True,
                retain_graph=True,
                create_graph=requires_grad
            )
            prod_partials_adj_y_and_params = misc.vjp(
                outputs=g,
                inputs=[y] + self._params,
                grad_outputs=adj_y * v * dgdy,
                allow_unused=True,
                retain_graph=True,
                create_graph=requires_grad
            )
            avg_dg_vjp, = misc.vjp(
                outputs=g,
                inputs=y,
                grad_outputs=(adj_y * v * g).detach(),
                allow_unused=True,
                create_graph=True
            )
            mixed_partials_adj_y_and_params = misc.vjp(
                outputs=avg_dg_vjp.sum(),
                inputs=[y] + self._params,
                allow_unused=True,
                create_graph=requires_grad
            )
            vjp_y_and_params = misc.seq_sub(prod_partials_adj_y_and_params, mixed_partials_adj_y_and_params)
        return misc.flatten((vg_dg_vjp, *vjp_y_and_params))
