import jax.numpy as jnp


# def get_dataset():
#     data = jnp.load("data/ns_tori.npy", allow_pickle=True).item()
#     u_ref = data["u"]
#     v_ref = data["v"]
#     w_ref = data["w"]
#
#     t = data["t"].flatten()
#     x = data["x"].flatten()
#     y = data["y"].flatten()
#     nu = data["viscosity"]
#
#     return u_ref, v_ref, w_ref, t, x, y, nu



def get_dataset():
    data = jnp.load("data/decaying_flow.npy", allow_pickle=True).item()
    w_ref = jnp.array(data["vorticity"])
    velocity = jnp.array(data["velocity"])

    u_ref = velocity[..., 0]
    v_ref = velocity[..., 1]

    t = jnp.array(data["t"]).flatten()
    t = t - t[0]

    coords = jnp.array(data["coords"])
    nu = data["nu"]

    return u_ref, v_ref, w_ref, t, coords, nu
