import os

import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset



class FeaturesDataset(Dataset):
    """weights dataset."""

    def __init__(self, root='Ftask', dataset='DOG', split='train'):
        super(FeaturesDataset, self).__init__()
        self.dataset = dataset

        self.split = split

        self.root = root

        datapath = os.path.join(root, dataset, f'{split}_data.pt')
        data = torch.load(datapath)
        self.data = data['features']
        self.targets = data['targets']
        self.n_classes = data['n_classes']

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        # print(idx)
        features = self.data[idx].to(torch.float32)

        target = self.targets[idx].to(torch.int64)
        # n_classes = self.n_classes

        return features, target
