#! -*- coding: utf-8
from logging import getLogger

import numpy as np
import torch

__all__ = ["InnerLoopSampler"]


class InnerLoopSampler(torch.utils.data.sampler.Sampler[int]):
    def __init__(self, dataset: torch.utils.data.Dataset, batch_size: int = 32, inner_loop: int = None,
                 shuffle: bool = False, seed: int = None):
        self.logger = getLogger(__name__)

        self.dataset = dataset
        self.batch_size = batch_size
        self.inner_loop = inner_loop
        self.shuffle = shuffle
        # random state
        self.rs = np.random.RandomState(seed)

        self.noffset = 0
        self.indices = []

        self.ndata = self.batch_size * self.inner_loop \
            if self.inner_loop is not None else len(self.dataset)

    def __make_indices__(self):
        while len(self.indices) < self.ndata:
            idxs = np.arange(len(self.dataset))
            if self.shuffle:
                idxs = self.rs.permutation(idxs)
            # self.indices += [int(i) for i in idxs]
            self.indices.extend([int(i) for i in idxs])

    def __len__(self):
        return self.ndata

    def __iter__(self):
        self.__make_indices__()
        # sidx = self.noffset * self.batch_size
        indices = self.indices[:self.ndata]
        self.indices  = self.indices[self.ndata:]

        yield from indices

    def __str__(self):
        return f"{self.__class__.__name__} " + ", ".join([f"dataset={len(self.dataset)}",
                                                          f"batch_size={self.batch_size}",
                                                          f"inner_loop={self.inner_loop}",
                                                          f"shuffle={self.shuffle}",])
