import numpy as np


# Load the raw dataset
raw_data = np.load("./raw_data.npy")
all_states = raw_data.reshape(-1, 3)
all_residuals = raw_data[:, -1, :] - raw_data[:, 0, :]

# Compute stats for states
std_x = np.std(all_states, axis=0)
np.save("./stats/std_x.npy", std_x)
mean_x = np.mean(all_states, axis=0)
np.save("./stats/mean_x.npy", mean_x)

# Compute stats for residuals
std_z = np.std(all_residuals, axis=0)
np.save("./stats/std_z.npy", std_z)
