import numpy as np
import os

path = input('请输入data.npz路径:')
data = np.load(path)
# for example : data/expert/large/data.npz

print('rewards的形状:',data['rewards'].shape)
print('terminals的形状:',data['terminals'].shape)
print('timeouts的形状:',data['timeouts'].shape)
print('observations的形状:',data['observations'].shape)
print('actions的形状:',data['actions'].shape)

print('timeout 中为 1的个数：',data["timeouts"].sum())
print('terminals 中为 1的个数：',data["terminals"].sum())
print('平均rewards',data["rewards"].sum() / data["timeouts"].sum())

# for i in range(data['rewards'].shape[0]):
#     if(data['timeouts'][i] == 1):
#         print(i)