import torch


class SDE(torch.nn.Module):
    sde_type = 'ito'
    noise_type = 'scalar'    
    def __init__(self, μ, θ, σ):
        super().__init__()
        self.register_buffer('μ', torch.as_tensor(μ))
        self.register_buffer('θ', torch.as_tensor(θ))
        self.register_buffer('σ', torch.as_tensor(σ))
        self.brownian_size = 1

    # Drift
    def f(self, t, y):
        return self.μ * t - self.θ * y
    # Diffusion
    def g(self, t, y):
        return self.σ.expand(y.size(0), 1, 1)


