import torch
from torch import nn
from torchsde import sdeint

class Model(nn.Module):
    sde_type = 'ito'
    noise_type = 'general'

    def __init__(self):
        super().__init__()

    def f_and_g(self, t, y):
        f = torch.ones_like(y)
        g = torch.ones((y.shape[0], y.shape[1], 2)).to(y)
        return f, g


y0 = torch.ones((16, 4), dtype=torch.half)
ts = torch.linspace(0, 1, 100, dtype=torch.half)

ys = sdeint(Model(), y0, ts, dt=0.1)
print(ys.dtype)

# a = torch.ones((16, 4, 2), dtype=torch.half)
# b = torch.ones((16, 2, 4), dtype=torch.half)

# with torch.autocast(device_type='cuda', dtype=torch.float32):
#     torch.bmm(a, b)
