import torch
from torch.utils.data import Dataset

from adversarial_superposition.constants import DEFAULT_MODULO


def one_hot_encode(number, size):
    one_hot = torch.zeros(size)
    one_hot[number] = 1
    return one_hot


def binary_encode(num, bits):
    return torch.tensor([int(b) for b in format(num, f"0{bits}b")])


def generate_random_one_hot(length):
    index = torch.randint(0, length, (1,)).item()
    one_hot_vector = torch.zeros(length)
    one_hot_vector[index] = 1
    return one_hot_vector


class AlgorithmicDataset(Dataset):
    def __init__(self, operation, p=DEFAULT_MODULO, input_size=None, output_size=None):
        self.p = p
        if not input_size:
            self.input_size = p
        else:
            self.input_size = input_size
        self.output_size = output_size if output_size else self.p
        self.operation = operation

        self.data = []
        self.targets = []

        for x in range(0, self.input_size):
            for y in range(0, self.input_size):
                if "div" in operation.__name__ and y == 0:
                    continue
                result = self.operation(x, y) % self.p

                x_one_hot = one_hot_encode(x, self.input_size)
                y_one_hot = one_hot_encode(y, self.input_size)

                combined_input = torch.cat((x_one_hot, y_one_hot), 0)
                self.data.append(combined_input)
                self.targets.append(result)
        self.data = torch.stack(self.data)
        self.targets = torch.tensor(self.targets)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.targets[idx]
