from tensorflow.keras.utils import Sequence
import numpy as np

EPS = 1e-8


class OSBatchGenerator(Sequence):

    def __init__(self, X, R, config, batch_size, randomize=True):
        self.X = X
        self.R = R
        self.N = X.shape[0]
        self.L = X.shape[1]
        self.B = batch_size
        self.num_samples_per_epoch = config['samples_per_epoch']
        self.randomize = randomize
        self.seq_idxs = np.arange(self.N)
        self.num_batches = np.ceil(len(self.seq_idxs) / self.B)
        self.include_R = config['include_R']
        np.random.seed(42)

    def __len__(self):
        if self.randomize:
            return self.num_samples_per_epoch
        else:
            return int(self.num_batches)

    def __getitem__(self, idx):

        if self.randomize:
            seq_ids_in_batch = np.random.choice(self.N, (self.B,), replace=True)
            batch_x = self.X[seq_ids_in_batch]
            batch_t = np.tile(np.expand_dims(np.arange(self.L, dtype=float), axis=(0, 2)), (self.B, 1, 1)) / self.L
            batch_y = self.R[seq_ids_in_batch]
        else:
            seq_ids_in_batch = np.arange(idx * self.B, idx * self.B + min(self.B, self.N - idx * self.B))
            batch_x = self.X[seq_ids_in_batch]
            batch_t = np.tile(np.expand_dims(np.arange(self.L, dtype=float), axis=(0, 2)), (len(seq_ids_in_batch), 1, 1)) / self.L
            batch_y = self.R[seq_ids_in_batch]

        if self.include_R:
            return [batch_t, np.expand_dims(batch_y, axis=2), batch_x], [batch_y]
        else:
            return [batch_t, batch_x], [batch_y]
