# --- data_loader.py ---
import numpy as np

def load_data(Data_type: str, m: int):
    # load features & labels
    if Data_type == 'a9a':
        X = np.load('./data/a9a_train_features.npy')
        y = np.load('./data/a9a_train_labels.npy')
    elif Data_type == 'w8a':
        X = np.load('./data/w8a_train_features.npy')
        y = np.load('./data/w8a_train_labels.npy')
    elif Data_type == 'ijcnn1':
        X = np.load('./data/ijcnn1_train_features.npy')
        y = np.load('./data/ijcnn1_train_labels.npy')
    elif Data_type == 'cod-rna':
        X = np.load('./data/cod-rna_train_features.npy')
        y = np.load('./data/cod-rna_train_labels.npy')
    else:
        raise ValueError(f'Unknown Data_type: {Data_type}')

    n, d = X.shape
    X = X[:int(n/m)*m, :]
    y = y[:int(n/m)*m]

    # reshape
    a = X.reshape(m, int(n/m), d)
    b = y.reshape(m, int(n/m), 1)

    # load z_star
    if Data_type == 'a9a':
        z_star = np.load('z_star_a9a.npz')['z_star']
    elif Data_type == 'w8a':
        z_star = np.load('z_star_w8a.npz')['z_star']
    elif Data_type == 'ijcnn1':
        z_star = np.load('z_star_ijcnn1.npy')['z_star']
    elif Data_type == 'cod-rna':
        z_star = np.load('z_star_cod-rna.npy')['z_star']

    return a, b, d, n, z_star
