import gym
import numpy as np
import collections
import pickle
import d4rl
import numpy as np
import os
import pickle
import re

# k个subgoal
def generate_arithmetic_sequence(k):  # 等差序列，每个元素大于零，否则用零代替
    if k <= 0:  # 检查k是否为正整数
        return []  # 如果不是，返回空列表
    sequence = []
    threshold = round((1 / k), 2)  # 确定每次递减的差值
    current_value = 1 - threshold  # 第一个元素从1减去一个差值开始
    for i in range(k):
        # 检查是否小于零，如果小于零则设置为零
        if current_value < 0:
            current_value = 0
        sequence.append(round(current_value, 2))  # 添加到序列中
        current_value -= threshold  # 减去差值
    return sequence

def discount_cumsum(x):
    discount_cumsum = np.zeros_like(x)
    discount_cumsum[-1] = x[-1]
    for t in reversed(range(x.shape[0]-1)):
        discount_cumsum[t] = x[t] + discount_cumsum[t+1]
    return discount_cumsum

def find_closest_indices(filename, subgoal_num):#threshold=1/(0.1*min_traj_len),动态的pre没有用到threshold了
    with open(filename, 'rb') as file:
        data = pickle.load(file)
    
    # 计算最短轨迹的长度,.0.1*min_traj_len即为希望生成sub数量
    min_traj_len = float('inf')  # 初始化为无穷大
    for k in range(len(data)):
        s = data[k]['observations']
        current_len = len(s)
        # 更新最短轨迹长度
        if current_len < min_traj_len:
            min_traj_len = current_len

    threshold=round( (1/subgoal_num) , 2)#保留两位小数,1/k

    #获取乘数列表,从1向下依次按阈值递减k次得到列表,目前[0.9,0.8,0.7....]
    multipliers=[]
    multipliers=generate_arithmetic_sequence(subgoal_num) #随着轨迹长度动态变化,sub个数为0.1*min_traj_len[取整]
    
    closest_indices = []
    return_reals = []
    for k in range(len(data)):
        print(k)
        s = data[k]['observations']
        r = data[k]['rewards']
        r_last = r[len(r)-1]
        r = discount_cumsum(r)
        r_last = r[-1]-r_last
        r -= r_last
        closest_indices = []
        last_r = r[len(r)-1]
        return_real = []
        for i in range(len(s)):
            current_r = r[i]
            closest_index_for_s = []
            return_real.append(current_r)
            for multiplier in multipliers:
                target_value = current_r * multiplier
                closest_index = -1
                min_distance = float('inf')
                
                # 从当前位置之后寻找最近的值
                for j in range(i + 1, len(r)):
                    if abs(r[j] - target_value) < min_distance:
                        min_distance = abs(r[j] - target_value)
                        closest_index = j
                
                # 检查距离是否小于 0.2 倍的当前 r与末尾r的 差值
                if closest_index != -1 and abs(r[closest_index] - target_value) > threshold * current_r:
                    closest_index_for_s.append(-1)
                else:
                    closest_index_for_s.append(closest_index)
            
            closest_indices.append(closest_index_for_s)
        data[k]['closest_indices'] = closest_indices
        data[k]['return_reals'] = return_real
        data[k]['subgoal_num']= subgoal_num
    with open(filename,'wb') as file :
        pickle.dump(data,file)

# 自动遍历 data-walker2d 下所有 walker2d-medium(-expert/-replay)-v2.hXpy 文件
data_dir = '/home/cike/plandt/gym/data/data-maze2d'
for dataset_type in ['medium-dense']:
        filename = f'{data_dir}/maze2d-{dataset_type}-v1.h{idx}py'
        if os.path.exists(filename):
            print(f'处理文件: {filename}')
            find_closest_indices(filename, idx)
        else:
            print(f'文件不存在: {filename}')
