import torch


class SDE(torch.nn.Module):
    noise_type = 'general'
    sde_type = 'ito'
    
    def __init__(self, a, b, σ):
        super().__init__()
        self.a = a
        self.b = b
        self.σ = σ
        self.brownian_size = 1

    # Drift
    def f(self, t, y):
        #print(y.shape)
        f_truth = self.a * (self.b - y)
        return f_truth  # shape (batch_size, state_size)

    # Diffusion
    def g(self, t, y):
        g_truth = torch.zeros(y.shape[0], y.shape[1], self.brownian_size)
        for i in range(y.shape[0]):
            g_truth[i, 0, 0] = self.σ * torch.sqrt(y[i, 0])            
        return g_truth

