from xmeta.utils.seed import set_seed
import learn2learn as l2l
from xmeta.utils.sift import SiftFeature
from argparse import ArgumentParser
import os
import datetime
import sys


def main(
        ways=5,
        shots=5,
        seed=42,
        k=10,
        num_tasks=None,
        num_data=None
):

    save_dir = datetime.datetime.now().strftime('%Y-%m%d-%H%M%S')
    save_dir = os.path.join('./cache', save_dir)

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    
    args_file_name = 'args_' + os.path.basename(__file__) + '.txt'
    args_file_path = os.path.join(save_dir, args_file_name)
    with open(args_file_path, mode='w') as f:
        f.write(" ".join(sys.argv))

    if num_tasks is not None:
        set_seed(seed)
        data = l2l.vision.benchmarks.get_tasksets('mini-imagenet',
                                                  train_ways=ways,
                                                  train_samples=2*shots,
                                                  test_ways=ways,
                                                  test_samples=2*shots,
                                                  num_tasks=num_tasks,
                                                  root='~/data',
                                                  ).train
        n_sample = num_tasks
    else:
        data = l2l.vision.datasets.MiniImagenet(root='~/data', mode='train',
                                                download=True)
        n_sample = num_data

    set_seed(seed)
    SiftFeature(data, k=k, name='mifeature', use_cache=True,
                n_sample=n_sample, cache_dir=save_dir)

if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument('--ways', type=int, default=5)
    parser.add_argument('--shots', type=int, default=5)
    parser.add_argument('--k', type=int, default=16)
    parser.add_argument('--num-tasks', type=int, default=None)
    parser.add_argument('--num-data', type=int, default=None)
    args = parser.parse_args()
    
    main(shots=args.shots,
         ways=args.ways,
         k=args.k,
         num_tasks=args.num_tasks,
         num_data=args.num_data
         )

