# prioritized_replay.py
"""
Prioritized Experience Replay (PER) implementation.
Adapted and modified from:
https://github.com/pythonlessons/Reinforcement_Learning/tree/master/05_CartPole-reinforcement-learning_PER_D3QN
Code related to the paper: Prioritized Experience Replay
"""

import numpy as np
from .rl_utils import (MAX_SIZE_MEMORY, ABS_ERROR_UPPER, ALPHA, BETA_INCREMENT, EPSILON, MINIBATCH_SIZE, logging)

class PER:
    def __init__(self, capacity: int = MAX_SIZE_MEMORY, absolute_error_upper: float = ABS_ERROR_UPPER) -> None:
        """Initialize the SumTree-based PER memory."""
        self.capacity = capacity
        self.absolute_error_upper = absolute_error_upper
        self.write = 0
        self.n_entries = 0
        self.beta = 0.4
        self.sumtree = np.zeros(2 * capacity - 1)
        self.data = np.zeros(capacity, dtype=object)
        logging.info("PER initialized with capacity %d", capacity)

    def add(self, transition: tuple, priority: float) -> None:
        """Add a transition to the memory with given priority."""
        if not isinstance(transition, tuple):
            logging.error("Transition must be a tuple, got %s", type(transition))
            raise ValueError("Transition must be a tuple")

        sumtree_idx = self.write + self.capacity - 1
        self.data[self.write] = transition
        self._update(sumtree_idx, priority)
        
        self.write += 1
        if self.write >= self.capacity:
            self.write = 0
        
        if self.n_entries < self.capacity:
            self.n_entries += 1

    def _update(self, sumtree_idx: int, priority: float) -> None:
        """Update the SumTree with a new priority value."""
        if sumtree_idx >= len(self.sumtree):
            logging.error("Invalid SumTree index: %d", sumtree_idx)
            raise IndexError("SumTree index out of bounds")

        change = priority - self.sumtree[sumtree_idx]
        self.sumtree[sumtree_idx] = priority
        
        while sumtree_idx != 0:
            sumtree_idx = (sumtree_idx - 1) // 2
            self.sumtree[sumtree_idx] += change

    def get_leaf(self, value: float) -> tuple:
        """Retrieve a leaf node from the SumTree based on a sampled value."""
        parent_idx = 0
        while True:
            left_child_idx = 2 * parent_idx + 1
            right_child_idx = left_child_idx + 1

            if left_child_idx >= len(self.sumtree):
                leaf_idx = parent_idx
                break
            elif value <= self.sumtree[left_child_idx]:
                parent_idx = left_child_idx
            else:
                value -= self.sumtree[left_child_idx]
                parent_idx = right_child_idx

        transition_idx = leaf_idx - self.capacity + 1
        if transition_idx < 0 or transition_idx >= self.capacity:
            logging.error("Invalid transition index: %d", transition_idx)
            raise IndexError("Transition index out of bounds")

        return leaf_idx, self.sumtree[leaf_idx], self.data[transition_idx]

    def sample(self, n: int = MINIBATCH_SIZE) -> tuple:
        """Sample a minibatch of transitions based on priorities."""
        if self.n_entries < n:
            logging.warning("Not enough entries to sample %d transitions, only %d available", n, self.n_entries)
            n = self.n_entries

        b_idx = np.empty(n, dtype=np.int32)
        minibatch = []
        ISWeights = np.empty(n, dtype=np.float32)

        priority_segment = self.sumtree[0] / n
        self.beta = np.min([1.0, self.beta + BETA_INCREMENT])

        for i in range(n):
            a, b = priority_segment * i, priority_segment * (i + 1)
            value = np.random.uniform(a, b)
            idx, priority, transition = self.get_leaf(value)
            b_idx[i] = idx
            prob = priority / self.sumtree[0]
            ISWeights[i] = np.power(self.n_entries * prob, -self.beta)
            minibatch.append(transition)

        ISWeights /= np.max(ISWeights) if ISWeights.size > 0 else 1.0  # Avoid division by zero
        return b_idx, minibatch, ISWeights

    def batch_update(self, sumtree_idx: np.ndarray, abs_errors: np.ndarray) -> None:
        """Update priorities for a batch of transitions."""
        abs_errors = np.clip(abs_errors + EPSILON, 0, self.absolute_error_upper)
        ps = np.power(abs_errors, ALPHA)
        
        for ti, p in zip(sumtree_idx, ps):
            self._update(ti, p)

    def get_max(self) -> float:
        """Get the maximum priority in the SumTree."""
        if self.n_entries == 0:
            logging.debug("SumTree is empty, returning default priority 0.0")
            return 0.0
        # Slice only up to n_entries to avoid empty array
        valid_priorities = self.sumtree[-self.capacity:][:self.n_entries]
        return np.max(valid_priorities)

    def get_beta(self) -> float:
        """Get the current beta value."""
        return self.beta
