import numpy as np

import openpi.shared.normalize as normalize


def test_normalize_update():
    arr = np.arange(12)

    stats = normalize.RunningStats()
    for i in range(0, len(arr), 3):
        stats.update(arr[i : i + 3])
    results = stats.get_statistics()

    assert np.allclose(results.mean, np.mean(arr))
    assert np.allclose(results.std, np.std(arr))


def test_serialize_deserialize():
    stats = normalize.RunningStats()
    stats.update(np.arange(12))

    norm_stats = {"test": stats.get_statistics()}
    norm_stats2 = normalize.deserialize_json(normalize.serialize_json(norm_stats))
    assert np.allclose(norm_stats["test"].mean, norm_stats2["test"].mean)
    assert np.allclose(norm_stats["test"].std, norm_stats2["test"].std)
