import os
import torch
from torch.utils.data import Dataset
from data.utils import load_openxai_dataset, set_seed

class HELOCNoisy(Dataset):
    def __init__(self, train=True, download=False, scale='minmax'):
        self.name = 'heloc_noisy'
        self.data, self.targets = load_openxai_dataset(name='heloc', train=train, download=download, scale=scale)
        self.classes = torch.unique(self.targets)
        if train:
            # Load flipped indices
            loc = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'flipped_indices.pt')
            flipped_indices = torch.load(loc)
            self.targets[flipped_indices] = 1 - self.targets[flipped_indices]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        idx = idx.tolist() if isinstance(idx, torch.Tensor) else idx
        return (self.data[idx], self.targets[idx])
    
    def get_num_samples(self):
        return self.data.shape[0]
        
    def get_num_features(self):
        return self.data.shape[1]

if __name__ == "__main__":
    # Load dataset and generate noisy mask for the labels with seed 42
    from datasets import HELOC
    n_train = HELOC(train=True, download=True).data.shape[0]
    set_seed(42)
    flipped_indices = torch.randperm(n_train)[:int(n_train*0.1)]
    loc = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'flipped_indices.pt')
    torch.save(flipped_indices, loc)