import torch
import numpy as np
from torch.utils.data import Dataset

class SineDataset(Dataset):
    def __init__(self, num_classes=2, num_samples=50000, sample_range=[-3.5, 3.5], intercept=1):
        self.function = np.sin
        self.intercepts = [(i - ((num_classes - 1) / 2)) * intercept for i in range(num_classes)]
        self.sample_xs = np.linspace(start=sample_range[0], stop=sample_range[1], num=num_samples)
        xlim = sample_range
        ylim = [(min(self.intercepts) - 1), (max(self.intercepts) + 1)]
        self.domain = tuple([xlim, ylim])

    def __len__(self):
        return len(self.sample_xs) * len(self.intercepts)

    def __getitem__(self, idx):
        class_label = idx // len(self.sample_xs)
        x = self.sample_xs[idx % len(self.sample_xs)]
        y = self.function(x) + self.intercepts[class_label]
        return torch.tensor([x, y]).float(), torch.tensor(class_label)