from torch.utils.data import Subset
from torch.utils.data import ConcatDataset
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from base.torchvision_dataset import TorchvisionDataset
from .preprocessing import get_target_label_idx, global_contrast_normalization
from skimage import io, color
import numpy as np
import torchvision.transforms as transforms
import os
import torch
import pandas as pd
from torch.utils import data
from scipy import io


class BANK_Dataset(TorchvisionDataset):

    def __init__(self, root: str):
        super().__init__(root)
        train_set = bank(root=self.root, train=True)
        self.train_set = train_set
        test_set = bank(root=self.root, train=False)
        self.test_set = test_set
        self.train_size = len(self.train_set)

class bank(data.Dataset):
    def __init__(self, root, train=True):
        self.root = os.path.expanduser(root)
        print("self.root:", self.root)
        self.training_file = self.root+"/bank/bank_data.npz"
        self.attr_file = self.root+"/bank/bank_attr.p"
        self.train = train
        self.use_cuda = torch.cuda.is_available()
        bank_train = np.load(self.training_file)
        import pickle
        self.attrs = pickle.load(open(self.attr_file, "rb"))
        total_l = 0

        for key in self.attrs['marital'].keys():
            print("key:", key, "length:", len(self.attrs['marital'][key]))
        for key in self.attrs['marital'].keys():
            total_l += len(self.attrs['marital'][key])
        temp_attr = np.zeros((total_l, ))

        for k, val in self.attrs['marital'].items():
            if k == 2:
                k = 2
            temp_attr[val] = k
        self.train_attrs = temp_attr
        self.test_attrs = temp_attr
        self.train_data, self.train_labels = torch.tensor(bank_train['data'], dtype=torch.float32), torch.tensor(bank_train['labels'], dtype=torch.int)
        self.test_data, self.test_labels = self.train_data, self.train_labels
        print(self.train_data.shape, self.train_labels.shape)
        if self.use_cuda == True:
            self.train_attrs = torch.from_numpy(self.train_attrs)
            self.train_data = self.train_data.cuda()
            self.train_attrs = self.train_attrs.cuda()
            self.train_labels = self.train_labels.cuda()
            self.test_attrs = torch.from_numpy(self.test_attrs)
            self.test_data = self.test_data.cuda()
            self.test_attrs = self.test_attrs.cuda()
            self.test_labels = self.test_labels.cuda()

    def __getitem__(self, index):
        img, attr, target = self.train_data[index], self.train_attrs[index], self.train_labels[index]
        return img, target, index, attr
    def __len__(self):
        if self.train:
            return len(self.train_data)
        else:
            return len(self.test_data)





