import numpy as np
import phate
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import random
import glob
import os

domain='walker' # walker or quadruped
# alphas=[0.01,0.2,0.5,0.7,0.9,1.0,1.5,2.0,3.0,5.0,50.0] # List of alphas used for data generation
alphas=[0.05,0.2,0.5,0.7,1.0,2.0,3.0,5.0,50.0] # List of alphas used for data generation
# print(os.getcwd())
cwd=os.getcwd() # current directory
data_files={}
obs_data={}
act_data={}
num_sample=10000 # Number of subsamples to run tsne and plot 
# num_sample=500 # Number of subsamples to run tsne and plot
sample_size=500000 # Number of samples from  dataset
uni_ind=np.rint(np.linspace(start=0, stop=sample_size-1, num=num_sample)).astype(int)
for alpha in alphas:
    data_files[str(alpha)] = glob.glob(cwd+'/exp_local/**/2024.07*/**/*'+domain+'*/buffer'+str(alpha)+'/*.npz',recursive=True)
    # print(data_files.keys(), len(data_files[str(alpha)]))

keys=data_files.keys()
np.random.seed=100
random.seed(100)

plot_obs_data=[]

# Store all the data points for each different alpha as a combined np array 

for key in keys:
    obs_data_temp=[]
    act_data_temp=[]
    for file in data_files[key]:
        with np.load(file) as data:
            obs_data_temp.append(data['observation'])
            act_data_temp.append(data['action'])
    obs_np_temp=np.array(obs_data_temp)
    act_np_temp=np.array(act_data_temp)
    obs_2d=obs_np_temp.reshape((obs_np_temp.shape[0]*obs_np_temp.shape[1]),obs_np_temp.shape[2])
    act_2d=act_np_temp.reshape((act_np_temp.shape[0]*act_np_temp.shape[1]),act_np_temp.shape[2])
    obs_data[key]=obs_2d[uni_ind,:] # random subsampled data to run tsne
    act_data[key]=act_2d[uni_ind,:] # random subsampled data to run tsne
    # obs_data[key]=obs_2d[:num_sample,:] # random subsampled data to run tsne
    # act_data[key]=act_2d[:num_sample,:] # random subsampled data to run tsne
    # obs_data[key]=obs_2d[np.random.choice(obs_2d.shape[0],num_sample,replace=False)] # random subsampled data to run tsne
    # act_data[key]=act_2d[np.random.choice(act_2d.shape[0],num_sample,replace=False)] # random subsampled data to run tsne
    plot_obs_data.append(obs_data[key])

plot_array=np.array(plot_obs_data)
plot_array=plot_array.reshape((plot_array.shape[0]*plot_array.shape[1]),plot_array.shape[2])
# print(np.shape(plot_array))

# print(np.shape(obs_data['0.5']))
phate_op = phate.PHATE(n_jobs=-1,gamma=0,decay=500,knn=10,t=20,mds_solver='smacof')
# phate_op = phate.PHATE(n_jobs=-1,gamma=0,decay=500,knn=10,t=20)
X_embedded = phate_op.fit_transform(plot_array)

colors = cm.viridis(np.linspace(0, 1, len(uni_ind)))
ind=0
ind_key=0
legends=list(keys)
fig, ax = plt.subplots(3,3,sharex=True,sharey=True)
for c in range(len(legends)):
    ax[c//3,c%3].scatter(X_embedded[ind:ind+num_sample,0],X_embedded[ind:ind+num_sample,1],s=5,color=colors,alpha=0.6)
    ax[c//3,c%3].title.set_text(r'$\alpha$ :'+legends[c])
    ind+=num_sample
    print(c)
    print(ind)
    # ind_key+=1
# ax.legend()
plt.show()
# X_embedded1 = TSNE(n_components=2, learning_rate='auto',
#                   init='random', perplexity=30).fit_transform(obs_data_array2d1[-10000:])
# X_embedded2 = TSNE(n_components=2, learning_rate='auto',
#                   init='random', perplexity=30).fit_transform(obs_data_array2d2[-10000:])
# plt.scatter(X_embedded[-10000:,0],X_embedded[-10000:,1],s=5,color='blue',alpha=0.2) 
# plt.scatter(X_embedded[:10000,0],X_embedded[:10000,1],s=5,color='green',alpha=0.2) 
# # plt.scatter(X_embedded[-10000:,0],X_embedded[-10000:,1],s=2,color='red',alpha=0.5) 
# plt.show()

exit()
npz_list_beh_beta0_5=glob.glob('/home/aamodh/Documents/Reinforcement learning/gbe-url/exp_local/2024.04.19/093546_walker_icm_apt/buffer/*.npz')
npz_list_beh_beta3_0=glob.glob('/home/aamodh/Documents/Reinforcement learning/gbe-url/exp_local/2024.04.19/214635_walker_icm_apt/buffer/*.npz')
obs_data_list1=[]
act_data_list1=[]
obs_data_list2=[]
act_data_list2=[]
for file in npz_list_beh_beta3_0:
    with np.load(file) as data:
        # print(data.files)
        obs_data_list1.append(data['observation'])
        act_data_list1.append(data['action'])

for file in npz_list_beh_beta0_5:
    with np.load(file) as data:
        # print(data.files)
        obs_data_list2.append(data['observation'])
        act_data_list2.append(data['action'])


obs_data_array1=np.array(obs_data_list1)
act_data_array1=np.array(act_data_list1)

obs_data_array2=np.array(obs_data_list2)
act_data_array2=np.array(act_data_list2)

obs_data_array2d1=obs_data_array1.reshape((obs_data_array1.shape[0]*obs_data_array1.shape[1]),obs_data_array1.shape[2])
act_data_array2d1=act_data_array1.reshape((act_data_array1.shape[0]*act_data_array1.shape[1]),act_data_array1.shape[2])

obs_data_array2d2=obs_data_array2.reshape((obs_data_array2.shape[0]*obs_data_array2.shape[1]),obs_data_array2.shape[2])
act_data_array2d2=act_data_array2.reshape((act_data_array2.shape[0]*act_data_array2.shape[1]),act_data_array2.shape[2])
# print(np.shape(obs_data_list))
# print(np.shape(act_data_list))

# print(obs_data_array2d.shape)
# print(obs_data_array[0][0][:])
# print(obs_data_array2d[0][:])
# with np.load('2024.05.03/foo.npz') as data:
#     a = data['a']

data=np.concatenate((obs_data_array2d1[-10000:],obs_data_array2d2[-10000:]))
# data=np.concatenate((obs_data_array2d1[:10000],obs_data_array2d2[:10000]))
# print(data.shape)
# X = np.array([[0, 0, 0], [0, 1, 1], [1, 0, 1], [1, 1, 1]])
# print(X.shape)
X_embedded = TSNE(n_components=2, learning_rate='auto',
                  init='random', perplexity=20, verbose=1).fit_transform(data)
# X_embedded1 = TSNE(n_components=2, learning_rate='auto',
#                   init='random', perplexity=30).fit_transform(obs_data_array2d1[-10000:])
# X_embedded2 = TSNE(n_components=2, learning_rate='auto',
#                   init='random', perplexity=30).fit_transform(obs_data_array2d2[-10000:])
plt.scatter(X_embedded[-10000:,0],X_embedded[-10000:,1],s=5,color='blue',alpha=0.2) 
plt.scatter(X_embedded[:10000,0],X_embedded[:10000,1],s=5,color='green',alpha=0.2) 
# plt.scatter(X_embedded[-10000:,0],X_embedded[-10000:,1],s=2,color='red',alpha=0.5) 
plt.show()