import random
from typing import List, Tuple

from algorithms.abstract.replay_buffer import ReplayBuffer
from algorithms.mu_zero.game_history import MuZeroGameHistory
from algorithms.utils.types import GameSampleData, StateFeature, TrainingTarget, ActionImage, TrainingBatch


class MuZeroReplayBuffer(ReplayBuffer):
    """ReplayBuffer of fixed size with a FIFO replacement policy.

    Stored transitions can be sampled uniformly.

    The underlying datastructure is a ring buffer, allowing 0(1) adding and
    sampling.
    """

    def __init__(self, replay_buffer_capacity):
        ReplayBuffer.__init__(self, replay_buffer_capacity)
        self.num_actions = 0
        self._data = []  # type: List[MuZeroGameHistory]

    def add(self, game_history: MuZeroGameHistory) -> None:
        """Adds `game_history` to the buffer.

        If the buffer is full, the oldest element will be replaced.

        Args:
          game_history: data to be added to the buffer.
        """
        self.num_actions += len(game_history)
        if len(self._data) < self._replay_buffer_capacity:
            self._data.append(game_history)
        else:
            self.num_actions -= len(self._data[self._next_entry_index])
            self._data[self._next_entry_index] = game_history
            self._next_entry_index += 1
            self._next_entry_index %= self._replay_buffer_capacity

    def sample_game_data_single(self, k: int) -> GameSampleData:
        """Returns `num_samples` uniformly sampled from the buffer.

        Args:
            k: `int`, the number of moves to unroll

        Returns:
          An iterable over `num_samples` random elements of the buffer.

        Raises:
          ValueError: If there are less than `num_samples` elements in the buffer
        """
        game_history = random.sample(self._data, 1).pop()
        return game_history.sample_data(k)

    def __len__(self) -> int:
        return len(self._data)

    def __iter__(self):
        return iter(self._data)

    def sample(self, batch_size: int = 256, k: int = 12) -> TrainingBatch:
        game_data = [self.sample_game_data_single(k) for _ in range(batch_size)]  # type: List[GameSampleData]

        root_state_features: Tuple[StateFeature, ...]
        action_images: Tuple[List[ActionImage], ...]
        targets: Tuple[List[TrainingTarget], ...]
        state_strings: Tuple[List[str], ...]
        root_state_features, action_images, targets, state_strings = zip(*game_data)

        root_targets: Tuple[TrainingTarget, ...]
        recurrent_targets: List[Tuple[TrainingTarget, ...]]
        root_targets, *recurrent_targets = zip(*targets)

        root_ss: Tuple[str, ...]
        recurr_ss: List[Tuple[str, ...]]
        root_ss, *recurr_ss = zip(*state_strings)

        action_images = list(zip(*action_images))  # type: List[Tuple[ActionImage, ...]]
        batch = root_state_features, root_targets, root_ss, action_images, recurrent_targets, recurr_ss
        return batch

    def get_num_actions(self) -> int:
        return self.num_actions

