import torch
import numpy as np
from torch.utils.data import Dataset


def data_generator(N, seq_length):
    """
    Args:
        seq_length: Length of the adding problem data
        N: # of data in the set
    """
    X_num = torch.rand([N, 1, seq_length])
    X_mask = torch.zeros([N, 1, seq_length])
    Y = torch.zeros([N, 1])
    for i in range(N):
        positions = np.random.choice(seq_length, size=2, replace=False)
        X_mask[i, 0, positions[0]] = 1
        X_mask[i, 0, positions[1]] = 1
        Y[i,0] = X_num[i, 0, positions[0]] + X_num[i, 0, positions[1]]
    X = torch.cat((X_num, X_mask), dim=1)
    return X,Y



class AddingProblemDataset(Dataset):
    def __init__(self, N, seq_length):
        """
        Args:
            N (int): Number of data points in the dataset.
            seq_length (int): Length of the sequences.
        """
        self.X, self.Y = data_generator(N, seq_length)

    def __len__(self):
        """
        Returns the total number of samples in the dataset.
        """
        return len(self.Y)

    def __getitem__(self, idx):
        """
        Args:
            idx (int): Index of the sample to retrieve.

        Returns:
            A tuple (X, Y) where X is the input tensor and Y is the target tensor.
        """
        return self.X[idx], self.Y[idx]