import torch
import torchvision.datasets as tvds
import os
import sys

sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname(__file__))))

from npf.utils.paths import datasets_path

class EMNIST(tvds.EMNIST):
    def __init__(self, train=True, class_range=[0, 47], device='cpu', download=True):
        super().__init__(datasets_path, train=train, split='balanced', download=download)

        self.data = self.data.unsqueeze(1).float().div(255).transpose(-1, -2).to(device)
        self.targets = self.targets.to(device)

        idxs = []
        for c in range(class_range[0], class_range[1]):
            idxs.append(torch.where(self.targets==c)[0])
        idxs = torch.cat(idxs)

        self.data = self.data[idxs]
        self.targets = self.targets[idxs]

    def __getitem__(self, idx):
        return self.data[idx], self.targets[idx]

class EMNIST_Partial(tvds.EMNIST):
    def __init__(self, train=True, class_range=[0, 47], device='cpu', download=True, num_bs=5):
        super().__init__(datasets_path, train=train, split='balanced', download=download)
        self.data = self.data.unsqueeze(1).float().div(255).transpose(-1, -2).to(device)
        self.targets = self.targets.to(device)
        idxs = []
        for c in range(class_range[0], class_range[1]):
            idxs.append(torch.where(self.targets==c)[0])
        idxs = torch.cat(idxs)

        self.data = self.data[idxs][:num_bs]
        self.targets = self.targets[idxs][:num_bs]

    def __getitem__(self, idx):
        return self.data[idx], self.targets[idx]
