import gym
import numpy as np

import collections
import pickle
import d4rl
import numpy as np
import os
import pickle
# k个subgoal
def generate_arithmetic_sequence(k):  # 等差序列，每个元素大于零，否则用零代替
    if k <= 0:  # 检查k是否为正整数
        return []  # 如果不是，返回空列表

    sequence = []
    threshold = round((1 / k), 2)  # 确定每次递减的差值
    current_value = 1 - threshold  # 第一个元素从1减去一个差值开始
    for i in range(k):
        # 检查是否小于零，如果小于零则设置为零
        if current_value < 0:
            current_value = 0
        sequence.append(round(current_value, 2))  # 添加到序列中
        current_value -= threshold  # 减去差值
    return sequence

def discount_cumsum(x):
    discount_cumsum = np.zeros_like(x)
    discount_cumsum[-1] = x[-1]
    for t in reversed(range(x.shape[0]-1)):
        discount_cumsum[t] = x[t] + discount_cumsum[t+1]
    return discount_cumsum

def find_closest_indices(filename):#threshold=1/(0.1*min_traj_len),动态的pre没有用到threshold了
    with open(filename, 'rb') as file:
        data = pickle.load(file)
    
    # 计算最短轨迹的长度,.0.1*min_traj_len即为希望生成sub数量
    min_traj_len = float('inf')  # 初始化为无穷大
    for k in range(len(data)):
        s = data[k]['observations']
        current_len = len(s)
        # 更新最短轨迹长度
        if current_len < min_traj_len:
            min_traj_len = current_len

    threshold=round( (1/10) , 2)#保留两位小数,1/k

    #获取乘数列表,从1向下依次按阈值递减k次得到列表,目前[0.8,0.6,0.4,0.2,0.0]
    multipliers=[]
    multipliers=generate_arithmetic_sequence(10) #随着轨迹长度动态变化,sub个数为0.1*min_traj_len[取整]
    
    closest_indices = []
    return_reals = []
    for k in range(len(data)):
        print(k)
        s = data[k]['observations']
        r = data[k]['rewards']
        r_last = r[len(r)-1]
        r = discount_cumsum(r)
        r_last = r[-1]-r_last
        # 遍历每个 s 的值
        r -= r_last
        closest_indices = []
        last_r = r[len(r)-1]
        return_real = []
        for i in range(len(s)):
            current_r = r[i]
            closest_index_for_s = []
            return_real.append(current_r)
            
            # 对每个乘数，计算 target_index
            for multiplier in multipliers:
                # 计算 target_index 并转换为整数
                target_index = int((len(s)-1 - i) * multiplier) + i
                # 确保 target_index 不超过 len(s)
                target_index = min(target_index, len(s)-1)
                closest_index_for_s.append(target_index)

            # 对每个乘数，找到后面最近的下标
            # for multiplier in multipliers:
            #     target_value = current_r * multiplier
            #     target_index = (len(s)-i)* multiplier + i
            #     closest_index = -1
            #     min_distance = float('inf')
                
            #     # 从当前位置之后寻找最近的值
            #     for j in range(i + 1, len(r)):
            #         if abs(r[j] - target_value) < min_distance:
            #             min_distance = abs(r[j] - target_value)
            #             closest_index = j
                
            #     # 检查距离是否小于 0.2 倍的当前 r与末尾r的 差值
            #     if closest_index != -1 and abs(r[closest_index] - target_value) > threshold * current_r:
            #         closest_index_for_s.append(-1)
            #     else:
            #         closest_index_for_s.append(closest_index)
            
            closest_indices.append(closest_index_for_s)
        data[k]['closest_indices'] = closest_indices
        data[k]['return_reals'] = return_real
        data[k]['subgoal_num']= 10
    with open(filename,'wb') as file :
        pickle.dump(data,file)
for env_name in ['hopper']:
    for dataset_type in ['medium-replay','medium-expert','medium']:#'umaze','umaze-dense','medium', 'medium-dense', 'large', 'large-dense'
        name = f'data-ablation/{env_name}-{dataset_type}-v2-timestep.h10py'
        print(name)
        find_closest_indices(name)
