import numpy as np
import random
from collections import namedtuple
import torch
Experience = namedtuple('Experience', ['history_dict', 'action', 'reward', 'next_history_dict', 'goal', 'done'])

class PrioritizedReplayBuffer:
    def __init__(self, capacity, k_future=4, alpha=0.6, beta=0.4, beta_increment=1e-5, reward_threshold=5e-3, reward_mode='combined'):
        """
        优先级经验回放缓冲区
        
        参数:
            capacity: 缓冲区容量
            alpha: 确定优先级使用程度的参数，0表示均匀采样
            beta: 重要性采样指数，0表示不使用重要性采样
            beta_increment: beta随时间增加的量
            reward_threshold: 达到目标的阈值
            reward_mode: 奖励计算模式
        """
        self.capacity = capacity
        self.k_future = k_future
        self.alpha = alpha
        self.beta = beta
        self.beta_increment = beta_increment
        self.reward_threshold = reward_threshold
        self.reward_mode = reward_mode
        
        self.buffer = []
        self.priorities = np.zeros((capacity,), dtype=np.float32)
        self.position = 0
        self.size = 0
        self.episode_buffer = []
        
    def add(self, history_dict, action, reward, next_history_dict, goal, done=False, priority=None):
        """添加单步过渡"""
        if priority is None:
            priority = max(self.priorities) if self.size > 0 else 1.0
            
        if len(self.buffer) < self.capacity:
            self.buffer.append(Experience(history_dict, action, reward, next_history_dict, goal, done))
        else:
            self.buffer[self.position] = Experience(history_dict, action, reward, next_history_dict, goal, done)
            
        self.priorities[self.position] = priority
        self.position = (self.position + 1) % self.capacity
        self.size = min(self.size + 1, self.capacity)
        
    def add_to_episode(self, history_dict, action, next_history_dict, goal):
        """将转换添加到当前episode缓冲区"""
        self.episode_buffer.append((history_dict, action, next_history_dict, goal))
    
    def _compute_reward(self, distance, done, prev_distance=None):
        """根据reward_mode计算奖励"""
        if self.reward_mode == 'binary':
            return 0.0 if done else -1.0
        
        elif self.reward_mode == 'distance':
            return -distance
        
        elif self.reward_mode == 'exp':
            return np.exp(-10 * distance) - 1
        
        elif self.reward_mode == 'progress':
            if prev_distance is None:
                return -distance
            progress = prev_distance - distance
            if progress > 0:
                return progress * 2.0
            else:
                return progress * 3.0
        
        elif self.reward_mode == 'step':
            if done:
                return 0.0
            elif distance < self.reward_threshold * 3:
                return -0.3
            elif distance < self.reward_threshold * 10:
                return -0.7
            else:
                return -1.0
        
        elif self.reward_mode == 'combined':
            distance_reward = -distance * 2.0
            goal_reward = 10.0 if done else 0.0
            return distance_reward + goal_reward
        
        else:
            return 0.0 if done else -1.0
    
    def process_episode(self):
        """处理并存储当前episode，应用HER策略"""
        if len(self.episode_buffer) == 0:
            return
        step_distances = {}
        final_states = []
        for t in range(min(4, len(self.episode_buffer))):
            if t >= len(self.episode_buffer) - 1:
                break
            idx = len(self.episode_buffer) - 1 - t
            _, _, next_hist, _ = self.episode_buffer[idx]
            final_states.append(self._get_current_output(next_hist))
        for t, (history_dict, action, next_history_dict, goal) in enumerate(self.episode_buffer):
            current_output = self._get_current_output(next_history_dict)
            distance = np.linalg.norm(current_output - goal)
            done = distance < self.reward_threshold
            prev_distance = step_distances.get((t-1, tuple(goal.flatten())), None)
            reward = self._compute_reward(distance, done, prev_distance)
            step_distances[(t, tuple(goal.flatten()))] = distance
            priority = 1.0 if done else distance  
            self.add(history_dict, action, reward, next_history_dict, goal, done, priority)
            for final_state in final_states:
                new_goal = final_state
                new_distance = np.linalg.norm(current_output - new_goal)
                new_done = new_distance < self.reward_threshold
                new_goal_key = tuple(new_goal.flatten())
                prev_new_distance = step_distances.get((t-1, new_goal_key), None)
                new_reward = self._compute_reward(new_distance, new_done, prev_new_distance)
                step_distances[(t, new_goal_key)] = new_distance
                new_priority = 1.0 if new_done else new_distance
                self.add(history_dict, action, new_reward, next_history_dict, new_goal, new_done, new_priority)
            future_idx_range = range(t+1, len(self.episode_buffer))
            future_idxs = random.sample(future_idx_range, min(self.k_future, len(future_idx_range))) if future_idx_range else []
            
            for future_idx in future_idxs:
                future_transition = self.episode_buffer[future_idx]
                future_next_history_dict = future_transition[2]
                new_goal = self._get_current_output(future_next_history_dict)
                new_distance = np.linalg.norm(current_output - new_goal)
                new_done = new_distance < self.reward_threshold
                new_goal_key = tuple(new_goal.flatten())
                prev_new_distance = step_distances.get((t-1, new_goal_key), None)
                new_reward = self._compute_reward(new_distance, new_done, prev_new_distance)
                step_distances[(t, new_goal_key)] = new_distance
                new_priority = 1.0 if new_done else new_distance
                self.add(history_dict, action, new_reward, next_history_dict, new_goal, new_done, new_priority)
        self.episode_buffer = []
    
    def _get_current_output(self, history_dict):
        """从历史字典中获取当前输出"""
        if 'outputs' in history_dict:
            return history_dict['outputs'][0, -1].copy()
        elif 'prev_outputs' in history_dict:
            return history_dict['prev_outputs'][0, -1].copy()
        else:
            raise ValueError("历史字典中未找到输出特征")
    
    def sample(self, batch_size):
        """采样一批转换，使用优先级"""
        if self.size < batch_size:
            return None
        self.beta = min(1.0, self.beta + self.beta_increment)
        priorities = self.priorities[:self.size] ** self.alpha
        probs = priorities / np.sum(priorities)
        indices = np.random.choice(self.size, batch_size, replace=False, p=probs)
        weights = (self.size * probs[indices]) ** (-self.beta)
        weights /= weights.max()
        batch = [self.buffer[idx] for idx in indices]
        
        return batch, indices, weights
    
    def update_priorities(self, indices, priorities):
        """更新转换优先级"""
        for idx, priority in zip(indices, priorities):
            self.priorities[idx] = priority
    
    def __len__(self):
        return self.size
    
    def clear(self):
        self.buffer = []
        self.priorities = np.zeros((self.capacity,), dtype=np.float32)
        self.position = 0
        self.size = 0
        self.episode_buffer = []