import os

THIS_PATH = os.path.dirname(__file__)
ROOT_PATH = os.path.abspath(os.path.join(THIS_PATH, '..', '..', '..', '..', '..'))


def make_run_folder(args):
    sub_dir = ['checkpoints', 'dl']
    run_dir = ['runs']
    if 'mipet' in args.model_type:
        sub_dir.append('mipet')
        run_dir.append('mipet')
    elif 'cfasl' in args.model_type:
        sub_dir.append('cfasl')
        run_dir.append('cfasl')
    elif 'cmcs' in args.model_type:
        sub_dir.append('cmcs')
        run_dir.append('cmcs')
    elif 'uni' in args.model_type:
        sub_dir.append('uni')
        run_dir.append('uni')
    else:
        sub_dir.append('baselines')
        run_dir.append('baselines')

    sub_dir.append(args.dataset)
    sub_dir.append(args.model_type)
    run_dir.append(args.dataset)
    run_dir.append(args.model_type)

    sub_name = '/'.join(sub_dir)
    folder_name = os.path.join(ROOT_PATH, sub_name)
    run_name = '/'.join(run_dir)
    run_folder_name = os.path.join(ROOT_PATH, run_name)
    return folder_name, run_folder_name


def make_run_files(args):

    folder_name, run_folder_name = make_run_folder(args)

    args_dict = vars(args)

    file_keys = []
    file_values = []
    for key, value in args_dict.items():
        if key == 'results_file' or \
           key == 'optimizer' or \
           key == 'epoch' or \
           key == 'lr_rate' or \
           key == 'seed' or \
           key == 'weight_decay' or \
           key == 'train_batch_size' or \
           key == 'latent_dim' or \
           key == 'st' or \
           ('hy_' in key and value != None) or \
           ('eq_prob' in key and value != None) or \
           ('sizes_ls' in key and value != None) or \
           (key == 'discri_lr_rate' and value != None) or \
           (key == 'sub_lr_rate' and value != None) or \
           (key == 'num_inv_equ' and value != None) or \
           (key == 'mask' and value != None) or \
           (key == 'th' and value != None) or \
           (key == 'sub_sec' and value != None) or \
           (key == 'nth_root' and value != None) or \
           (key == 'alpha' and value != None) or \
           (key == 'beta' and value != None) or \
           (key == 'gamma' and value != None) or \
           (key == 'epsilon' and value != None):

            file_keys.append(key)
            file_values.append(str(value))

    combined = ['{}_{}'.format(key, value) for key, value in zip(file_keys, file_values)]
    file_name = '_'.join(combined)
    # pattern = '{}_{}'
    # file = '_{}'.join(pattern.format(x, '{}') for x in file_name)

    output_dir = os.path.join(folder_name, file_name)
    run_output_dir = os.path.join(run_folder_name, file_name)
    return output_dir, run_output_dir, folder_name
