import numpy as np
from torch.utils.data import Dataset

class WeightedDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
        self.weights = np.ones(len(dataset))
        self.scores = np.zeros(len(dataset))
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        idx = int(idx)
        data = {k:v for k,v in self.dataset[idx].items()}
        data.update({'sample_idx': idx, 'weight': self.weights[idx]})
        return data
    
    def update_scores(self, indices, values):
        self.scores[indices] = values
    
    def update_weights(self, indices, values):
        self.weights[indices] = values
    
    def reset_weights(self):
        self.weights = np.ones(len(self.dataset)) 