from torch.utils.data import Subset
from torch.utils.data import ConcatDataset
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from base.torchvision_dataset import TorchvisionDataset
from .preprocessing import get_target_label_idx, global_contrast_normalization
from skimage import io, color
import numpy as np
import torchvision.transforms as transforms
import os
import torch
import pandas as pd
from torch.utils import data
from scipy import io


class HAR_Dataset(TorchvisionDataset):

    def __init__(self, root: str):
        super().__init__(root)
        train_set = har(root=self.root, train=True)
        self.train_set = train_set
        test_set = har(root=self.root, train=False)
        self.test_set = test_set
        self.train_size = len(self.train_set)

class har(data.Dataset):
    raw_folder = 'raw'
    processed_folder = 'processed'
    training_file = 'training.pt'
    test_file = 'test.pt'
    @property
    def get_attr(self):
        if self.train:
            return self.train_attrs
        else:
            return self.test_attrs
    def __init__(self, root, train=True):
        # attr_type: 0: age, 1: gender
        self.root = os.path.expanduser(root)
        self.train = train  # training set or test set
        self.use_cuda = torch.cuda.is_available()
        har = io.loadmat(self.root+'/har/raw/HAR.mat')
        train_data = np.array(har['X_train'])
        test_data = np.array(har['X_test'])
        train_y = np.squeeze(np.array(har['y_train']))
        test_y = np.squeeze(np.array(har['y_test']))
        train_attrs = np.array(har['subject_train'])
        test_attrs = np.array(har['subject_test'])
        with open(self.root+'/har/processed/train_data', 'wb') as f:
            torch.save(train_data, f)
        with open(self.root+'/har/processed/test_data', 'wb') as f:
            torch.save(test_data, f)
        with open(self.root+'/har/processed/train_attrs', 'wb') as f:
            torch.save(train_attrs, f)
        with open(self.root+'/har/processed/test_attrs', 'wb') as f:
            torch.save(test_attrs, f)
        with open(self.root+'/har/processed/train_y', 'wb') as f:
            torch.save(train_y, f)
        with open(self.root+'/har/processed/test_y', 'wb') as f:
            torch.save(test_y, f)
        if self.train:
            self.train_data = torch.load(self.root+'/har/processed/train_data')
            self.train_attrs = torch.load(self.root+'/har/processed/train_attrs')
            self.train_y = torch.load(self.root+'/har/processed/train_y')
            self.train_data = self.train_data.astype(float)
            self.train_attrs = self.train_attrs.astype(int)
            self.train_y = self.train_y.astype(int)
            if self.use_cuda:
                self.train_data = torch.from_numpy(self.train_data).float()
                self.train_attrs = torch.from_numpy(self.train_attrs)
                self.train_y = torch.from_numpy(self.train_y)
        else:
            self.test_data = torch.load(self.root+'/har/processed/test_data')
            self.test_data = self.test_data.astype(float)
            self.test_attrs = torch.load(self.root+'/har/processed/test_attrs')
            self.test_y = torch.load(self.root+'/har/processed/test_y')
            self.test_attrs = self.test_attrs.astype(int)
            self.test_y = self.test_y.astype(int)
            if self.use_cuda:
                self.test_data = torch.from_numpy(self.test_data).float()
                self.test_attrs = torch.from_numpy(self.test_attrs)
                self.test_y = torch.from_numpy(self.test_y)

    def __getitem__(self, index):
          """
          Args:
              index (int): Index
              attr_type (int): 0: age, 1: gender, 2: race
          Returns:
              tuple: (image, target) where target is index of the target class.
          """
          if self.train:
              img, attr, y = self.train_data[index], self.train_attrs[index], self.train_y[index]
          else:
              img, attr, y = self.test_data[index], self.test_attrs[index], self.test_y[index]

          return img, y, index, attr-1

    def __len__(self):
          if self.train:
              return len(self.train_data)
          else:
              return len(self.test_data)


