from sklearn.cluster import KMeans
import torch
import numpy as np
import argparse
import os
import tqdm

data = torch.load("oneshot-2.pt")
print(sum([data[i]["observations"].shape[0] for i in range(len(data))]))
exit(0)

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--N", type=int)
    args = parser.parse_args()
    return args

args = get_args()

N = args.N

"""
# get sub

seqs = []
for name in ["kettle_bottom burner_top burner_slide cabinet.pt"]: # microwave_light switch_slide cabinet_hinge cabinet
    print("processing " + name + " ...")
    data = torch.load(name)
    print(data)
    
    for i in range(1):
        seqs.append({"observations":data[i]["observations"], "actions":data[i]["actions"]})
data = torch.save(seqs,"oneshot-1.pt")
"""
"""
# analysis
data = torch.load("task-specific2-sub.pt")
# obs, action = [], []
obs, action = [], []
for i in range(4):
    obs.append(data[i]["observations"])
    action.append(data[i]["actions"])
obs = np.concatenate(obs, axis=0)
action = np.concatenate(action, axis=0)
dist = np.zeros((obs.shape[0], obs.shape[0]))
for i in range(obs.shape[0]):
    for j in range(i+1, obs.shape[0]):
        dist[i, j] = dist[j, i] = ((obs[i] - obs[j]) ** 2).sum()
dist += 1000 * np.eye(obs.shape[0]) # exclude distance to self
for i in range(obs.shape[0]):
    agmn = dist[i].argmin()
    print("nearest state MSE distance is state #", agmn, dist[i, agmn], "action difference:", ((action[i] - action[agmn]) ** 2).sum())

dist = np.zeros((obs.shape[0], obs.shape[0]))
for i in range(obs.shape[0]):
    for j in range(i+1, obs.shape[0]):
        dist[i, j] = dist[j, i] = ((action[i] - action[j]) ** 2).sum()
dist += 1000 * np.eye(obs.shape[0]) # exclude distance to self
for i in range(obs.shape[0]):
    agmn = dist[i].argmin()
    print("nearest action MSE distance is state #", agmn, dist[i, agmn], "state difference:", ((obs[i] - obs[agmn]) ** 2).sum())

# print(dist) 
"""

data = torch.load("task-agnostic-seqs.pt")
obs = np.concatenate([data[i]["states"][-1].reshape(1, -1) for i in range(len(data))])
kmeans = KMeans(n_clusters=N).fit(obs)

#print("len:", len(data))

#for i in range(len(data)):
#    print(kmeans.labels_[i], end=" ")
#print(np.bincount(kmeans.labels_))
#print("")
#for i in range(N):
#    print(kmeans.cluster_centers_[i])
print(args.N, "inertia:", kmeans.inertia_)

hashy = np.zeros(N)

seqs = []

for i in range(len(data)):
    if hashy[kmeans.labels_[i]] == 0:
        hashy[kmeans.labels_[i]] = 1
        seqs.append({"observations":data[i]["states"], "actions":data[i]["actions"]})

seqs = [[] for i in range(N)]

for i in range(len(data)):
    seqs[kmeans.labels_[i]].append({"observations":data[i]["states"], "actions":data[i]["actions"]})

for i in range(N):
    """
    if len(seqs[i]) < 10: 
        print("skipping", i)
        continue
    """
    torch.save(seqs[i], "task_noprior_"+str(i)+".pt")

