import torch
import numpy as np
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset, Dataset
import numpy as np

def data_generator(num_samples=200, input_dim=50):
    torch.manual_seed(42)
    np.random.seed(42)


    n_samples = num_samples 
    input_dim = input_dim   
    X = torch.randn(n_samples, input_dim)  


    true_w = torch.randn(input_dim, 1)  
    true_v = torch.randn(input_dim, 1)  


    true_u = torch.randn(input_dim, 1)  
    y = X @ true_w + (X**2) @ true_v + 0.1 * torch.randn(n_samples, 1)

    return X, y


class CIFAR10OneHot(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
        self.num_classes = 10

    def __getitem__(self, index):
        data, label = self.dataset[index]

        one_hot = torch.zeros(self.num_classes)
        one_hot[label] = 1
        return data, one_hot

    def __len__(self):
        return len(self.dataset)

def get_cifar10_loaders(batch_size, num_samples=1000):


    torch.manual_seed(42)
    np.random.seed(42)
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])


    train_dataset = datasets.CIFAR10(
        root='./official_data', 
        train=True,
        download=True,
        transform=transform
    )

    test_dataset = datasets.CIFAR10(
        root='./official_data', 
        train=False,
        download=True,
        transform=transform
    )


    train_indices = np.random.choice(len(train_dataset), num_samples, replace=False)
    test_indices = np.random.choice(len(test_dataset), num_samples, replace=False)
    

    train_subset = Subset(train_dataset, train_indices)
    test_subset = Subset(test_dataset, test_indices)
    

    train_dataset_onehot = CIFAR10OneHot(train_subset)
    test_dataset_onehot = CIFAR10OneHot(test_subset)
    

    train_loader = DataLoader(
        train_dataset_onehot,
        batch_size=batch_size,
        shuffle=True
    )
    
    test_loader = DataLoader(
        test_dataset_onehot,
        batch_size=batch_size,
        shuffle=False
    )

    return train_dataset_onehot, train_loader, test_dataset_onehot, test_loader