import random
import os
from pathlib import Path


DEBUG = False

PERCENT_OF_SEARCH_SPACE_TO_SEARCH = 1.0
MAX_NUM_COMBINATIONS = 100

CV = 0
LIMIT_HRS = 20
hyperparams = {'data': ['sim_easy', 'sim_semi_markov', 'sim_hard', 'har', 'cpap'],
               'ds_factor': [40], # 100 for har_70, 50 for everything else
               'iters': [10001],
               'init_way': 'prior',
               'k_max': 20,
               'c_pri,d_pri':[(1, 1), (2, 1)], # (1, 1) is uniform prior on kappa
               'alpha0_a_pri,alpha0_b_pri': [(1, 1), (2, 1)],
               'gamma0_a_pri,gamma0_b_pri': [(1, 1)],
               'S0_type,S0_factor': [('emperical_cov', 0.75), ('emperical_cov', 1)],
               'V0_type,V0_factor': [('identity', 0.1), ('identity', 1)],#, ('identity', 3)],
               'M0_type,M0_factor': [('zeros', 0)]}


'''
hyperparams = {'data': ['har'],
               'ds_factor': [50], # 100 for har_70, 50 for everything else
               'iters': [10001],
               'init_way': 'prior',
               'k_max': 20,
               'c_pri,d_pri':[(1, 1)], 
               'alpha0_a_pri,alpha0_b_pri': [(1, 1)],
               'gamma0_a_pri,gamma0_b_pri': [(1, 1)],
               'S0_type,S0_factor': [('emperical_cov', 1)],
               'V0_type,V0_factor': [('identity', 1)],#, ('identity', 3)],
               'M0_type,M0_factor': [('zeros', 0)]}
'''



combos_tried = [] # Hyper paramter combinations tried
possible_num_combos = 1


for key in hyperparams:
    if type(hyperparams[key]) == list:
        possible_num_combos *= len(hyperparams[key])


while True:
    hyper_param_dict = {}
    for key in hyperparams:
        if type(hyperparams[key]) == list:
            choice = random.choice(hyperparams[key])
            if type(choice) == tuple:
                hyper_param_dict[key.split(',')[0]] = choice[0]
                hyper_param_dict[key.split(',')[1]] = choice[1]
            else:
                hyper_param_dict[key] = choice
        else:
            hyper_param_dict[key] = hyperparams[key]
    

    if hyper_param_dict not in combos_tried:
        job_file = 'hyper_param_opt.sh'
        if job_file in os.listdir():
            os.remove(job_file) # Remove it so we start with an empty file
        
        dataset_chosen = hyper_param_dict['data']
        long_name = ''
        for hp in hyper_param_dict:
            long_name += hp + '_' + str(hyper_param_dict[hp]) + '_'
        long_name = long_name[:-1]
        
        if not os.path.exists(dataset_chosen):
            os.makedirs(dataset_chosen)

        file = Path(dataset_chosen + '/id_mapper.txt')
        file.touch(exist_ok=True)
        combo_exists = False
        short_name = ''
        max_short_name = 0
        with open(dataset_chosen + '/id_mapper.txt', 'r') as f:
            for line in f.readlines():
                if long_name in line:
                    combo_exists=True
                    short_name = line.split(':')[0]
                try:
                    max_short_name = max(max_short_name, int(line.split(':')[0]))
                except ValueError:
                    pass
        if not combo_exists:
            short_name = '0'*(4-len(str(max_short_name+1))) + str(max_short_name+1)
            with open(dataset_chosen + '/id_mapper.txt', 'a') as f:
                f.write('\n')
                f.write(short_name + ':' + long_name)
        
        
            if not os.path.exists(dataset_chosen + '/' + short_name):
                os.makedirs(dataset_chosen + '/' + short_name)
            text_file = './' +  dataset_chosen + '/' + short_name + '/output_%d.txt'%CV
            
            
            with open(text_file, 'w') as f:
                pass # Simply create the file so that the text file is there even if the slurm job hasn't been queued yet.
        
            with open(job_file, 'w') as fh:
                fh.writelines("#!/bin/bash\n")
                fh.writelines("#SBATCH --job-name=%s\n" % (dataset_chosen + short_name))
                fh.writelines("#SBATCH --output=%s\n" % (text_file))
                fh.writelines('#SBATCH --qos=cpu_qos\n')
                fh.writelines("#SBATCH --time=%d:00:00\n"%LIMIT_HRS)
                fh.writelines("#SBATCH --partition='cpu'\n")
                fh.writelines("#SBATCH --mem=32G\n")
                fh.writelines("#SBATCH --open-mode=append\n")

                python_call = "python -u run_full_bayesian_approx_parallel_gibbs_ar_efox.py "
                for key in hyper_param_dict:
                    s = '--'
                    s += key
                    s += ' '
                    s += str(hyper_param_dict[key])
                    s += ' '
                    python_call += s
                python_call += ' --short_name ' + short_name
                python_call += ' --cv ' + str(CV)

                fh.writelines('%s\n'%python_call)
            os.system("sbatch %s" %job_file)

            combos_tried.append(hyper_param_dict)
    if len(combos_tried) >= int(PERCENT_OF_SEARCH_SPACE_TO_SEARCH*possible_num_combos) or len(combos_tried) >= MAX_NUM_COMBINATIONS:
        break




