import os

# Retraining from scratch tests.
experiments_dict = {"x3_+_cutout": "_x3_cutout", "x3": "_x3",
                    "x1_+_cutout": "_cutout", "x1": ""}
dataset = "C10+"
true_dataset = "C10+"

spars_config_dict = {
    "no_DSD": "-nep {1} --no-sparsify",
    "DSD": "-nep {1} --sparsify --spars_sched_func sched_dsd"
    " --granularity weight --end_sparsity 0 --dsd_middle {2} --dsd_pattern {3}"
    " --spars_end_epoch {0} --rlr_start_epoch {0} --every_epoch_until {0}"}
epoch_count_list = [(200, 100)]  # (400, 200)]
prlr = "DensEMANN"
dsd_middle = 80
dsd_pattern = 'square'

for epoch_count in epoch_count_list:
    for experiment in experiments_dict:
        # Get the path and source ID for the network to load.
        success = False
        load_dir = "./ft-logs/Reference_DensEMANN__networks/"
        load_dir += experiment
        for fname in os.listdir(load_dir):
            if true_dataset in fname and fname.endswith("_ft_log.csv"):
                success = True
                source_id = fname[:-len("_ft_log.csv")]
                break
        # Run the experiment
        if success:
            for i in range(5):
                os.system("python run_DensEMANN.py --train --test -m DenseNet-BC"
                          " -ds {4} --prebuilt --import-only-hypers"
                          # " --data /dev/shm"
                          " --load {6}"
                          " --save ./ft-logs/DensEMANN_retrain_{7}/"
                          "{5}_same-k/no_DSD_{3}ep"
                          " -prlr {0} {1} --source_id {2}".format(
                            prlr, spars_config_dict["no_DSD"].format(
                                epoch_count[0], sum(epoch_count), 0),
                            source_id, sum(epoch_count),
                            dataset + (
                                " --cutout" if experiment.endswith("cutout")
                                else ""),
                            true_dataset, load_dir,
                            experiments_dict[experiment]))
