
import os, sys
import numpy as np

if __name__ == '__main__':

    data_dir = ''
    DATA_ID = 'complex_sph=False-get_H=False-get_SASA=False-get_charge=False-lmax=4-n_channels=4-n_neigh=%d-rcut=10.0-rmax=20-rst_normalization=square'
    N_train = 20000
    N_valid = 20000
    prop_list = [2, 5, 10, 25]
    N_train_per_type_list = list(map(lambda prop: int((N_train/20)*prop/100), prop_list))
    N_valid_per_type_list = list(map(lambda prop: int((N_valid/20)*prop/100), prop_list))

    print(N_train_per_type_list)
    print(N_valid_per_type_list)

    SEED = 10000000

    train_projections = np.load(os.path.join(data_dir, 'projections-train-{}.npy'.format(DATA_ID % N_train)))
    train_data_ids = np.load(os.path.join(data_dir, 'data_ids-train-{}.npy'.format(DATA_ID % N_train)))
    train_frames = np.load(os.path.join(data_dir, 'frames-train-{}.npy'.format(DATA_ID % N_train)))
    train_aa_labels = np.load(os.path.join(data_dir, 'aa_labels-train-{}.npy'.format(DATA_ID % N_train)))
    shuffling_indices = np.arange(train_projections.shape[0])
    np.random.default_rng(SEED).shuffle(shuffling_indices)
    train_projections = train_projections[shuffling_indices]
    train_data_ids = train_data_ids[shuffling_indices]
    train_frames = train_frames[shuffling_indices]
    train_aa_labels = train_aa_labels[shuffling_indices]

    valid_projections = np.load(os.path.join(data_dir, 'projections-val-{}.npy'.format(DATA_ID % N_valid)))
    valid_data_ids = np.load(os.path.join(data_dir, 'data_ids-val-{}.npy'.format(DATA_ID % N_valid)))
    valid_frames = np.load(os.path.join(data_dir, 'frames-val-{}.npy'.format(DATA_ID % N_valid)))
    valid_aa_labels = np.load(os.path.join(data_dir, 'aa_labels-val-{}.npy'.format(DATA_ID % N_valid)))
    shuffling_indices = np.arange(valid_projections.shape[0])
    np.random.default_rng(SEED).shuffle(shuffling_indices)
    valid_projections = valid_projections[shuffling_indices]
    valid_data_ids = valid_data_ids[shuffling_indices]
    valid_frames = valid_frames[shuffling_indices]
    valid_aa_labels = valid_aa_labels[shuffling_indices]

    AAs = list(set(list(train_aa_labels)))
    print(AAs)
    print(len(AAs))

    for N_per_type in N_train_per_type_list:
        N = N_per_type * len(AAs)
        
        temp_train_projections, temp_train_data_ids, temp_train_frames, temp_train_aa_labels = [], [], [], []
        
        for aa in AAs:
            temp_train_projections.append(train_projections[train_aa_labels == aa][:N_per_type])
            temp_train_data_ids.append(train_data_ids[train_aa_labels == aa][:N_per_type])
            temp_train_frames.append(train_frames[train_aa_labels == aa][:N_per_type])
            temp_train_aa_labels.append(train_aa_labels[train_aa_labels == aa][:N_per_type])
        
        temp_train_projections = np.vstack(temp_train_projections)
        temp_train_data_ids = np.vstack(temp_train_data_ids)
        temp_train_frames = np.vstack(temp_train_frames)
        temp_train_aa_labels = np.hstack(temp_train_aa_labels)

        print(temp_train_projections.shape)
        print(temp_train_data_ids.shape)
        print(temp_train_frames.shape)
        print(temp_train_aa_labels.shape)

        shuffling_indices = np.arange(temp_train_projections.shape[0])
        np.random.default_rng(SEED).shuffle(shuffling_indices)
        temp_train_projections = temp_train_projections[shuffling_indices]
        temp_train_data_ids = temp_train_data_ids[shuffling_indices]
        temp_train_frames = temp_train_frames[shuffling_indices]
        temp_train_aa_labels = temp_train_aa_labels[shuffling_indices]

        np.save(os.path.join(data_dir, 'projections-train-{}.npy'.format(DATA_ID % N)), temp_train_projections)
        np.save(os.path.join(data_dir, 'data_ids-train-{}.npy'.format(DATA_ID % N)), temp_train_data_ids)
        np.save(os.path.join(data_dir, 'frames-train-{}.npy'.format(DATA_ID % N)), temp_train_frames)
        np.save(os.path.join(data_dir, 'aa_labels-train-{}.npy'.format(DATA_ID % N)), temp_train_aa_labels)


    for N_per_type in N_valid_per_type_list:
        N = N_per_type * len(AAs)
        
        temp_valid_projections, temp_valid_data_ids, temp_valid_frames, temp_valid_aa_labels = [], [], [], []
        
        for aa in AAs:
            temp_valid_projections.append(valid_projections[valid_aa_labels == aa][:N_per_type])
            temp_valid_data_ids.append(valid_data_ids[valid_aa_labels == aa][:N_per_type])
            temp_valid_frames.append(valid_frames[valid_aa_labels == aa][:N_per_type])
            temp_valid_aa_labels.append(valid_aa_labels[valid_aa_labels == aa][:N_per_type])

        temp_valid_projections = np.vstack(temp_valid_projections)
        temp_valid_data_ids = np.vstack(temp_valid_data_ids)
        temp_valid_frames = np.vstack(temp_valid_frames)
        temp_valid_aa_labels = np.hstack(temp_valid_aa_labels)

        print(temp_valid_projections.shape)
        print(temp_valid_data_ids.shape)
        print(temp_valid_frames.shape)
        print(temp_valid_aa_labels.shape)

        shuffling_indices = np.arange(temp_valid_projections.shape[0])
        np.random.default_rng(SEED).shuffle(shuffling_indices)
        temp_valid_projections = temp_valid_projections[shuffling_indices]
        temp_valid_data_ids = temp_valid_data_ids[shuffling_indices]
        temp_valid_frames = temp_valid_frames[shuffling_indices]
        temp_valid_aa_labels = temp_valid_aa_labels[shuffling_indices]

        np.save(os.path.join(data_dir, 'projections-val-{}.npy'.format(DATA_ID % N)), temp_valid_projections)
        np.save(os.path.join(data_dir, 'data_ids-val-{}.npy'.format(DATA_ID % N)), temp_valid_data_ids)
        np.save(os.path.join(data_dir, 'frames-val-{}.npy'.format(DATA_ID % N)), temp_valid_frames)
        np.save(os.path.join(data_dir, 'aa_labels-val-{}.npy'.format(DATA_ID % N)), temp_valid_aa_labels)
