from abc import ABCMeta, abstractmethod
from typing import Iterator, List, cast, Any, Optional

import numpy as np

from ..containers import FIFOQueue
from ..dataset import Transition, TransitionMiniBatch

class DropTransitionMiniBatch():
    def __init__(self, batch: TransitionMiniBatch, Rs: Optional[np.ndarray] = None, Inits: Optional[np.ndarray] = None):
        self._batch = batch
        self._Rs = Rs
        self._Inits = Inits
    
    @property
    def observations(self) -> np.ndarray: 
        return self._batch.observations

    @property
    def actions(self) -> np.ndarray: 
        return self._batch.actions

    @property
    def rewards(self) -> np.ndarray:
        return self._batch.rewards

    @property
    def next_observations(self) -> np.ndarray:
        return self._batch.next_observations

    @property
    def transitions(self) -> List[Transition]:
        return self._batch.transitions

    @property
    def terminals(self) -> np.ndarray:
        return self._batch.terminals

    @property
    def n_steps(self) -> np.ndarray:
        return self._batch.n_steps

    def __len__(self) -> int:
        assert False
        # return self._batch.__len__()

    def __iter__(self) -> Iterator[Transition]:
        assert False
        # return self._batch.__iter__()

    @property
    def Rs(self) -> np.ndarray:
        return self._Rs
    
    @property
    def Inits(self) -> np.ndarray:
        return self._Inits

class DropTransitionIterator(metaclass=ABCMeta):

    _transitions: List[Transition]
    _index_transitions: List[List[int]]
    _index_transitions_drop: List[List[int]]
    _index_transitions_init: List[int]
    _discounted_rewards: List[float]
    _generated_transitions: FIFOQueue[Transition]
    _action_space: Any
    _observation_space: Any
    _batch_size: int
    _n_steps: int
    _gamma: float
    _n_frames: int
    _real_ratio: float
    _real_batch_size: int
    _count: int
    _drop_num: int

    def __init__(
        self,
        transitions: List[Transition],
        index_transitions: List[List[int]],
        index_transitions_drop: Optional[List[List[int]]],
        index_transitions_init: List[int],
        discounted_rewards: List[float],
        action_space: Any,
        observation_space: Any,
        batch_size: int,
        n_steps: int = 1,
        gamma: float = 0.99,
        n_frames: int = 1,
        real_ratio: float = 1.0,
        generated_maxlen: int = 100000,
    ):
        self._transitions = transitions
        self._index_transitions = index_transitions
        self._index_transitions_drop = index_transitions_drop
        self._index_transitions_init = index_transitions_init
        self._discounted_rewards = discounted_rewards
        self._action_space = action_space
        self._observation_space = observation_space
        self._generated_transitions = FIFOQueue(generated_maxlen)
        self._batch_size = batch_size
        self._n_steps = n_steps
        self._gamma = gamma
        self._n_frames = n_frames
        self._real_ratio = real_ratio
        self._real_batch_size = batch_size
        self._count = 0

        # assert len(self._index_transitions) == len(self._index_transitions_drop)
        self._drop_num = len(self._index_transitions)
        self._drop_count = [len(x) for x in self._index_transitions]
        self._drop_num_count = np.arange(self._drop_num)[np.array(self._drop_count) > 0]
        self._drop_num_rewards = []
        for rs in self._discounted_rewards:
            if len(rs) > 0:
                self._drop_num_rewards.append(np.mean(rs))
        self._drop_num_rewards = np.array(self._drop_num_rewards)/sum(self._drop_num_rewards)
        self._drop_num_rewards = np.exp(self._drop_num_rewards)
        self._drop_num_rewards = self._drop_num_rewards/self._drop_num_rewards.sum()
        
        # if self._index_transitions_drop is not None:
        #     for temp_a, temp_b in zip(self._index_transitions, self._index_transitions_drop):
        #         assert len(temp_a) > 0 and len(temp_b) > 0
        # else:
        #     for temp_a in self._index_transitions:
        #         assert len(temp_a) > 0


    def __iter__(self) -> Iterator[DropTransitionMiniBatch]:
        self.reset()
        return self

    def __next__(self) -> list:
        if len(self._generated_transitions) > 0:
            real_batch_size = self._real_batch_size
            fake_batch_size = self._batch_size - self._real_batch_size
            transitions = [self.get_next() for _ in range(real_batch_size)]
            transitions += self._sample_generated_transitions(fake_batch_size)
            assert False # TODO: offline + online 
        else:
            nss = []
            Rs = []
            Inits = []
            transitions = []
            transitions_drop = []
            transitions_init = []
            for _ in range(self._batch_size):
                [n, R, Init, temp, temp_drop, temp_init] = self.get_next()
                nss += [n]
                Rs += [R]
                Inits += [Init]
                transitions += [temp]
                transitions_drop += [temp_drop]
                transitions_init += [temp_init]
        
        nss = np.array(nss)
        ns = nss[:, 0]
        ns_drop = nss[:, 1]

        Rs = np.array(Rs).reshape(-1, 1)
        Inits = np.array(Inits).reshape(-1, 1)

        batch = TransitionMiniBatch(
            transitions,
            n_frames=self._n_frames,
            n_steps=self._n_steps,
            gamma=self._gamma,
        )
        batch = DropTransitionMiniBatch(batch, Rs, Inits)

        batch_drop = TransitionMiniBatch(
            transitions_drop,
            n_frames=self._n_frames,
            n_steps=self._n_steps,
            gamma=self._gamma,
        )
        batch_drop = DropTransitionMiniBatch(batch_drop, Rs, Inits)

        batch_init = TransitionMiniBatch(
            transitions_init,
            n_frames=self._n_frames,
            n_steps=self._n_steps,
            gamma=self._gamma,
        )
        batch_init = DropTransitionMiniBatch(batch_init, Rs, Inits)

        self._count += 1

        return [ns, ns_drop, batch, batch_drop, batch_init]

    def reset(self) -> None:
        self._count = 0
        if len(self._generated_transitions) > 0:
            self._real_batch_size = int(self._real_ratio * self._batch_size)
        self._reset()

    @abstractmethod
    def _reset(self) -> None:
        pass

    @abstractmethod
    def _next(self, n: int) -> List[Transition]:
        pass

    @abstractmethod
    def _has_finished(self) -> bool:
        pass

    def add_generated_transitions(self, transitions: List[Transition]) -> None:
        self._generated_transitions.extend(transitions)
        assert False # TODO: offline + online 

    def get_next(self) -> Any:
        if self._has_finished():
            raise StopIteration
        # n = np.random.randint(self._drop_num, size=(2,))
        n = np.random.choice(self._drop_num_count, 2)#, p=self._drop_num_rewards)
        return [n.copy()] + self._next(n[0], n[1])

    def _sample_generated_transitions(
        self, batch_size: int
    ) -> List[Transition]:
        transitions: List[Transition] = []
        n_generated_transitions = len(self._generated_transitions)
        for _ in range(batch_size):
            index = cast(int, np.random.randint(n_generated_transitions))
            transitions.append(self._generated_transitions[index])
        return transitions

    @abstractmethod
    def __len__(self) -> int:
        pass

    def size(self) -> int:
        return len(self._transitions) + len(self._generated_transitions)

    @property
    def transitions(self) -> List[Transition]:
        return self._transitions

    @property
    def generated_transitions(self) -> FIFOQueue[Transition]:
        return self._generated_transitions
