from typing import List, cast, Optional, Any, Sequence, Union

import numpy as np

from ..dataset import Transition
from .base_drop import DropTransitionIterator
from .base_drop_init import InitDropTransitionIterator

def getRandomTransition(
        temp_transition: Transition,
        observation: np.ndarray,
        action: Union[int, np.ndarray],
        next_observation: np.ndarray,
        reward: float = 0.0,
        terminal: float = 1.0,
    ):
        return Transition(temp_transition.get_observation_shape(), 
                       temp_transition.get_action_size(), 
                       observation, 
                       action, 
                       reward, 
                       next_observation, 
                       terminal)
                    


class DropRandomIterator(DropTransitionIterator):

    _n_steps_per_epoch: int

    def __init__(
        self,
        transitions: List[Transition],
        index_transitions: List[List[int]],
        index_transitions_drop: Optional[List[List[int]]],
        index_transitions_init: List[int],
        discounted_rewards: List[float],
        action_space: Any,
        observation_space: Any,
        n_steps_per_epoch: int,
        batch_size: int,
        n_steps: int = 1,
        gamma: float = 0.99,
        n_frames: int = 1,
        real_ratio: float = 1.0,
        generated_maxlen: int = 100000,
    ):
        super().__init__(
            transitions=transitions,
            index_transitions=index_transitions,
            index_transitions_drop=index_transitions_drop,
            index_transitions_init=index_transitions_init,
            discounted_rewards=discounted_rewards,
            action_space = action_space,
            observation_space = observation_space,
            batch_size=batch_size,
            n_steps=n_steps,
            gamma=gamma,
            n_frames=n_frames,
            real_ratio=real_ratio,
            generated_maxlen=generated_maxlen,
        )
        self._n_steps_per_epoch = n_steps_per_epoch

    def _reset(self) -> None:
        pass

    def _next(self, n: int, n_drop: int) -> List[Transition]:
        index = cast(int, np.random.randint(len(self._index_transitions[n])))
        index_ = self._index_transitions[n][index]
        if np.random.rand() < 0.01:
            index_ = np.random.choice(self._index_transitions_init[n])
        transition = self._transitions[index_]
        # if self._index_transitions_drop is not None:
        #     index_drop = cast(int, np.random.randint(len(self._index_transitions_drop[n_drop])))
        #     index_drop = self._index_transitions_drop[n_drop][index_drop]
        # else:
        #     index_drop = cast(int, np.random.randint(len(self._transitions)))
        # transition_drop = self._transitions[index_drop]
        transition_drop = getRandomTransition(transition, 
                                           self._observation_space.sample(),
                                           self._action_space.sample(),
                                           self._observation_space.sample())
        index_init = cast(int, np.random.randint(len(self._index_transitions_init[n])))
        index_init_ = self._index_transitions_init[n][index_init]
        transition_init = self._transitions[index_init_]

        
        is_init = index_ in self._index_transitions_init[n]
        if is_init:
            R = self._discounted_rewards[n][self._index_transitions_init[n].index(index_)]
        else:
            R = 0.
        # is_init = True
        # R = self._transitions_rss[index_]
        return [R, is_init, transition, transition_drop, transition_init]

    def _has_finished(self) -> bool:
        return self._count >= self._n_steps_per_epoch

    def __len__(self) -> int:
        return self._n_steps_per_epoch


import math

class InitDropRandomIterator(InitDropTransitionIterator):

    _n_steps_per_epoch: int

    def __init__(
        self,
        drop_num: int,
        transitions: List[Transition],
        index_transitions_init: List[int],
        n_steps_per_epoch: int,
        batch_size: int,
        n_steps: int = 1,
        gamma: float = 0.99,
        n_frames: int = 1,
    ):
        super().__init__(
            drop_num=drop_num,
            transitions=transitions,
            index_transitions_init=index_transitions_init,
            batch_size=batch_size,
            n_steps=n_steps,
            gamma=gamma,
            n_frames=n_frames,
        )
        self._n_steps_per_epoch = n_steps_per_epoch

    def _reset(self) -> None:
        pass

    def _next(self) -> List[Transition]:
        index_init = cast(int, np.random.randint(len(self._index_transitions_init)))
        index_init = self._index_transitions_init[index_init]
        transition_init = self._transitions[index_init]
        return [transition_init]

    def _has_finished(self) -> bool:
        return self._count >= self._drop_num

    def __len__(self) -> int:
        return math.ceil(self._drop_num/self._batch_size)