from src.setup import default_gan_runner, evaluation_runner
from src.utils import set_default_arguments, get_loss_dict, print_preamble, make_experiment_paths, get_hyperparams, get_ds_and_model
from src.dataset import CountDSIterator

args = set_default_arguments()

hyperparams = get_hyperparams(args)

print_preamble(args.exp_name, args.dataset, args.Nranks, args.batch_size, args.epochs, hyperparams)

exp_path = make_experiment_paths("./experiments", args.exp_name, args.dataset,
                                 args.Nranks, args.exp_id, hyperparams)

DataIterator, GANModel, ExtraDataIterator = get_ds_and_model(args)

extra_data = [ExtraDataIterator, args.extra_dataf, args.extra_data_dir]

loss_dict, loss_w_dict = get_loss_dict(hyperparams)

"""----------------------------------------------
Step 1. Run Training
-------------------------------------------------"""

generator = default_gan_runner(trainf        = args.trainf,
                               data_dir      = args.data_dir,
                               DataIterator  = DataIterator,
                               batch_size    = args.batch_size,
                               nranks        = args.Nranks,
                               epochs        = args.epochs,
                               path          = exp_path,
                               GAN           = GANModel,
                               loss_dict     = loss_dict,
                               loss_w_dict   = loss_w_dict,
                               hyperparams   = hyperparams,
                               extra_data    = extra_data)


"""----------------------------------------------
Step 2. Evaluate
-------------------------------------------------"""
evaluation_runner(evalf         =  args.evalf,
                  data_dir      =  args.data_dir,
                  DataIterator  =  CountDSIterator,
                  generator     =  generator,

