
import os, sys
import numpy as np
import argparse

sys.path.append('../..')
from utils.argparse_utils import *

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--proj_dir', type=str)
    parser.add_argument('--proj_id', type=str, default='complex_sph=False-get_H=False-get_SASA=False-get_charge=False-lmax=4-n_channels=4-n_neigh=%d-rcut=10.0-rmax=20-rst_normalization=square')
    parser.add_argument('--N_name', type=int, default=80000)
    parser.add_argument('--N_per_type_desired', type=int, default=2500)
    parser.add_argument('--splits', type=str, default='40,40,20')
    parser.add_argument('--seed', type=int, default=100000000)

    args = parser.parse_args()

    splits = list(map(int, args.splits.split(',')))
    assert np.sum(splits) == 100
    props = {
        'train': splits[0],
        'val': splits[1],
        'test': splits[2]
    }
    splits = ['train', 'val', 'test'] # this list assigns an order to the splits
    assert set(splits) == set(props.keys())

    rng = np.random.default_rng(args.seed)

    data_arrays = {
        'projections': np.load(os.path.join(args.proj_dir, 'projections-all-{}.npy'.format(args.proj_id % (args.N_name)))),
        'data_ids': np.load(os.path.join(args.proj_dir, 'data_ids-all-{}.npy'.format(args.proj_id % (args.N_name)))),
        'frames': np.load(os.path.join(args.proj_dir, 'frames-all-{}.npy'.format(args.proj_id % (args.N_name)))),
        'aa_labels': np.load(os.path.join(args.proj_dir, 'aa_labels-all-{}.npy'.format(args.proj_id % (args.N_name))))
    }


    # need to make the division equal across residue types
    residue_types_N = np.array([data_id[0] for data_id in data_arrays['data_ids']])
    N = residue_types_N.shape[0]

    print(data_arrays['projections'].shape)

    from collections import Counter
    residue_types = list(dict(Counter(residue_types_N)).keys())
    print(Counter(residue_types_N))
    N_desired = args.N_per_type_desired * 20

    # exit(1)
    
    # get idxs for each split
    splits_idxs = {}
    for split in splits:
        splits_idxs[split] = []
    
    for residue_type in residue_types:
        idxs_of_type = rng.choice(np.arange(N)[residue_types_N == residue_type], args.N_per_type_desired, replace=False)
        
        start_i = 0
        for split in splits:
            delta_i = int(args.N_per_type_desired*(props[split])/100)
            splits_idxs[split].append(idxs_of_type[start_i : start_i + delta_i])
            start_i += delta_i

    for split in splits:
        splits_idxs[split] = np.hstack(splits_idxs[split])
        rng.shuffle(splits_idxs[split])
    
    for split in splits:
        print(split)
        print(splits_idxs[split])
        print(splits_idxs[split].shape)
        print()
    
    # train = set(list(splits_idxs['train']))
    # val = set(list(splits_idxs['val']))
    # test = set(list(splits_idxs['test']))
    # for item in train:
    #     assert item not in val
    #     assert item not in test
    # for item in val:
    #     assert item not in test

    

    # assemble dataset with the idxs
    dataset = {}
    for split in splits:
        dataset[split] = {}
        for data_item in data_arrays:
            dataset[split][data_item] = data_arrays[data_item][splits_idxs[split]]
    
    for split in splits:
        for data_item in dataset[split]:
            np.save(os.path.join(args.proj_dir, '{}-{}-{}.npy'.format(data_item, split, args.proj_id % (dataset[split][data_item].shape[0]))), dataset[split][data_item])

    
