import numpy as np
def discount_cumsum(x, gamma):
    discount_cumsum = np.zeros_like(x)
    discount_cumsum[-1] = x[-1]
    for t in reversed(range(x.shape[0]-1)):
        discount_cumsum[t] = x[t] + gamma * discount_cumsum[t+1]
    return discount_cumsum

def get_val_path(path,batch_size=20):
    num_trajectories = len(path)
    batch_inds = np.random.choice(
            np.arange(num_trajectories),
            size=batch_size,
            replace=True,
        )
    paths = []
    for i in range(batch_size):
        traj = path[batch_inds[i]]
        paths.append(traj)
    return paths

def process_data(trajectories,variant,mode):
    states, traj_lens, returns = [], [], []
    for path in trajectories:
        states.append(path['observations'])
        traj_lens.append(len(path['observations']))
        returns.append(path['rewards'].sum())
    traj_lens, returns = np.array(traj_lens), np.array(returns)

    states = np.concatenate(states, axis=0)


    num_timesteps = sum(traj_lens)

    # print('=' * 50)
    # print(f'Starting new experiment:{mode}')
    # print(f'{len(traj_lens)} trajectories, {num_timesteps} timesteps found')
    # print(f'Average return: {np.mean(returns):.2f}, std: {np.std(returns):.2f}')
    # print(f'Max return: {np.max(returns):.2f}, min: {np.min(returns):.2f}')
    # print('=' * 50)

    K = variant['ICRL_K']
    pct_traj = variant.get('ICRL_pct_traj', 1.)

    # only train on top pct_traj trajectories (for %BC experiment)
    num_timesteps = max(int(pct_traj*num_timesteps), 1)
    sorted_inds = np.argsort(returns)  # lowest to highest 奖励由低到高排序的，轨迹序列号
    num_trajectories = 1
    timesteps = traj_lens[sorted_inds[-1]]
    ind = len(trajectories) - 2
    while ind >= 0 and timesteps + traj_lens[sorted_inds[ind]] <= num_timesteps:
        timesteps += traj_lens[sorted_inds[ind]]
        num_trajectories += 1
        ind -= 1
    sorted_inds = sorted_inds[-num_trajectories:]

    # 轨迹步长，作为重要性采样权重used to reweight sampling so we sample according to timesteps instead of trajectories
    p_sample = traj_lens[sorted_inds] / sum(traj_lens[sorted_inds])
    
    return p_sample,states, traj_lens, returns,num_trajectories,sorted_inds
