"""
Train a model on miniImageNet.
"""

import random
import tensorflow as tf
from MinibatchProx.args import argument_parser, model_kwargs, train_kwargs, evaluate_kwargs,data_kwargs
from MinibatchProx.eval import evaluate
from MinibatchProx.models import MiniImageNetModel
from MinibatchProx.miniimagenet import read_dataset
from MinibatchProx.train import train
import os
import pdb
from MinibatchProx.tieredimagenet import dataset_tiered

os.environ['CUDA_VISIBLE_DEVICES'] = '0'
#DATA_DIR = '../tiered-imagenet/'

"""  tieredImageNet
5-way 1-shot:
python -u run_tieredimagenet.py --lam_reg 10.0 --shots 1 --classes 5 --inner-batch 10 --inner-iters 8 --meta-step 1 --meta-batch 20 --meta-iters 100000 --eval-batch 15 --eval-iters 50 --learning-rate 0.001 --meta-step-final 0 --train-shots 15 --checkpoint ckpt_m51_tieredimagenet

5-way 5-shot:
python -u run_tieredimagenet.py --lam_reg 10.0 --shots 5 --classes 5 --inner-batch 10 --inner-iters 8 --meta-step 1 --meta-batch 20 --meta-iters 100000 --eval-batch 20 --eval-iters 50 --learning-rate 0.001 --meta-step-final 0 --train-shots 15 --checkpoint ckpt_m55_tieredimagenet

python -u test_run_miniimagenet55-10MB3.py --model 99999 --lam_reg 10.0 --inner-batch 10 --inner-iters 8 --meta-step 1 --meta-batch 20 --meta-iters 100000 --eval-batch 20 --eval-iters 50 --learning-rate 0.001 --meta-step-final 0 --train-shots 15 --checkpoint ckpt_m55_10_MB

10-way 1-shot:
python -u run_tieredimagenet.py --lam_reg 10.0 --shots 1 --classes 10 --inner-batch 10 --inner-iters 8 --meta-step 1 --meta-batch 20 --meta-iters 100000 --eval-batch 15 --eval-iters 50 --learning-rate 0.001 --meta-step-final 0 --train-shots 15 --checkpoint ckpt_m101_tieredimagenet


10-way 5-shot:
python -u run_tieredimagenet.py --lam_reg 10.0 --shots 5 --classes 10 --inner-batch 10 --inner-iters 16 --meta-step 1 --meta-batch 20 --meta-iters 100000 --eval-batch 15 --eval-iters 50 --learning-rate 0.001 --meta-step-final 0 --train-shots 15 --checkpoint ckpt_m105_tieredimagenet

"""


def main():
    """
    Load data and train a model on it.
    """
    args = argument_parser().parse_args()
    random.seed(args.seed)

    train_set = dataset_tiered(split='train', **data_kwargs(args))
    train_set.load_data_pkl()

    test_set = dataset_tiered(split='test',  **data_kwargs(args))
    test_set.load_data_pkl()
    model = MiniImageNetModel(args.classes, **model_kwargs(args))

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        if not args.pretrained:
            print('Training...')
            train(sess, model, train_set, test_set, args.checkpoint, **train_kwargs(args))
        else:
            print('Restoring from checkpoint...')
            tf.train.Saver().restore(sess, tf.train.latest_checkpoint(args.checkpoint))

        print('Evaluating...')
        eval_kwargs = evaluate_kwargs(args)
        #print('Train accuracy: ' + str(evaluate(sess, model, train_set, **eval_kwargs)))
        #print('Validation accuracy: ' + str(evaluate(sess, model, val_set, **eval_kwargs)))
        print('Test accuracy: ' + str(evaluate(sess, model, test_set, **eval_kwargs)))

if __name__ == '__main__':
    main()
