"""solver.py.

Implements ODE and SDE solvers for the model.

Joins the torchdyn and torchsde libraries.
"""

from math import prod

import torch
import torchsde
from torchdyn.core import NeuralODE

from .augmentation import AugmentationModule, AugmentedVectorField, Sequential


class TorchSDE(torch.nn.Module):
    def __init__(
        self,
        sigma,
        forward_sde_drift,
        backward_sde_drift,
        noise_type,
        sde_type,
        reverse=False,
    ):
        super().__init__()
        self.sigma = sigma
        self.forward_sde_drift = forward_sde_drift
        self.backward_sde_drift = backward_sde_drift
        self.noise_type = noise_type
        self.sde_type = sde_type
        self.reverse = reverse

    def f(self, t, y):
        if self.reverse:
            return self.backward_sde_drift(1 - t, y)
        return self.forward_sde_drift(t, y)

    def g(self, t, y):
        return self.sigma(t) * torch.ones_like(y)

    def h(self, t, y):
        return torch.zeros_like(y)


class FlowSolver(torch.nn.Module):
    def __init__(
        self,
        vector_field,
        dim,
        augmentations=None,
        score_field=None,
        sigma=None,
        ode_solver="euler",
        sde_solver="euler",
        sde_noise_type="diagonal",
        sde_type="ito",
        dt=0.01,
        dt_min=1e-3,
        atol=1e-5,
        rtol=1e-5,
        **kwargs,
    ):
        """Initializes the solver.

        Merges Torchdyn with torchsde.
        Args:
            vector_field (torch.nn.Module): The vector field of the ODE.
            augmentations (torch.nn.Module): The augmentations of the ODE. Not used for SDE
            score_field (torch.nn.Module): The score field of the SDE. Score field is -g(t)^2 / 2 \nabla log p(x(t)).
            sigma (noise_schedule): The noise schedule of the SDE.
            reverse (bool): Whether to reverse the SDE no effect on ODE.
            ode_solver (str): The ODE solver to use.
            sde_solver (str): The SDE solver to use.
            sde_noise_type (str): The noise type of the SDE.
            dt (float): The fixed time step of the ODE solver.
            atol (float): The absolute tolerance of the ODE solver.
            rtol (float): The relative tolerance of the ODE solver.
        """
        super().__init__()
        self.net = vector_field
        self.dim = dim
        self.augmentations = augmentations
        self.score_net = score_field
        self.separate_score = score_field is not None
        self.sigma = sigma
        self.ode_solver = ode_solver
        self.sde_solver = sde_solver
        self.sde_noise_type = sde_noise_type
        self.sde_type = sde_type
        self.dt = dt
        self.dt_min = dt_min
        self.atol = atol
        self.rtol = rtol
        self.nfe = 0
        self.kwargs = kwargs
        self.is_image = not isinstance(self.dim, int)
        if self.is_image:
            self.flat_dim = prod(dim)

    def forward_flow_and_score(self, t, x, only_flow=False):
        if self.is_image:
            x = x.reshape(-1, *self.dim)
        if self.separate_score:
            vt, st = self.net(t, x), self.score_net(t, x)
        else:
            vtst = self.net(t, x)
            if vtst.shape[1] == x.shape[1]:
                return vtst
            split_idx = vtst.shape[1] // 2
            vt, st = vtst[:, :split_idx], vtst[:, split_idx:]
            assert vt.shape == x.shape
        if only_flow:
            return vt
        if self.is_image:
            vt = vt.reshape(-1, self.flat_dim)
            st = st.reshape(-1, self.flat_dim)
        return vt, st

    def forward_sde_drift(self, t, x):
        """Computes the forwards drift of the SDE."""
        self.nfe += 1
        vt, st = self.forward_flow_and_score(t, x)
        return vt + st

    def backward_sde_drift(self, t, x):
        """Computes the backwards drift of the SDE."""
        self.nfe += 1
        vt, st = self.forward_flow_and_score(t, x)
        return -vt + st

    def forward_ode_drift(self, t, x):
        """Computes the forwards drift of the ODE."""
        self.nfe += 1
        return self.forward_flow_and_score(t, x, only_flow=True)

    def backward_ode_drift(self, t, x):
        """Computes the backwards drift of the ODE."""
        self.nfe += 1
        return -self.forward_flow_and_score(t, x, only_flow=True)

    def ode_drift(self, reverse=False):
        return self.forward_ode_drift if not reverse else self.backward_ode_drift

    def sde_drift(self, reverse=False):
        return self.forward_sde_drift if not reverse else self.backward_sde_drift

    def flat_wrapper(self, func):
        if not isinstance(self.dim, int):

            def wrap(t, x):
                x = x.reshape(-1, self.dim)
                y = func(t, x)
                y = y.reshape(-1, self.flat_dim)

    def sdeint(self, x0, t_span, logqp=False, adaptive=False, reverse=False):
        self.nfe = 0
        sde = TorchSDE(
            self.sigma,
            self.forward_sde_drift,
            self.backward_sde_drift,
            self.sde_noise_type,
            self.sde_type,
            reverse,
        )
        if self.is_image:
            x0 = x0.reshape(-1, self.flat_dim)
        traj = torchsde.sdeint(
            sde,
            x0,
            t_span,
            method=self.sde_solver,
            dt=self.dt,
            rtol=self.rtol,
            atol=self.atol,
            logqp=logqp,
            adaptive=adaptive,
        )
        if self.is_image:
            traj = traj.reshape(traj.shape[0], traj.shape[1], *self.dim)
        return traj

    def odeint(self, x0, t_span):
        """Computes the ODE trajectory.

        Relies on the torchdyn library to compute the ODE trajectory and to handle reverse t_spans.
        """
        self.nfe = 0

        if self.augmentations is None:
            node = NeuralODE(
                self.forward_ode_drift,
                solver=self.ode_solver,
                atol=self.atol,
                rtol=self.rtol,
                return_t_eval=False,
            )
            return node(x0, t_span)

        aug_dims = self.augmentations.aug_dims
        aug_net = AugmentedVectorField(self.forward_ode_drift, self.augmentations.regs, self.dim)
        node_partial = NeuralODE(
            aug_net,
            solver=self.ode_solver,
            atol=self.atol,
            rtol=self.rtol,
            return_t_eval=False,
        )
        node = Sequential(
            self.augmentations.augmenter,
            node_partial,
        )
        aug_traj = node(x0, t_span)
        aug, traj = aug_traj[:, :, :aug_dims], aug_traj[:, :, aug_dims:]
        return traj, aug

    def get_nfe(self):
        return self.nfe

    def reset_nfe(self):
        self.nfe = 0


class DSBMFlowSolver(FlowSolver):
    """Same as SF2M except interprets net as forward and score_net as backward SDE drifts."""

    def forward_flow_and_score(self, t, x, only_forward=False, only_backward=False):
        if self.is_image:
            x = x.reshape(-1, *self.dim)
        if only_forward:
            fvt = self.net(t, x)
            return fvt.reshape(-1, self.flat_dim) if self.is_image else fvt
        if only_backward:
            return self.score_net(t, x)
        if self.separate_score:
            fvt, bvt = self.net(t, x), self.score_net(t, x)
        else:
            fbvt = self.net(t, x)
            # if using a single network split the network in two along the [1] dimension
            # batch, *(dims)
            split_idx = fbvt.shape[1] // 2
            fvt, bvt = fbvt[..., :split_idx], fbvt[..., split_idx:]
        if self.is_image:
            fvt = fvt.reshape(-1, self.flat_dim)
            bvt = bvt.reshape(-1, self.flat_dim)
        return fvt, bvt

    def forward_sde_drift(self, t, x):
        """Computes the forwards drift of the SDE."""
        self.nfe += 1
        return self.forward_flow_and_score(t, x, only_forward=True)

    def backward_sde_drift(self, t, x):
        """Computes the backwards drift of the SDE."""
        self.nfe += 1
        return self.forward_flow_and_score(t, x, only_backward=True)

    def forward_ode_drift(self, t, x):
        """Computes the forwards drift of the ODE."""
        self.nfe += 1
        fvt, bvt = self.forward_flow_and_score(t, x)
        return (fvt - bvt) / 2

    def backward_ode_drift(self, t, x):
        """Computes the backwards drift of the ODE."""
        self.nfe += 1
        fvt, bvt = self.forward_flow_and_score(t, x)
        return -(fvt - bvt) / 2
