import numpy as np

class her_sampler:
    def __init__(self, replay_strategy, replay_k, env_params, env_name, reward_func=None, clipv=200, gamma=0.99):
        self.env_name = env_name
        self.env_params = env_params
        self.replay_strategy = replay_strategy
        self.replay_k = replay_k
        self.gamma = gamma  # 新增：CRL的折扣因子
        if self.replay_strategy == 'future':
            self.future_p = 1 - (1. / (1 + replay_k))
        else:
            self.future_p = 0
        self.reward_func = reward_func
        self.clipv = clipv

    def sample_her_transitions(self, episode_batch, batch_size_in_transitions):
        T = episode_batch['actions'].shape[1]
        rollout_batch_size = episode_batch['actions'].shape[0]
        batch_size = batch_size_in_transitions
        # Select which rollouts and timesteps to be used
        episode_idxs = np.random.randint(0, rollout_batch_size, batch_size)
        t_samples = np.random.randint(T, size=batch_size)
        transitions = {key: episode_batch[key][episode_idxs, t_samples].copy() for key in episode_batch.keys()}

        # HER idx
        random_index = np.random.uniform(size=batch_size)
        her_indexes = np.where(random_index < self.future_p)
        her_indexes2 = np.where(random_index >= self.future_p)

        # CRL-inspired geometric sampling for future_offset
        # 1. Compute remaining timesteps for each sample
        remaining_steps = T - t_samples - 1  # 从t_samples+1到T-1
        # 2. Generate geometric distribution probabilities
        max_steps = np.max(remaining_steps)
        probs = self.gamma ** np.arange(max_steps)  # [gamma^0, gamma^1, ..., gamma^(max_steps-1)]
        probs = probs / np.sum(probs)  # 归一化概率
        # 3. Sample future_offset from geometric distribution
        future_offset = np.zeros(batch_size, dtype=int)
        for i in range(batch_size):
            if remaining_steps[i] > 0:  # 确保有未来时间步
                future_offset[i] = np.random.choice(
                    remaining_steps[i], size=1, p=probs[:remaining_steps[i]] / np.sum(probs[:remaining_steps[i]])
                )
        future_t = (t_samples + 1 + future_offset)[her_indexes]

        # Replace goal with achieved goal
        future_ag = episode_batch['ag'][episode_idxs[her_indexes], future_t]
        transitions['g'][her_indexes] = future_ag

        # Update other fields (unchanged)
        future_offset_half = (future_offset * 1).astype(int)
        transitions['obsfuture'] = episode_batch['obs'][episode_idxs, t_samples + 1 + future_offset_half]
        transitions['sg'] = episode_batch['ag'][episode_idxs, t_samples + 1 + future_offset_half]
        transitions['fg'] = episode_batch['ag'][episode_idxs, -1]
        transitions['obst'] = episode_batch['obs'][episode_idxs, :]
        transitions['agt'] = episode_batch['ag'][episode_idxs, :]
        transitions['actionst'] = episode_batch['actions'][episode_idxs, :]
        transitions['actionst'] = np.concatenate((transitions['actionst'], transitions['actionst'][:, -1:, :]), 1)

        step = self.env_params['max_timesteps'] + 1
        t_samples_all = np.arange(0, step, 1)
        future_offset_all = np.random.uniform(size=step) * (T - t_samples_all)
        future_offset_all = future_offset_all.astype(int)
        transitions['herg'] = episode_batch['ag'][episode_idxs][:, future_offset_all + t_samples_all, :]
        transitions['trag'] = np.tile(episode_batch['ag'][episode_idxs][:, -1:, :], (1, step, 1))

        # Compute rewards (unchanged)
        if self.env_name[:3] == 'Ant' or self.env_name[:5] == 'Fetch' or self.env_name[:4] == 'Maze' or self.env_name[:5] == 'Point':
            transitions['r'] = np.expand_dims(self.reward_func(transitions['ag_next'], transitions['g'], None), 1)
        elif self.env_name[:6] == 'Sawyer':
            transitions['r'] = -np.linalg.norm(transitions['rag_next'] - transitions['dg'], axis=1)
        elif self.env_name == 'MultiGoal':
            r = self.reward_func(transitions['ag_next'], transitions['g'], transitions['actions'])
            transitions['r'] = np.expand_dims(r, 1)

        transitions = {k: transitions[k].reshape(batch_size, *transitions[k].shape[1:]) for k in transitions.keys()}
        transitions['T'] = T

        return transitions, her_indexes, her_indexes2, episode_idxs, t_samples, future_offset