import time
import argparse
from pathlib import Path
from optuna.samplers import RandomSampler

from mnist_auto_aug.training_utils import prepare_skorch_training

from mnist_auto_aug.auto_augment import search_subpolicies
from mnist_auto_aug.auto_augment import make_train_and_see_objective_skorch



def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--batch-size', type=int, default=64, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size', type=int, default=1000,
                        metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs', type=int, default=14, metavar='N',
                        help='number of epochs to train (default: 14)')
    parser.add_argument('--lr', type=float, default=0.002, metavar='LR',
                        help='learning rate (default: 0.002)')
    parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
                        help='Learning rate step gamma (default: 0.7)')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--dry-run', action='store_true', default=False,
                        help='quickly check a single pass')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--idx', type=int, default=0,
                        help='Idx for this job, useful to launch multiple'
                        'searcher in parallel.')
    parser.add_argument(
        '--log-interval', type=int, default=75, metavar='N',
        help='how many batches to wait before logging training status'
    )
    parser.add_argument('-t', '--n_trials', type=int, default=200,
                        help='Number of trials in the policy search.')
    parser.add_argument(
        'experiment_path',
        help='Path to folder where to save the search results.'
    )
    parser.add_argument(
        '--use-all-classes', action='store_true',
        help='Disables class filtering (all 10 mnist classes will be used).'
    )
    parser.add_argument(
        '--no-classwise', action='store_false', dest="classwise",
        help='Disables classwise sampling and uses regular transforms.'
    )
    parser.add_argument(
        '--baseline', action='store_true',
        help='Disables sampling and uses Identity everywhere.'
    )
    parser.add_argument('--n_jobs', type=int, default=1,
                        help='Number of parallel search jobs.')
    args = parser.parse_args()
    
    experiment_path = Path(args.experiment_path)
    if args.baseline:
        experiment_path /= 'baseline'
    elif args.classwise:
        experiment_path /= 'classwise'
    else:
        experiment_path /= 'single'

    t_start = time.time()
    (
        model, data, model_params, callbacks, epochs, classes
    ) = prepare_skorch_training(args)

    train_set, valid_set, test_loader = data
    obj_params = {
        'train_set': train_set,
        'valid_set': valid_set,
        'test_loader': test_loader,
        'model_params': model_params,
        'callbacks': callbacks,
        'classes': classes,
        'classwise': args.classwise,
        'baseline': args.baseline,
        'epochs': epochs,
    }
    objective_factory = make_train_and_see_objective_skorch

    n_trials = args.n_trials if not args.dry_run else 3

    search_subpolicies(
        experiment_path=experiment_path,
        objective_factory=objective_factory,
        objective_params=obj_params,
        model=model,
        sampler_class=RandomSampler,
        seed=args.seed,
        n_trials=n_trials,
        pruning=True,
        n_jobs=args.n_jobs,
        job_idx=args.idx
    )
    
    print(f"Script ran in {time.time() - t_start:.1f}s")


if __name__ == '__main__':
    main()
