import numpy as np
import h5py

def load_data_from_mat(path="data/mnist_fea_pca.mat"):
	f = h5py.File(path,'r')
	trX = f.get('trX')
	trY = f.get('trY')
	teX = f.get('teX')
	teY = f.get('teY')
	trX, trY = np.array(trX).T, np.array(trY)[0]
	teX, teY = np.array(teX).T, np.array(teY)[0]
	trY, teY = trY.astype(int), teY.astype(int)
	return trX, trY, teX, teY

def load_celeb_from_mat(path="data/celeb_fea.mat"):
    f = h5py.File(path,'r')
    data = f.get('data')
    labels = f.get('labels')
    labels_gt = f.get('labels_gt')

    data = np.array(data).T
    labels = np.array(labels)[0].astype(int)
    labels_gt = np.array(labels_gt)[0].astype(int)

    return data, labels, labels_gt

def load_cifar10_from_mat(path):
    f = h5py.File(path,'r')
    trX = f.get('trainXCp')['value'][:]
    teX = f.get('testXCp')['value'][:]
    trY = f.get('trainY')['value'][:]
    teY = f.get('testY')['value'][:]

    trX, trY = np.array(trX).T, np.array(trY)[0]
    teX, teY = np.array(teX).T, np.array(teY)[0]
    trX, teX = trX.astype(np.float64), teX.astype(np.float64)
    trY, teY = trY.astype(int), teY.astype(int)
    return trX, trY, teX, teY