import numpy as np
from multinav.data.dataset import make_dataset
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
import os

if __name__ == "__main__":
    num_clusters = 256
    ds, _ = make_dataset(None, 256, 1, 1)

    actions = []
    for batch in ds.take(10).iterator():
        actions.append(batch["action"].squeeze(1))
        actions.append(batch["action"].squeeze(1) * np.array([1, -1, 1, -1]))
    
    actions = np.concatenate(actions, axis=0)
    means = KMeans(n_clusters=num_clusters).fit(actions).cluster_centers_

    plt.scatter(means[:, 0], means[:, -1])
    plt.scatter(actions[:, 0], actions[:, -1], alpha=0.1, s=1)
    plt.show()

    np.save(os.path.join(os.path.dirname(__file__), "action_clusters.npy"), means)