import os

import numpy as np
import scipy.io
import pickle

data_dir = '/home/nwaftp23/projects/def-dpmeger/nwaftp23/uncertainty_estimation/mujoco/SARCOS_test_aquisition/noiseweight0.2_modes0/'

mat_train = scipy.io.loadmat(os.path.join(data_dir, 'sarcos_inv.mat'))
mat_test = scipy.io.loadmat(os.path.join(data_dir, 'sarcos_inv_test.mat'))
mat_train = mat_train['sarcos_inv']
mat_test = mat_test['sarcos_inv_test']

state_train = mat_train[:,:14]
action_train = mat_train[:,14:21]
next_state_train = mat_train[:,21:]
unseen_train = np.zeros(action_train.shape)

state_test = mat_train[:,:14]
action_test = mat_train[:,14:21]
next_state_test = mat_train[:,21:]
unseen_test = np.zeros(action_test.shape)


reward = 0
train_size = 2000
oracle_size = state_train.shape[0]-train_size
test_size = state_test.shape[0]
train_choice = np.random.choice(range(state_train.shape[0]), size=(train_size,), replace=False)
ind = np.zeros(state_train.shape[0], dtype=bool)
ind[train_choice] = True
state_oracle = state_train[~ind]
action_oracle = action_train[~ind]
next_state_oracle = next_state_train[~ind]
unseen_oracle = unseen_train[~ind]
state_train = state_train[ind]
action_train = action_train[ind]
next_state_train = next_state_train[ind]
unseen_train = unseen_train[ind]

train_list = []
for i in range(state_train.shape[0]):
    done = False
    if (i+1) % 200 == 0:
        done = True

    train_tuple = (state_train[i].astype(np.float32), action_train[i].astype(np.float32), reward, next_state_train[i].astype(np.float32), done, unseen_train[i].astype(np.float32))
    train_list.append(train_tuple)


oracle_list = []
for i in range(state_oracle.shape[0]):
    done = False
    if (i+1) % 200 == 0:
        done = True
    oracle_tuple = (state_oracle[i].astype(np.float32), action_oracle[i].astype(np.float32), reward, next_state_oracle[i].astype(np.float32), done, unseen_oracle[i].astype(np.float32))
    oracle_list.append(oracle_tuple)


test_list = []
for i in range(state_test.shape[0]):
    done = False
    if (i+1) % 200 == 0:
        done = True
    test_tuple = (state_test[i].astype(np.float32), action_test[i].astype(np.float32), reward, next_state_test[i].astype(np.float32), done, unseen_test[i].astype(np.float32))
    test_list.append(test_tuple)

with open(os.path.join(data_dir,'train_buffer_seed14.pkl'), 'wb') as f:
    pickle.dump(train_list, f)
with open(os.path.join(data_dir,'oracle_buffer_seed14.pkl'), 'wb') as f:
    pickle.dump(oracle_list, f)
with open(os.path.join(data_dir,'test_buffer_seed14.pkl'), 'wb') as f:
    pickle.dump(test_list, f)
