from __future__ import print_function
import argparse
from os.path import join

import pandas as pd

from skorch.callbacks import Checkpoint

from mnist_auto_aug.training_utils import fit_and_predict,\
    prepare_skorch_training


def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--batch-size', type=int, default=64, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size', type=int, default=1000,
                        metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs', type=int, default=14, metavar='N',
                        help='number of epochs to train (default: 14)')
    parser.add_argument('--lr', type=float, default=0.002, metavar='LR',
                        help='learning rate (default: 0.002)')
    parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
                        help='Learning rate step gamma (default: 0.7)')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--dry-run', action='store_true', default=False,
                        help='quickly check a single pass')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument(
        '--log-interval', type=int, default=10, metavar='N',
        help='how many batches to wait before logging training status'
    )
    parser.add_argument(
        'experiment_path',
        help='Path to folder where to save the results.'
    )
    parser.add_argument(
        '--use-all-classes', action='store_true',
        help='Disables class filtering (all 10 mnist classes will be used).'
    )
    args = parser.parse_args()
    (
        model, data, model_params, callbacks, epochs, classes,
    ) = prepare_skorch_training(args)

    train_set, valid_set, test_loader = data
    callbacks += [
        (
            'checkpoint', Checkpoint(
                monitor='valid_acc_best',
                dirname=args.experiment_path
            )
        )
    ]

    (test_loss, test_acc), (valid_loss, valid_acc) = fit_and_predict(
        model=model,
        train_set=train_set,
        valid_set=valid_set,
        test_loader=test_loader,
        epochs=epochs,
        model_params=model_params,
        callbacks=callbacks,
        return_valid_perf=True
    )

    final_perf = pd.DataFrame(
        [[test_loss, test_acc, valid_loss, valid_acc]],
        index=[1],
        columns=['test_loss', 'test_acc', 'valid_loss', 'valid_acc'],
    )
    final_perf.to_csv(
        join(args.experiment_path, "final_results.csv"),
        index=False,
    )


if __name__ == '__main__':
    main()
