import os

THIS_PATH = os.path.dirname(__file__)
ROOT_PATH = os.path.abspath(os.path.join(THIS_PATH, '..', '..', '..', '..'))
# IMAGE_PATH = os.path.join(ROOT_PATH, 'checkpoints/combgen/dsprites')


def make_finetuning_folder(args):
    # args.cg: [pre_trained, fine_tuning]
    sub_dir = ['checkpoints', 'jcgel_cls', args.task]
    run_dir = ['runs']
    if 'cgeresnet' in args.model_type:
        sub_dir.append('cgeresnet')
        run_dir.append('cgeresnet')
    elif 'cgeconv' in args.model_type:
        sub_dir.append('cgeconv')
        run_dir.append('cgeconv')

    sub_dir.append(args.model_type)
    sub_dir.append(args.dataset)
    run_dir.append(args.model_type)
    run_dir.append(args.dataset)

    sub_name = '/'.join(sub_dir)
    folder_name = os.path.join(ROOT_PATH, sub_name)
    run_name = '/'.join(run_dir)
    run_folder_name = os.path.join(ROOT_PATH, run_name)
    return folder_name, run_folder_name


def make_finetuning_files(args):

    folder_name, run_folder_name = make_finetuning_folder(args)

    args_dict = vars(args)

    file_keys = []
    file_values = []
    for key, value in args_dict.items():
        if key == 'results_file' or \
           key == 'optimizer' or \
           key == 'epoch' or \
           key == 'lr_rate' or \
           key == 'seed' or \
           key == 'weight_decay' or \
           key == 'train_batch_size' or \
           key == 'latent_dim' or \
           key == 'c_rot' or \
           key == 'g_rot' or \
           key == 'n_flip' or \
           key == 'temperature' or \
           key == 'normalization' or \
           key == 'soft' or \
           (key == 'sub_lr_rate' and value != None) or \
           (key == 'kernel_aug' and value != None) or \
           (key == 'aug' and value != None) or \
           (key == 'fiber_group' and value != None):

            file_keys.append(key)
            file_values.append(str(value))

    combined = ['{}_{}'.format(key, value) for key, value in zip(file_keys, file_values)]
    file_name = '_'.join(combined)
    # pattern = '{}_{}'
    # file = '_{}'.join(pattern.format(x, '{}') for x in file_name)

    output_dir = os.path.join(folder_name, file_name)
    run_output_dir = os.path.join(run_folder_name, file_name)
    return output_dir, run_output_dir, folder_name
