import h5py
import numpy as np

dataset = h5py.File('Ant_maze_hardest-maze_noisy_multistart_True_multigoal_True_sparse.hdf5', 'r')
dkeys = ['observations', 'actions', 'rewards', 'terminals', 'timeouts', 'infos/goal', 'infos/qpos', 'infos/qvel']
dataset_copy = [obs, acs, rews, terminals, timeouts, goals, qpos, qvel] = [np.array(dataset[item_name])[:502502] for item_name in dkeys]
data_copy = {dkeys[i]: dataset_copy[i] for i in range(len(dkeys))}

def reset_data():
    return {'observations': [],
            'actions': [],
            'terminals': [],
            'rewards': [],
            'timeout': [],
            }

def append_data(data, s, a, r, done, timeout):
    data['observations'].append(s)
    data['actions'].append(a)
    data['rewards'].append(r)
    data['terminals'].append(done)
    data['timeout'].append(timeout)

def npify(data):
    for k in data:
        if k == 'terminals':
            dtype = np.bool_
        else:
            dtype = np.float32

        data[k] = np.array(data[k], dtype=dtype)

import time
start_time = time.time()

UL = (24, 14)
UR = (27, 18)
LR = (7, 10)

blocked_LR = np.zeros((1000000))
blocked_UR = np.zeros((1000000))
blocked_UL = np.zeros((1000000))
data = reset_data()

cond = np.bitwise_or(timeouts.astype('bool'),
                     terminals.astype('bool'))
goals_cond = (goals[1:] - goals[:-1]).sum(-1) != 0
cond[:-1] = np.bitwise_or(cond[:-1], goals_cond)

for i in range(len(obs)):
    if obs[i][0] > UL[0] and obs[i][1] < UL[1]:
        blocked_UL[i] = 1
    if obs[i][0] > UR[0] and obs[i][1] > UR[1]:
        blocked_UR[i] = 1
    if obs[i][0] < LR[0] and obs[i][1] > LR[1]:
        blocked_LR[i] = 1
print(np.sum(blocked_LR), np.sum(blocked_LL), np.sum(blocked_UR), np.sum(blocked_UL), "items filtered out of", i, "total items")

tblocked_UR = np.zeros((8114))
tblocked_UL = np.zeros((8114))
tblocked_LR = np.zeros((8114))

traj_c = 0
start_ind = 0
cond_inds = [0]
for i in range(len(obs)):
    if cond[i]:
        if np.sum(blocked_LR[start_ind:i]):
            tblocked_LR[traj_c] = 1
        if np.sum(blocked_UR[start_ind:i]):
            tblocked_UR[traj_c] = 1
        if np.sum(blocked_LL[start_ind:i]):
            tblocked_LL[traj_c] = 1
        if np.sum(blocked_UL[start_ind:i]):
            tblocked_UL[traj_c] = 1
        start_ind = i
        traj_c += 1
        cond_inds.append(i + 1)

for region_fname, tblocked_region in zip(['Antmaze_filtered_' + region + '.hdf5' for region in ['LR', 'UL', 'UR']], [tblocked_LR, tblocked_UL, tblocked_UR]):
    data_region = {key: [] for key in dkeys}
    for t_num in range(len(tblocked_region)):
        if tblocked_region[t_num]:
            for key in dkeys:
                data_region[key].extend(data_copy[key][cond_inds[t_num]:cond_inds[t_num+1]])


    dataset_region = h5py.File(region_fname, 'w')

    npify(data_region)
    for k in data_region:
        dataset_region.create_dataset(k, data=data_region[k], compression='gzip')

exit()
