import torch
import torchvision
from PIL import Image
import numpy as np
import pycls.datasets.utils as ds_utils
import pandas as pd

class TRPB(torch.utils.data.Dataset):
    def __init__(self, root, train, transform, test_transform, download=True, only_features= False):
        # super(TRPB, self).__init__(root, train, transform=transform, download=download)
        # self.test_transform = test_transform
        # self.no_aug = False
        # self.only_features = only_features
        self.features = ds_utils.load_features("TRPB", train=train, normalized=False)

    def __len__(self):
        return self.features.shape[0]
    
    
    def __getitem__(self, index: int):

        # img, target = self.data[index], self.targets[index]
        feat = self.features[index][1:]
        target = self.features[index][0]
        return feat, target
    

class TRPB_umap(torch.utils.data.Dataset):
    def __init__(self, root, train, transform, test_transform, download=True, only_features= False):
        self.features = ds_utils.load_features("TRPB_umap", train=train, normalized=False)

    def __len__(self):
        return self.features.shape[0]
    
    def __getitem__(self, index: int):

        # img, target = self.data[index], self.targets[index]

        feat = self.features[index][1:]
        target = self.features[index][0]


        return feat, target

        
class Sysdata(torch.utils.data.Dataset):
    def __init__(self, root, group, train, transform, test_transform, download=True, only_features= False):
        self.features = ds_utils.load_features(group, train=train, normalized=False)

    def __len__(self):
        return self.features.shape[0]
    
    def __getitem__(self, index: int):

        # img, target = self.data[index], self.targets[index]

        feat = self.features[index][1:]
        target = self.features[index][0]
        # target_regre = self.features[index][-1]


        return feat, target
    

class Octanoate(torch.utils.data.Dataset):
    def __init__(self, root, group, train, transform, test_transform, download=True, only_features= False):
        self.features = ds_utils.load_features(group, train=train, normalized=False)

    def __len__(self):
        return self.features.shape[0]
    
    def __getitem__(self, index: int):

        # img, target = self.data[index], self.targets[index]

        feat = self.features[index][1:]
        target = self.features[index][0]
        # target_regre = self.features[index][-1]


        return feat, target

class Octanoate_1028(torch.utils.data.Dataset):
    def __init__(self, root, group, train, transform, test_transform, download=True, only_features= False):
        self.features = ds_utils.load_features(group, train=train, normalized=False)

    def __len__(self):
        return self.features.shape[0]
    
    def __getitem__(self, index: int):

        # img, target = self.data[index], self.targets[index]

        feat = self.features[index][1:-1]
        target = self.features[index][0]
        # target_regre = self.features[index][-1]

        return feat, target
    

class Octanoate_test(torch.utils.data.Dataset):
    def __init__(self, root, group, train, transform, test_transform, download=True, only_features= False):
        self.features = ds_utils.load_features_test700(group, train=train, normalized=False)

    def __len__(self):
        return self.features.shape[0]
    
    def __getitem__(self, index: int):

        feat = self.features[index]
        # target = self.features[index][0]

        return feat
    

class Q1_data(torch.utils.data.Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return self.data.shape[0]
    
    def __getitem__(self, index: int):
        feat = self.data[index][1:]
        target = self.data[index][0]
        return feat, target


class Octanoate_test_1028(torch.utils.data.Dataset):
    def __init__(self, root, group, train, transform, test_transform, download=True, only_features= False):
        self.features = ds_utils.load_features_test700(group, train=train, normalized=False)

    def __len__(self):
        return self.features.shape[0]
    
    def __getitem__(self, index: int):

        feat = self.features[index]
        # target = self.features[index][0]

        return feat

class CIFAR10(torchvision.datasets.CIFAR10):
    def __init__(self, root, train, transform, test_transform, download=True, only_features= False):
        super(CIFAR10, self).__init__(root, train, transform=transform, download=download)
        self.test_transform = test_transform
        self.no_aug = False
        self.only_features = only_features
        self.features = ds_utils.load_features("CIFAR10", train=train, normalized=False)


    def __getitem__(self, index: int):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img, target = self.data[index], self.targets[index]

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img)
        if self.only_features:
            img = self.features[index]
        else:
            if self.no_aug:
                if self.test_transform is not None:
                    img = self.test_transform(img)
            else:
                if self.transform is not None:
                    img = self.transform(img)


        return img, target


class CIFAR100(torchvision.datasets.CIFAR100):
    def __init__(self, root, train, transform, test_transform, download=True, only_features= False):
        super(CIFAR100, self).__init__(root, train, transform=transform, download=download)
        self.test_transform = test_transform
        self.no_aug = False
        self.only_features = only_features
        self.features = ds_utils.load_features("CIFAR100", train=train, normalized=False)

    def __getitem__(self, index: int):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img, target = self.data[index], self.targets[index]

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img)
        if self.only_features:
            img = self.features[index]
        else:
            if self.no_aug:
                if self.test_transform is not None:
                    img = self.test_transform(img)
            else:
                if self.transform is not None:
                    img = self.transform(img)

        return img, target


class STL10(torchvision.datasets.STL10):
    def __init__(self, root, train, transform, test_transform, download=True):
        super(STL10, self).__init__(root, train, transform=transform, download=download)
        self.test_transform = test_transform
        self.no_aug = False
        self.targets = self.labels

    def __getitem__(self, index: int):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img, target = self.data[index], int(self.targets[index])

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img.transpose(1,2,0))

        if self.no_aug:
            if self.test_transform is not None:
                img = self.test_transform(img)
        else:
            if self.transform is not None:
                img = self.transform(img)

        return img, target


class MNIST(torchvision.datasets.MNIST):
    def __init__(self, root, train, transform, test_transform, download=True):
        super(MNIST, self).__init__(root, train, transform=transform, download=download)
        self.test_transform = test_transform
        self.no_aug = False

    def __getitem__(self, index: int):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img, target = self.data[index], int(self.targets[index])

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img.numpy(), mode='L')
        
        if self.no_aug:
            if self.test_transform is not None:
                img = self.test_transform(img)            
        else:
            if self.transform is not None:
                img = self.transform(img)


        return img, target


class SVHN(torchvision.datasets.SVHN):
    def __init__(self, root, train, transform, test_transform, download=True):
        super(SVHN, self).__init__(root, train, transform=transform, download=download)
        self.test_transform = test_transform
        self.no_aug = False

    def __getitem__(self, index: int):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img, target = self.data[index], self.targets[index]

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img)
        
        if self.no_aug:
            if self.test_transform is not None:
                img = self.test_transform(img)            
        else:
            if self.transform is not None:
                img = self.transform(img)


        return img, target