from typing import List, cast, Optional

import numpy as np

from ..dataset import Transition
from .base_drop import DropTransitionIterator


class DropRoundIterator(DropTransitionIterator):

    _shuffle: bool
    _indices: np.ndarray
    _index: int

    def __init__(
        self,
        transitions: List[Transition],
        index_transitions: List[List[int]],
        index_transitions_drop: Optional[List[List[int]]],
        batch_size: int,
        n_steps: int = 1,
        gamma: float = 0.99,
        n_frames: int = 1,
        real_ratio: float = 1.0,
        generated_maxlen: int = 100000,
        shuffle: bool = True,
    ):
        super().__init__(
            transitions=transitions,
            index_transitions=index_transitions,
            index_transitions_drop=index_transitions_drop,
            batch_size=batch_size,
            n_steps=n_steps,
            gamma=gamma,
            n_frames=n_frames,
            real_ratio=real_ratio,
            generated_maxlen=generated_maxlen,
        )
        self._shuffle = shuffle
        self._indices = np.arange(len(self._transitions))
        self._index = 0

    def _reset(self) -> None:
        self._indices = np.arange(len(self._transitions))
        if self._shuffle:
            np.random.shuffle(self._indices)
        self._index = 0

    def _next(self, n: int, n_drop: int) -> List[Transition]:
        # transition = self._transitions[cast(int, self._indices[self._index])]
        # self._index += 1
        # return transition
        index = cast(int, np.random.randint(len(self._index_transitions[n])))
        index = self._index_transitions[n][index]
        transition = self._transitions[index]
        if self._index_transitions_drop is not None:
            index_drop = cast(int, np.random.randint(len(self._index_transitions_drop[n_drop])))
            index_drop = self._index_transitions_drop[n_drop][index_drop]
        else:
            index_drop = cast(int, np.random.randint(len(self._transitions)))
        transition_drop = self._transitions[index_drop]
        index_init = self._index_transitions[n][0]
        transition_init = self._transitions[index_init]
        self._index += 1
        return [transition, transition_drop, transition_init]

    def _has_finished(self) -> bool:
        return self._index >= len(self._transitions)

    def __len__(self) -> int:
        return len(self._transitions) // self._real_batch_size
