import os
import torch
from torch.utils.data import Dataset
from data.qnli.dataset import get_dataset

class QNLINoisy(Dataset):
    def __init__(self, train=True):
        self.name = 'qnli_noisy'
        self.data = get_dataset('train' if train else 'validation')
        if train:
            # Load noisy labels
            folder = os.path.dirname(os.path.abspath(__file__))
            flipped_indices = torch.load(os.path.join(folder, 'flipped_indices.pt'))

            def flip_label(example, idx):
                if idx in flipped_indices:
                    example['label'] = 1 - example['label']
                return example

            # Use map to modify the dataset labels
            self.data = self.data.map(flip_label, with_indices=True)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]
    
    def select(self, indices):
        return self.data.select(indices)


if __name__ == "__main__":
    # Load dataset and generate noisy mask for the labels with seed 42
    from data import QNLI
    from data.utils import set_seed
    qnli_trainset = QNLI(train=True)
    n_train = len(qnli_trainset)
    set_seed(42)
    flipped_indices = torch.randperm(n_train)[:int(n_train*0.1)]
    targets = torch.tensor([ex['label'] for ex in qnli_trainset.data])
    for idx in flipped_indices:
        targets[idx] = 1 - targets[idx]

    folder = os.path.dirname(os.path.abspath(__file__))
    torch.save(flipped_indices, os.path.join(folder, 'flipped_indices.pt'))
    torch.save(targets, os.path.join(folder, 'noisy_labels.pt'))