import random

import torch
from torch.utils.data import Dataset


class RandomDataset(Dataset):

    def __init__(self, max_value, num_classes):
        self.max_value = max_value
        self.num_classes = num_classes
        self.random_labels = []

        for i in range(max_value + 1):
            self.random_labels.append(random.randint(0, num_classes - 1))
        print(self.random_labels)

    def __len__(self):
        return self.max_value + 1

    def __getitem__(self, index):
        data = torch.tensor([index], dtype=torch.float32)
        label = torch.tensor(self.random_labels[index], dtype=torch.long)
        return data, label


class ModuloDataset(Dataset):

    def __init__(self, max_value, modulo_value, binary=False):
        self.max_value = max_value
        self.modulo_value = modulo_value
        self.binary = binary

    @staticmethod
    def int2bits(i, fill=20):
        return list(map(int, bin(i)[2:].zfill(fill)))

    def __len__(self):
        return self.max_value + 1

    def __getitem__(self, index):
        number = self.int2bits(index) if self.binary else [index]
        data = torch.tensor(number, dtype=torch.float32)
        label = torch.tensor(index % self.modulo_value, dtype=torch.long)
        return data, label


class BinaryDataset(Dataset):

    def __init__(self, max_value):
        self.max_value = max_value

    @staticmethod
    def int2bits(i, fill=20):
        return list(map(int, bin(i)[2:].zfill(fill)))

    def __len__(self):
        return self.max_value + 1

    def __getitem__(self, index):
        data = torch.tensor([index], dtype=torch.float32)
        label = torch.tensor(self.int2bits(index), dtype=torch.float32)
        return data, label
