from abc import ABCMeta, abstractmethod
from typing import Iterator, List, cast, Any, Optional

import numpy as np

from ..dataset import Transition, TransitionMiniBatch


class InitDropTransitionIterator(metaclass=ABCMeta):

    _transitions: List[Transition]
    _index_transitions_init: List[int]
    _batch_size: int
    _n_steps: int
    _gamma: float
    _n_frames: int
    _count: int
    _drop_num: int

    def __init__(
        self,
        drop_num: int, 
        transitions: List[Transition],
        index_transitions_init: List[int],
        batch_size: int,
        n_steps: int = 1,
        gamma: float = 0.99,
        n_frames: int = 1,
    ):
        self._drop_num = drop_num
        self._transitions = transitions
        self._index_transitions_init = index_transitions_init
        self._batch_size = batch_size
        self._n_steps = n_steps
        self._gamma = gamma
        self._n_frames = n_frames
        self._count = 0


    def __iter__(self) -> Iterator[TransitionMiniBatch]:
        self.reset()
        return self

    def __next__(self) -> list:
        ns = []
        transitions_init = []
        for _ in range(min(self._batch_size, self._drop_num-self._count)):
            [n, temp] = self.get_next()
            ns += [n]
            transitions_init += [temp]
            self._count += 1
        
        ns = np.array(ns)

        batch_init = TransitionMiniBatch(
            transitions_init,
            n_frames=self._n_frames,
            n_steps=self._n_steps,
            gamma=self._gamma,
        )

        # self._count += 1

        return [ns, batch_init]

    def reset(self) -> None:
        self._count = 0
        self._reset()

    @abstractmethod
    def _reset(self) -> None:
        pass

    @abstractmethod
    def _next(self, n: int) -> List[Transition]:
        pass

    @abstractmethod
    def _has_finished(self) -> bool:
        pass

    def get_next(self) -> Any:
        if self._has_finished():
            raise StopIteration
        n = self._count % self._drop_num
        return [n, ] + self._next()

    @abstractmethod
    def __len__(self) -> int:
        pass

    def size(self) -> int:
        return len(self._transitions)

    @property
    def transitions(self) -> List[Transition]:
        return self._transitions
