import numpy as np
import random
from collections import namedtuple
import torch

#Experience tuple definition
Experience = namedtuple('Experience', ['history_dict', 'action', 'reward', 'next_history_dict', 'goal', 'done'])

class PrioritizedReplayBuffer:
    def __init__(self, capacity, k_future=4, alpha=0.6, beta=0.4, beta_increment=1e-5, reward_threshold=5e-3, reward_mode='combined'):
        """
        Priority Experience Playback Buffer
        
        Args:
            capacity: Buffer capacity
            alpha: parameter that determines the degree of priority usage, 0 means uniform sampling
            beta: Importance sampling index, 0 means no importance sampling is used
            beta_increment: The amount of beta that increases over time
            reward_threshold: the threshold to reach the goal
            reward_mode: reward calculation mode
        """
        self.capacity = capacity
        self.k_future = k_future
        self.alpha = alpha
        self.beta = beta
        self.beta_increment = beta_increment
        self.reward_threshold = reward_threshold
        self.reward_mode = reward_mode
        
        self.buffer = []
        self.priorities = np.zeros((capacity,), dtype=np.float32)
        self.position = 0
        self.size = 0
        
        #Temporary storage of a single episode
        self.episode_buffer = []
        
    def add(self, history_dict, action, reward, next_history_dict, goal, done=False, priority=None):
        """Add one-step transition"""
        if priority is None:
            priority = max(self.priorities) if self.size > 0 else 1.0
            
        if len(self.buffer) < self.capacity:
            self.buffer.append(Experience(history_dict, action, reward, next_history_dict, goal, done))
        else:
            self.buffer[self.position] = Experience(history_dict, action, reward, next_history_dict, goal, done)
            
        self.priorities[self.position] = priority
        self.position = (self.position + 1) % self.capacity
        self.size = min(self.size + 1, self.capacity)
        
    def add_to_episode(self, history_dict, action, next_history_dict, goal):
        """Add a transformation to the current episode buffer"""
        self.episode_buffer.append((history_dict, action, next_history_dict, goal))
    
    def _compute_reward(self, distance, done, prev_distance=None):
        """Calculate rewards based on reward_mode"""
        if self.reward_mode == 'binary':
            return 0.0 if done else -1.0
        
        elif self.reward_mode == 'distance':
            return -distance
        
        elif self.reward_mode == 'exp':
            return np.exp(-10 * distance) - 1
        
        elif self.reward_mode == 'progress':
            if prev_distance is None:
                return -distance
            progress = prev_distance - distance
            if progress > 0:
                return progress * 2.0
            else:
                return progress * 3.0
        
        elif self.reward_mode == 'step':
            if done:
                return 0.0
            elif distance < self.reward_threshold * 3:
                return -0.3
            elif distance < self.reward_threshold * 10:
                return -0.7
            else:
                return -1.0
        
        elif self.reward_mode == 'combined':
            distance_reward = -distance * 2.0
            goal_reward = 10.0 if done else 0.0
            return distance_reward + goal_reward
        
        else:
            return 0.0 if done else -1.0
    
    def process_episode(self):
        """Process and store the current episode, apply her policy"""
        if len(self.episode_buffer) == 0:
            return
        
        #Record distance for each step
        step_distances = {}
        
        #Collect final status as a potential target
        final_states = []
        for t in range(min(4, len(self.episode_buffer))):
            if t >= len(self.episode_buffer) - 1:
                break
            idx = len(self.episode_buffer) - 1 - t
            _, _, next_hist, _ = self.episode_buffer[idx]
            final_states.append(self._get_current_output(next_hist))
            
        #Process each conversion
        for t, (history_dict, action, next_history_dict, goal) in enumerate(self.episode_buffer):
            #Calculate distance to original target
            current_output = self._get_current_output(next_history_dict)
            distance = np.linalg.norm(current_output - goal)
            done = distance < self.reward_threshold
            
            #Calculate Rewards
            prev_distance = step_distances.get((t-1, tuple(goal.flatten())), None)
            reward = self._compute_reward(distance, done, prev_distance)
            
            #Record current distance
            step_distances[(t, tuple(goal.flatten()))] = distance
            
            #Add to playback buffer
            priority = 1.0 if done else distance  #Higher conversion priority for reaching or approaching goals
            self.add(history_dict, action, reward, next_history_dict, goal, done, priority)
            
            #Use end state as alternate target
            for final_state in final_states:
                new_goal = final_state
                new_distance = np.linalg.norm(current_output - new_goal)
                new_done = new_distance < self.reward_threshold
                
                #Calculate new rewards
                new_goal_key = tuple(new_goal.flatten())
                prev_new_distance = step_distances.get((t-1, new_goal_key), None)
                new_reward = self._compute_reward(new_distance, new_done, prev_new_distance)
                
                #Record current distance
                step_distances[(t, new_goal_key)] = new_distance
                
                #Add to playback buffer, set higher priority
                new_priority = 1.0 if new_done else new_distance
                self.add(history_dict, action, new_reward, next_history_dict, new_goal, new_done, new_priority)
                
            #Use random future state as target
            future_idx_range = range(t+1, len(self.episode_buffer))
            future_idxs = random.sample(future_idx_range, min(self.k_future, len(future_idx_range))) if future_idx_range else []
            
            for future_idx in future_idxs:
                future_transition = self.episode_buffer[future_idx]
                future_next_history_dict = future_transition[2]
                
                #Use future output as new target
                new_goal = self._get_current_output(future_next_history_dict)
                
                #Calculate distance from new target
                new_distance = np.linalg.norm(current_output - new_goal)
                new_done = new_distance < self.reward_threshold
                
                #Calculate new rewards
                new_goal_key = tuple(new_goal.flatten())
                prev_new_distance = step_distances.get((t-1, new_goal_key), None)
                new_reward = self._compute_reward(new_distance, new_done, prev_new_distance)
                
                #Record current distance
                step_distances[(t, new_goal_key)] = new_distance
                
                #Add to playback buffer
                new_priority = 1.0 if new_done else new_distance
                self.add(history_dict, action, new_reward, next_history_dict, new_goal, new_done, new_priority)
        
        #Empty episode buffer
        self.episode_buffer = []
    
    def _get_current_output(self, history_dict):
        """Get current output from historical dictionary"""
        if 'outputs' in history_dict:
            return history_dict['outputs'][0, -1].copy()
        elif 'prev_outputs' in history_dict:
            return history_dict['prev_outputs'][0, -1].copy()
        else:
            raise ValueError("Output feature not found in historical dictionary")
    
    def sample(self, batch_size):
        """Sample a batch of conversions, using priority"""
        if self.size < batch_size:
            return None
        
        #Update beta values
        self.beta = min(1.0, self.beta + self.beta_increment)
        
        #Calculate sampling probability
        priorities = self.priorities[:self.size] ** self.alpha
        probs = priorities / np.sum(priorities)
        
        #Sampled Index and Calculated Importance Weights
        indices = np.random.choice(self.size, batch_size, replace=False, p=probs)
        weights = (self.size * probs[indices]) ** (-self.beta)
        weights /= weights.max()
        
        #Get Sample
        batch = [self.buffer[idx] for idx in indices]
        
        return batch, indices, weights
    
    def update_priorities(self, indices, priorities):
        """Update Conversion Priority"""
        for idx, priority in zip(indices, priorities):
            self.priorities[idx] = priority
    
    def __len__(self):
        return self.size
    
    def clear(self):
        self.buffer = []
        self.priorities = np.zeros((self.capacity,), dtype=np.float32)
        self.position = 0
        self.size = 0
        self.episode_buffer = []