import os
import shutil


def readfile(path):
    with open(path, "r") as f:
        return f.read()


# H-para
''' grid-holdout_fraction-env_density
exps_name = "grid-holdout_fraction-env_density"
bash_template_path = "slurm_scripts/sbatch_baselines_singlejob.sh"
python_cmd_template = "python domainbed/scripts/train_edg.py --data_dir ../datasets_for_domainbed --algorithm {--algorithm} --dataset {--dataset} --edg_direction backward --holdout_fraction {--holdout_fraction} --env_density {--env_density} --seed {--seed} --output_dir {--output_dir} "
bashfile_ouput_dir = "EXPS/" + exps_name + "/"
hparams = {
    "algorithm": ['EDG', 'ERM'],
    "holdout_fraction": [0.2, 0.4, 0.6, 0.8],
    "env_density": [3, 5, 7, 9, 11, 13, 15, 17],
    "seed": [1, 2, 3, 4, 5]
}
'''

# ''' grid-env_number-env_distance-env_sample_number
# (bash_template_path -> bashfile_template) + (python_cmd_template + hparams)
# output to bashfile_ouput_dir
# exps_name = "grid-env_number-env_distance-env_sample_number"
exps_name = "new_compare"
bash_template_path = "slurm_scripts/sbatch_baselines_singlejob.sh"
python_cmd_template = "python domainbed/scripts/train_edg.py --data_dir ../datasets_for_domainbed --algorithm {--algorithm} --dataset {--dataset} --test_type backward_val --holdout_fraction 0.2 --env_distance {--env_distance} --env_number {--env_number} --env_sample_number {--env_sample_number} --seed {--seed} --output_dir {--output_dir}"
bashfile_ouput_dir = "EXPS/" + exps_name + "/"
hparams = {
    "dataset": ['EDGPortrait'],
    "algorithm": ['EDG', 'ERM'],
    # "env_distance": [1, 3, 5, 7, 10, 15, 20, 45],
    "env_distance": [1, 3, 5, 10, 15, 20],
    # "env_number": [],
    # "env_number": [3, 5, 7, 10, 15, 20],
    "env_number": [3, 5, 7, 9],
    # "env_sample_number": [200, 500, 1000, 1500, 2000, 3000],
    "env_sample_number": [200],
    "seed": [1, 2, 3, 4, 5]
}
# '''

bashfile_template = readfile(bash_template_path) # ↓ TODO
bashfile_cmd_slot = "###bashfile_cmd_slot###"



def write_file_from_hparams(cur_hparams):
    '''
        1. cur_params + output_bashfile -> python_cmd
        2. python_cmd -> bashfile_template
        cur_hparams: a dict whose keys are corresponding to slots.
        python_cmd_template: cmd with some slots. eg. "python domainbed/scripts/train_edg.py --data_dir ../datasets_for_domainbed --algorithm {algorithm} --dataset DenseDomainRotatedMNIST --edg_direction backward --holdout_fraction {--holdout_fraction} --env_density {--env_density} --seed {--seed} --output_dir {--output_dir}"
        bashfile_output_dir: where to put the bashfiles.
        bashfile_template: template with slot
    '''
    # slot cur_hparams into python_cmd template
    global python_cmd_template
    python_cmd = python_cmd_template
    for k, v in cur_hparams.items():
        python_cmd = python_cmd.replace("{--"+k+"}", str(v))
    # slot python_cmd output dir based on name
    bashfile_name = make_name(cur_hparams, prefix=exps_name)
    python_cmd = python_cmd.replace(
        "{--output_dir}", bashfile_ouput_dir+bashfile_name.replace(".sh", ""))
    # slot python_cmd into bashfile and write
    # print(bashfile_name)
    bashfile_content = bashfile_template.replace(bashfile_cmd_slot, python_cmd)
    with open(bashfile_ouput_dir+bashfile_name, "w") as f:
        f.write(bashfile_content)
    return bashfile_name

'utils'

def gen_list_of_hp_dict(left_hparams, cur_list = []):
    'from {a: [], b: [], ...}. '
    if len(cur_list) == 0: # first level
        keys = list(left_hparams.keys())
        first_key = keys[0]
        res_list = []
        for each_v in left_hparams[first_key]:
            res_list.append({first_key: each_v})
    else:
        keys = list(left_hparams.keys())
        first_key = keys[0]
        res_list = []
        for each_v in left_hparams[first_key]:
            for each_d in cur_list:
                each_d[first_key] = each_v
                res_list.append(each_d.copy())
    del left_hparams[first_key]
    if len(keys) == 1: return res_list
    else: return gen_list_of_hp_dict(left_hparams, cur_list=res_list)

def make_name(cur_hparams, prefix=""):
    res = prefix
    for k, v in cur_hparams.items():
        res += "--"
        res += k
        res += '#'
        res += str(v).replace('.', 'dot')
    res += ".sh"
    return res



if __name__ == "__main__":
    cur_hparams = None
    file_list = []
    hparams_list = gen_list_of_hp_dict(hparams)
    for each_hparams in hparams_list:
        res = write_file_from_hparams(each_hparams)
        file_list.append(res)
    for i in range(len(file_list)):
        file_list[i] = f"sbatch {file_list[i]}"
    file_list_ = []
    # for i in range(len(file_list)):
    #     file_list_.insert(0, file_list[i].replace("sbatch", "chmod +777"))
    file_list = file_list_ + file_list
    content = "\n".join(file_list)
    with open(os.path.join(bashfile_ouput_dir, "run.sh"), "w") as f:
        f.write(content)
    print(f"generate {len(file_list)} exps .sh")