import gym
import numpy as np

import collections
import pickle
import d4rl
import numpy as np
import os
import pickle
import random
# k个subgoal
def generate_arithmetic_sequence(k):  # 等差序列，每个元素大于零，否则用零代替
    if k <= 0:  # 检查k是否为正整数
        return []  # 如果不是，返回空列表

    sequence = []
    threshold = round((1 / k), 2)  # 确定每次递减的差值
    current_value = 1 - threshold  # 第一个元素从1减去一个差值开始
    for i in range(k):
        # 检查是否小于零，如果小于零则设置为零
        if current_value < 0:
            current_value = 0
        sequence.append(round(current_value, 2))  # 添加到序列中
        current_value -= threshold  # 减去差值
    return sequence

def discount_cumsum(x):
    discount_cumsum = np.zeros_like(x)
    discount_cumsum[-1] = x[-1]
    for t in reversed(range(x.shape[0]-1)):
        discount_cumsum[t] = x[t] + discount_cumsum[t+1]
    return discount_cumsum

def find_closest_indices(filename):#threshold=1/(0.1*min_traj_len),动态的pre没有用到threshold了
    with open(filename, 'rb') as file:
        data = pickle.load(file)
    
    # 计算最短轨迹的长度,.0.1*min_traj_len即为希望生成sub数量
    min_traj_len = float('inf')  # 初始化为无穷大
    for k in range(len(data)):
        s = data[k]['observations']
        current_len = len(s)
        # 更新最短轨迹长度
        if current_len < min_traj_len:
            min_traj_len = current_len

    threshold=round( (1/10) , 2)#保留两位小数,1/k

    #获取乘数列表,从1向下依次按阈值递减k次得到列表,目前[0.8,0.6,0.4,0.2,0.0]
    multipliers=[]
    multipliers=generate_arithmetic_sequence(10) #随着轨迹长度动态变化,sub个数为0.1*min_traj_len[取整]
    
    closest_indices = []
    return_reals = []
    for k in range(len(data)):
        print(k)
        s = data[k]['observations']
        r = data[k]['rewards']
        r_last = r[len(r)-1]
        r = discount_cumsum(r)
        r_last = r[-1]-r_last
        # 遍历每个 s 的值
        r -= r_last
        closest_indices = []
        last_r = r[len(r)-1]
        return_real = []
        # for i in range(len(s)):
        #     current_r = r[i]
        #     closest_index_for_s = []
        #     return_real.append(current_r)
            
        #     # 对每个乘数，计算 target_index
        #     for multiplier in multipliers:
        #         # 计算 target_index 并转换为整数
        #         target_index = int((len(s)-1 - i) * multiplier) + i
        #         # 确保 target_index 不超过 len(s)
        #         target_index = min(target_index, len(s)-1)
        #         closest_index_for_s.append(target_index)

        for i in range(len(s)):
            current_r = r[i]
            closest_index_for_s = []
            return_real.append(current_r)
            
            # 创建一个从 i+1 到 len(s)-1 的有效索引列表
            if i < len(s) - 1:
                valid_indices = list(range(i + 1, len(s)))  # i+1 到 len(s)-1（包括）
                # 随机选择 len(multipliers) 个索引
                selected_indices = random.sample(valid_indices, len(multipliers)) if len(valid_indices) >= len(multipliers) else valid_indices[:len(multipliers)]
                selected_indices.sort()  # 确保子目标是有序的
                closest_index_for_s = selected_indices
                # 检查是否索引不足，补充 -1
                if len(closest_index_for_s) < len(multipliers):
                    closest_index_for_s += [-1] * (len(multipliers) - len(closest_index_for_s))
            else:
                # 如果已经是最后一个元素，全部设置为 -1
                closest_index_for_s = [-1] * len(multipliers)
            
            closest_indices.append(closest_index_for_s)
        data[k]['closest_indices'] = closest_indices
        data[k]['return_reals'] = return_real
        data[k]['subgoal_num']= 10
    with open(filename,'wb') as file :
        pickle.dump(data,file)
for env_name in ['hopper']:
    for dataset_type in ['medium-replay','medium-expert','medium']:#'umaze','umaze-dense','medium', 'medium-dense', 'large', 'large-dense'
        name = f'data-ablation/{env_name}-{dataset_type}-v2-random.h10py'
        print(name)
        find_closest_indices(name)

