import os

import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset



class FeaturesDataset(Dataset):
    """weights dataset."""

    def __init__(self, root='Ftask', dataset='DOG', split='train'):
        super(FeaturesDataset, self).__init__()
        self.dataset = dataset

        self.split = split

        self.root = root
        # dpath = os.path.join(root, dataset)
        # fn = os.listdir(dpath)
        datapath = os.path.join(root, dataset, f'{split}_data.pt')

        data = torch.load(datapath)
        self.data = data['features']
        self.targets = data['targets']

        # _, self.targets = torch.unique(data['targets'], sorted=True, return_inverse=True)

        # labels =[]
        self.n_classes = data['n_classes']
        # print(self.targets.min(), self.targets.max())
        # print(self.n_classes)
        # print(f'-------------{split}-====================')

        # min_label = min(self.label)

        

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        # print(idx)
        features = self.data[idx].to(torch.float32)

        target = self.targets[idx].to(torch.int64)
        # n_classes = self.n_classes

        return features, target