"""
synthetic_classification.py
Synthetic datasets for low-sample classification.
"""

# At top of experiment file
import importlib
pkg = "Code"  # package directory name in this repo; adjust if you rename the folder
datasets = importlib.import_module(f"{pkg}.datasets")
climate_agriculture = getattr(datasets, "climate_agriculture")
healthcare_sparse   = getattr(datasets, "healthcare_sparse")


import sklearn.datasets as ds
import torch


def generate_blobs(n_samples=100, n_features=2, n_classes=2, seed=0):
    X, y = ds.make_blobs(n_samples=n_samples, centers=n_classes, n_features=n_features, random_state=seed)
    return torch.tensor(X, dtype=torch.float32), torch.tensor(y, dtype=torch.long)


def generate_spirals(n_samples=100, noise=0.1, seed=0):
    X, y = ds.make_moons(n_samples=n_samples, noise=noise, random_state=seed)
    return torch.tensor(X, dtype=torch.float32), torch.tensor(y, dtype=torch.long)
