import numpy as np
import os
import pickle
def discount_cumsum(x):
    discount_cumsum = np.zeros_like(x)
    discount_cumsum[-1] = x[-1]
    for t in reversed(range(x.shape[0]-1)):
        discount_cumsum[t] = x[t] + discount_cumsum[t+1]
    return discount_cumsum
def find_closest_indices(filename, multipliers=[0.8, 0.6, 0.4, 0.2, 0], threshold=0.2):
    with open(filename, 'rb') as file:
        data = pickle.load(file)
    closest_indices = []
    for k in range(len(data)):
        print(k)
        s = data[k]['observations']
        r = data[k]['rewards']
        r = discount_cumsum(r)
        # 遍历每个 s 的值
        closest_indices = []
        for i in range(len(s)):
            current_r = r[i]
            closest_index_for_s = []
            
            # 对每个乘数，找到后面最近的下标
            for multiplier in multipliers:
                target_value = current_r * multiplier
                closest_index = -1
                min_distance = float('inf')
                
                # 从当前位置之后寻找最近的值
                for j in range(i + 1, len(r)):
                    if abs(r[j] - target_value) < min_distance:
                        min_distance = abs(r[j] - target_value)
                        closest_index = j
                
                # 检查距离是否小于 0.2 倍的当前 r 值
                if closest_index != -1 and abs(r[closest_index] - target_value) > threshold * current_r:
                    closest_index_for_s.append(-1)
                else:
                    closest_index_for_s.append(closest_index)
            
            closest_indices.append(closest_index_for_s)
        data[k]['closest_indices'] = closest_indices
    
for filename in os.listdir("."):
    if filename.endswith(".pkl") and os.path.isfile(filename):
        print(filename)
        find_closest_indices(filename)

            
