import numpy as np
from torch.utils.data import Dataset


def extract_digits():
    idx = []
    for k in range(100):
        if k//10 % 10 < 5 and k % 10 < 5:
            idx.append(k)
        elif k//10 % 10 >= 9 and k % 10 >= 5:
            idx.append(k)
        if k//10 % 10 >= 5 and k//10 % 10 <= 8 :
            idx.append(k)
    return np.array(idx)


training_idx = extract_digits()


class MyDataset_X(Dataset):
    def __init__(self, X):
        self.X = X
    def __len__(self):
        return len(self.X)
    def __getitem__(self, idx):
        return self.X[idx]

class MyDataset_X2(Dataset):
    def __init__(self, X):
        # X:(num_examples, num_samples, dim)
        self.X = X
        self.num_examples = X.shape[0]
    def __len__(self):
        return self.num_examples
    def __getitem__(self, idx):
        return self.X[idx]
    
class MyDataset_XU(Dataset):
    def __init__(self, X, U):
        self.X = X
        self.U = U
    def __len__(self):
        return len(self.U)
    def __getitem__(self, idx):
        return self.X[idx], self.U[idx]
    
class MyDataset_XU2(Dataset):
    def __init__(self, X, U):
        # X:(num_examples, num_samples, X_dim), U:(num_examples, U_dim)
        self.X = X
        self.U = U
        self.num_examples = X.shape[0]
    def __len__(self):
        return self.num_examples
    def __getitem__(self, idx):
        return self.X[idx], self.U[idx]
    