import os
import argparse
import numpy as np
from ray import tune
from train_image_classification import ImageClassification
import pickle
import random
import torch


def main(args):
    if args.seed is None:
        args.seed = random.randint(0, 1e3)
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    exp_name = (
        args.optimizer
        + "_"
        + args.tune_algo
        + "_{}".format(args.num_config)
        + "_seed-{}".format(args.seed)
    )
    experiment_dir = os.path.join(args.save_root, exp_name)
    best_trial = pickle.load(open(f"{experiment_dir}/best_trial.pkl", "rb"))

    config = {"args": args, "best": tune.grid_search(best_trial)}
    analysis = tune.run(
        ImageClassification,
        name="test_{}".format(exp_name),
        local_dir=args.save_root,
        stop={"training_iteration": 3 if args.smoke_test else args.max_t,},
        resources_per_trial={"cpu": 3, "gpu": int(args.use_gpu)},
        num_samples=1,
        checkpoint_at_end=True,
        checkpoint_freq=3,
        config=config,
    )

    dfs = analysis.trial_dataframes

    test_dir = os.path.join(args.save_root, "test_{}".format(exp_name))
    pickle.dump(dfs, open(f"{test_dir}/log.pkl", "wb"))


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--data_root", type=str, default="~/data", help="Root of data folder."
    )
    parser.add_argument(
        "--save_root", type=str, default="./checkpoint", help="Root of checkpoint."
    )
    parser.add_argument(
        "--tune_algo",
        type=str,
        required=True,
        choices=["HB", "BOHB"],
        help="Name of tuning algorithm.",
    )
    parser.add_argument(
        "--optimizer",
        type=str,
        required=True,
        choices=["SGD", "Adam"],
        help="Name of optimizer.",
    )
    parser.add_argument("--seed", type=int, help="Random seed")
    parser.add_argument("--mode", type=str, default="train")
    parser.add_argument("--batch_size", type=int, default=128)
    parser.add_argument("--max_t", type=int, default=50, help="maximum epoch per trial")
    parser.add_argument(
        "--num_config", type=int, default=62, help="number of configurations"
    )
    parser.add_argument("--dataset", type=str, default="mnist")
    parser.add_argument(
        "--use-gpu", action="store_true", default=False, help="enables CUDA training"
    )
    parser.add_argument(
        "--smoke-test", action="store_true", help="Finish quickly for testing"
    )
    args = parser.parse_args()
    main(args)
