from typing import Generator, List, Optional

import numpy as np
import torch

from .dataset import EpisodeDataset
from .segment import SegmentId


class BatchSampler(torch.utils.data.Sampler):
    def __init__(self, dataset: EpisodeDataset, batch_size: int, sequence_length: int, can_sample_beyond_end: bool, sample_weights: Optional[List] = None) -> None:
        super().__init__(dataset)
        assert isinstance(dataset, EpisodeDataset)
        self.dataset = dataset
        self.sample_weights = sample_weights
        self.batch_size = batch_size
        self.sequence_length = sequence_length
        self.can_sample_beyond_end = can_sample_beyond_end

    def __len__(self):
        raise NotImplementedError

    def __iter__(self) -> Generator[List[SegmentId], None, None]:
        while True:
            yield self.sample()

    def sample(self) -> List[SegmentId]:
        num_episodes = self.dataset.num_episodes
        weights = self.sample_weights
        num_weights = len(self.sample_weights) if self.sample_weights is not None else 0

        if num_weights > num_episodes or num_weights == 0:
            weights = None
        else:
            assert all([0 <= x <= 1 for x in weights]) and sum(weights) == 1
            sizes = [num_episodes // num_weights + (num_episodes % num_weights) * (i == num_weights - 1) for i in range(num_weights)]
            weights = [w / s for (w, s) in zip(weights, sizes) for _ in range(s)]

        episode_ids = np.random.choice(np.arange(self.dataset.num_episodes), size=self.batch_size, replace=True, p=weights)
        timesteps = np.random.randint(low=0, high=self.dataset.lengths[episode_ids])

        # padding allowed, both before start and after end
        if self.can_sample_beyond_end:
            starts = timesteps - np.random.randint(0, self.sequence_length, len(timesteps))
            stops = starts + self.sequence_length

        # padding allowed only before start
        else:
            stops = np.minimum(self.dataset.lengths[episode_ids], timesteps + 1 + np.random.randint(0, self.sequence_length, len(timesteps)))
            starts = stops - self.sequence_length

        return [SegmentId(*x) for x in zip(episode_ids, starts, stops)]
