import torch
from torchsde import sdeint, SDEIto


class Lorenz(SDEIto):
    def __init__(self, alpha=0.15, beta=8.0 / 3.0, rho=28, sigma=10):
        super().__init__("diagonal")
        self.alpha = alpha
        self.beta = beta
        self.rho = rho
        self.sigma = sigma

    def f(self, t, x):
        x, y, z = x[..., 0], x[..., 1], x[..., 2]

        xdot = self.sigma * (y - x)
        ydot = x * (self.rho - z) - y
        zdot = x * y - self.beta * z

        return torch.stack((xdot, ydot, zdot), dim=-1)

    def g(self, t, x):
        return torch.full_like(x, self.alpha)


class VanDerPol(SDEIto):
    def __init__(self, alpha=0.15, mu=2.0):
        super().__init__("diagonal")
        self.alpha = alpha
        self.mu = mu

    def f(self, t, x):
        x, y = x[..., 0], x[..., 1]

        ydot = self.mu * (1 - x * x) * y - x

        return torch.stack((y, ydot), dim=-1)

    def g(self, t, x):
        return torch.full_like(x, self.alpha)


@torch.no_grad()
def make_lorenz_paths(times, batches, **kwargs):
    x0 = torch.randn(batches, 3)
    return sdeint(Lorenz(), x0, times, **kwargs)


@torch.no_grad()
def make_van_der_pol_paths(times, batches, **kwargs):
    x0 = torch.randn(batches, 2)
    return sdeint(VanDerPol(), x0, times, **kwargs)
