import os
import json
import pandas as pd
import numpy as np


if __name__ == '__main__':
    # load config file
    config_name = 'oqc_bomla_lam1.json'
    jsonfile = open(os.path.join('./config/la_seqdataset', config_name))
    config = json.loads(jsonfile.read())

    dest_dir = os.path.join(config['data_dir'], 'mini_imagenet')
    split_dir = os.path.join(os.path.join(config['data_dir'], config['split_folder']), 'mini_imagenet')

    os.makedirs(split_dir)
    # use split in Ravi & Larochelle
    metatrain = [os.path.join(dest_dir, clsname)
                 for clsname in pd.read_csv('../data/mini_imagenet_split/train.csv')['label'].unique()]
    metaval = [os.path.join(dest_dir, clsname)
               for clsname in pd.read_csv('../data/mini_imagenet_split/val.csv')['label'].unique()]
    metatest = [os.path.join(dest_dir, clsname)
                for clsname in pd.read_csv('../data/mini_imagenet_split/test.csv')['label'].unique()]


    np.save(os.path.join(split_dir, 'metatrain.npy'), metatrain)
    np.save(os.path.join(split_dir, 'metaval.npy'), metaval)
    np.save(os.path.join(split_dir, 'metatest.npy'), metatest)
