from typing import Optional, TypeVar, List, Any, Sequence, Tuple
import torch

from pfrl.replay_buffers import PrioritizedReplayBuffer
from pfrl.collections.prioritized import PrioritizedBuffer, MinTreeQueue, SumTreeQueue

from utils import batch_experiences

T = TypeVar("T")


class GreedyPrioritizedBuffer(PrioritizedBuffer):
    def __init__(
        self,
        capacity: Optional[int] = None,
        wait_priority_after_sampling: bool = True,
        initial_max_priority: float = 1.0,
    ):
        super().__init__(
            capacity=capacity,
            wait_priority_after_sampling=wait_priority_after_sampling,
            initial_max_priority=initial_max_priority,
        )
        self.prune_priority = MinTreeQueue()
        # for stochastic sampling
        # self.prune_priority = SumTreeQueue()
        # self.t_tree = {}
        self.t = 0
        self.count = 0

    def append(self, value: T, priority: Optional[float] = None, kl: Optional[float] = None) -> None:
        index = None
        self.t += 1
        if self.capacity is not None and len(self) == self.capacity:
            if kl is None:
                index = self.find_least_priority(self.max_priority)
                # for stochastic sampling
                # index = self.sample_priority()
            else:
                index = self.find_least_priority(kl)
                # for stochastic sampling
                # index = self.sample_priority()
            
            if index is None:
                return

        if priority is None:
            # Append with the highest priority
            priority = self.max_priority
        
        if kl is None:
            return

        if index is not None:
            self.data[index] = value
            self.priority_sums[index] = priority
            self.priority_mins[index] = priority
            self.prune_priority[index] = kl
        else:
            self.data.append(value)
            self.priority_sums.append(priority)
            self.priority_mins.append(priority)
            self.prune_priority.append(kl)
    
    def find_least_priority(self, priority):
        min_priority = self.prune_priority.min()
        if min_priority > priority:
            # print(min_priority, priority)
            self.count += 1
            return None
        
        index_left, index_right = self.prune_priority.bounds
        return _find_index(index_left, index_right, self.prune_priority.root, min_priority)

    def set_last_priority(self, priority: Sequence[float], onpolicy_action: Sequence[bool], kls: Optional[Sequence[float]]=None) -> None:
        if kls is not None:
            for i, kl, oa in zip(self.sampled_indices, kls, onpolicy_action):
                self.prune_priority[i] = kl * oa
        else:
            raise ValueError()
            for i, p, oa in zip(self.sampled_indices, priority, onpolicy_action):
                self.prune_priority[i] = p * oa

        super().set_last_priority(priority)

    def sample_priority(self):
        index = self.prune_priority.prioritized_sample(1, remove=True)[0][0]
        return index


def _find_index(index_left: int, index_right: int, node: List[Any], value: float) -> int:
    if index_right - index_left == 1:
        return index_left
    else:
        node_left, node_right, _ = node
        index_center = (index_left + index_right) // 2
        if node_left:
            left_value = node_left[2]
        else:
            left_value = None

        if left_value is not None and value == left_value:
            return _find_index(index_left, index_center, node_left, value)
        else:
            return _find_index(index_center, index_right, node_right, value)


class GreedyReplayBuffer(PrioritizedReplayBuffer):
    def __init__(
        self,
        capacity=None,
        alpha=0.6,
        beta0=0.4,
        betasteps=2e5,
        eps=0.01,
        normalize_by_max=True,
        error_min=0,
        error_max=1,
        num_steps=1,
    ):
        super().__init__(
            capacity=capacity,
            alpha=alpha,
            beta0=beta0,
            betasteps=betasteps,
            eps=eps,
            normalize_by_max=normalize_by_max,
            error_min=error_min,
            error_max=error_max,
            num_steps=num_steps,
        )
        self.memory = GreedyPrioritizedBuffer(capacity=capacity)
        self.up_count = 0
        self.low_count = 0
    
    def append(
        self,
        state,
        action,
        reward,
        next_state=None,
        next_action=None,
        is_state_terminal=False,
        env_id=0,
        agent=None,
        **kwargs
    ):
        last_n_transitions = self.last_n_transitions[env_id]
        experience = dict(
            state=state,
            action=action,
            reward=reward,
            next_state=next_state,
            next_action=next_action,
            is_state_terminal=is_state_terminal,
            **kwargs
        )
        last_n_transitions.append(experience)
        if is_state_terminal:
            while last_n_transitions:
                priority, kl = self.calculate_priority(agent, list(last_n_transitions))
                self.memory.append(list(last_n_transitions), priority, kl)
                del last_n_transitions[0]
            assert len(last_n_transitions) == 0
        else:
            if len(last_n_transitions) == self.num_steps:
                priority, kl = self.calculate_priority(agent, list(last_n_transitions))
                self.memory.append(list(last_n_transitions), priority, kl)

    def calculate_priority(self, agent, experiences):
        assert agent is not None
        exp_batch = batch_experiences(
            [experiences],
            device=agent.device,
            gamma=agent.gamma,
        )
        with torch.no_grad():
            y, t = agent._compute_y_and_t(exp_batch)
            kl = (-t * torch.log(torch.clamp(y, 1e-10, 1.0))).sum(dim=1).detach().cpu().numpy()[0]
        return None, kl
    
    def update_errors(self, errors, onpolicy_action, kls=None):
        self.memory.set_last_priority(self.priority_from_errors(errors), onpolicy_action=onpolicy_action, kls=kls)