from torch import nn, Tensor


class SdeIntLoss(nn.Module):
    def __init__(self, obs_loss: nn.Module, n_samples=100, use_adjoint=False, **kwargs):
        super().__init__()
        self.obs_loss = obs_loss
        self.use_adjoint = use_adjoint
        self.n_samples = n_samples
        self.kwargs = kwargs

    def forward(self, sde: nn.Module, times: Tensor, obs: Tensor):
        if self.use_adjoint:
            from torchsde import sdeint_adjoint as sdeint
        else:
            from torchsde import sdeint

        obs = obs.repeat(1, self.n_samples, 1)
        ys = sdeint(
            sde,
            obs[0, :, :],
            times,
            **self.kwargs
        )
        return self.obs_loss(ys, obs)
