import argparse
import os

from _create_experiment_cross_validation import *
from utils.constants import Cte

parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument('--dataset', default=Cte.TRIANGLE, type=str)
parser.add_argument('--model', default=Cte.VCAUSE, type=str)
parser.add_argument('--experiment_name', default='all', type=str)

parser.add_argument('--executable', default='/home/author/miniconda3/envs/gnn/bin/python', type=str)
args = parser.parse_args()

dataset = args.dataset
model = args.model
experiment_name = args.experiment_name
executable = args.executable

from datetime import date

today = date.today()
today = today.strftime("%y_%m_%d")
complete_experiment_name = f'{dataset}_{model}_{experiment_name}_{today}'
seed_list = list(range(10))

# %% Files

yaml_file_d = os.path.join('_params', f'dataset_{dataset}.yaml')
yaml_file_m = os.path.join('_params', f'model_{model}.yaml')
yaml_file_t = os.path.join('_params',  f'trainer.yaml')

# %%
trainer_dict = {}
# %% Get dataset &  model permutations

dataset_dict = None
model_dict = None

job_template_data = {'executable': executable,
                     'request_gpus': '0',
                     'request_memory': '8096',
                     'request_cpus': '1'
                     }
if model == Cte.VCAUSE:
    dataset_dict, model_dict = get_dict_vcause(dataset, experiment_name)
elif model == Cte.CARELF:
    dataset_dict, model_dict = get_dict_carefl(dataset, experiment_name)
elif model == Cte.MCVAE:
    dataset_dict, model_dict, trainer_dict = get_dict_mcvae(dataset, experiment_name)
assert dataset_dict is not None
assert model_dict is not None

from utils.tools import create_experiment_job_file

optim_dict = {'lr': [0.005]}

create_experiment_job_file(run_filename=f'_run_files/{complete_experiment_name}.sh',
                           dataset_dict=dataset_dict,
                           model_dict=model_dict,
                           optim_dict=optim_dict,
                           trainer_dict=trainer_dict,
                           yaml_file=[yaml_file_d, yaml_file_m, yaml_file_t],
                           seed_list=seed_list,
                           root_dir=f'exper_{complete_experiment_name}')
