from tensorflow.python.keras.utils.data_utils import get_file
import os 
import numpy as np
# from google_drive_downloader import GoogleDriveDownloader as gdd

# credit for https://www.microsoft.com/en-us/download/details.aspx?id=54765

IMG_SIZE = 128

def load_data(path):
    save_data_path = os.path.join(path,'stroke_80*4_X.npy')
    save_label_path = os.path.join(path, 'stroke_80*4_Y.npy')
    train_data = np.load(save_data_path)
    train_label = np.expand_dims(np.load(save_label_path),axis=1).astype(int)
    #train_data = np.transpose(train_data, (0,4,2,3,1)) # 3D model
    train_data = np.transpose(train_data, (0,1,4,2,3))
    print(train_data.shape)
    N = train_label.shape[0]
    trN = int(np.floor(0.9*N))
    teN = int(N - trN)
    np.random.seed(0)
    randids = np.random.permutation(N)
    np.save('/home/dixzhu/data/stroke_80*4_perm_ids.npy',randids)
    train_ids = randids[:trN]
    test_ids = randids[N-teN:N]
    test_data = train_data[test_ids]
    train_data = train_data[train_ids]
    test_label = train_label[test_ids]
    train_label = train_label[train_ids]
    print(train_data.shape)
    print(test_data.shape)
    print(sum(train_label)/len(train_label))
    print(sum(test_label)/len(test_label))
    
    return train_data, train_label, test_data, test_label


def STROKE():
    
    
    train_X, train_Y,  test_X, test_Y = load_data('/home/dixzhu/data/')

    # convert data type
    train_X, train_Y = train_X.astype(float), train_Y.astype(np.int32) 
    test_X, test_Y = test_X.astype(float), test_Y.astype(np.int32) 
    
    return  (train_X, train_Y), (test_X, test_Y) 


