import sys
import os
import warnings
warnings.filterwarnings('ignore')
current_dir = os.getcwd()
smac_dir = current_dir + '/SMAC3'

sys.path.append(smac_dir)

import logging
import math
import json
import numpy as np
import matplotlib.pyplot as plt
from prior_acq import Prior, PriorLogEI, PriorEI, PriorLCB, PriorPI
from setup import setup_experiment
# Import ConfigSpace and different types of parameters
from smac.facade.smac_bo_facade import SMAC4BO
from smac.facade.smac_hpo_facade import SMAC4HPO
from smac.initial_design.latin_hypercube_design import LHDesign
from smac.optimizer.acquisition import LCB, EI, PI, LogEI
from smac.runhistory.runhistory2epm import RunHistory2EPM4InvScaledCost
from smac.configspace import Configuration
# Import SMAC-utilities
from smac.scenario.scenario import Scenario

def main(args=None):
    '''
    All variable configurations are entered in the shell scripts. This is not really optimal,
    but as the number of configurations that have needed to be run has increased, this has 
    been necessary. It's bad practice, but the system arguments are used a bit all over the place - 
    mainly in this file, but in the setup file too. The arguments to be entered are the following:

    Config type (arg 1): String of the form {model_type}_{beta}_{acq_function}, e.g. gp_10_EI
        or rf_10_LogEI. Thus, this is the one argument that really decides model type, prior decay
        and acquisition function. The acquisitio
    print(sys.argv)n functions have not been experimented with, we've 
        always run EI for GP and LogEI for RF.

    Function (arg 2): This specifies which function is to be run, i.e. Branin or any of the meta
        benchmarks. If you look in the Shell scripts, this is already given there.

    Prior (arg 3): Specifies prior strength - strong, weak, wrong or noprior (other names can be
        given too, assuming there's a prior for it. Since the prior names are {function}_{strength}.json,
        anything can really be added to the priors folder. 'noprior' defaults to SMAC without priors.)

    Initialization (arg 4): Specifies the method for initializing - prior or sobol. With all prior-
        weighted runs, we've been running with samples from the prior as the initialization. I've tested
        sobol a bit for myself, but it hasn't proven to be better.

    Prior number (arg 5): Which prior to use for this certain smac run.
    '''
    if args== None:
        config_type = sys.argv[1]
        function = sys.argv[2]
        prior = sys.argv[3]
        sampling = sys.argv[4]
        prior_index = sys.argv[5] # even without a prior, it's passed in the shell script
    else: 
        config_type, function, prior, sampling, prior_index = args
    
    if config_type != 'sampling':
        model, beta, acq_function = config_type.split('_')
        if model == 'rf':
            acq_function, space_fraction = acq_function[0:5], int(acq_function[5:]) 
        beta = float(beta)

    else:
        model = config_type
        beta = 1
        acq_function = 'EI'
    
    with open(f'base_configs/config.json', 'r') as f:
        run_config = json.load(f)

    run_config['model'] = model
    run_config['prior_beta'] = beta
    run_config['acq_function'] = acq_function
    run_config['use_prior_samples'] = sampling == 'prior'

    # these should be moved to shell script

    smac_run = prior == 'noprior'
    if sys.argv[4] == 'warping':
        smac_run = True
        method = 'warping'
        output_dirname = 'warping'

    else:
        method='bopro'
        output_dirname = config_type
    
    run_config['prior_strength'] = prior # already solved in prior rows
    run_config['smac_run'] = smac_run # already solved in prior rows
    run_config['function_family'] = function # already solved in prior rows
    model = run_config['model']
    smac_run = run_config['smac_run']
    prior_beta = run_config['prior_beta']
    n_iterations = run_config['n_iterations']
    path_to_files = run_config['path_to_files']
    prior_data_path = run_config['prior_data_path']
    function_family = run_config['function_family']
    function_id = run_config['function_id']
    seed = run_config['seed']
    prior_strength = run_config['prior_strength']
    acq_function_name = run_config['acq_function']
    use_prior_samples = run_config['use_prior_samples']

    cs, prior, bb_function = setup_experiment(
        function_id, 
        function_family, 
        method=method, 
        prior_strength=prior_strength, 
        prior_index=prior_index,
        smac_run=smac_run
        )
    dims = cs.get_hyperparameter_names()
    n_dims = len(dims)
    logging.basicConfig(level=logging.INFO)  # logging.DEBUG for debug output
    #print(cs)
    # Scenario object
    scenario = Scenario({"run_obj": "quality",  # we optimize quality (alternatively runtime)
                        "runcount-limit": n_iterations,  # max. number of function evaluations; for this example set to a low number
                        "cs": cs,  # configuration space - DEFINED ABOVE
                        "deterministic": "true", 
                        'abort_on_first_run_crash': True,
                        'output_dir': output_dirname
                        })
    acq_func_options = {
        'LogEI': {'Normal': LogEI, 'Prior': PriorLogEI},
        'EI': {'Normal': EI, 'Prior': PriorEI},
        'LCB': {'Normal': LCB, 'Prior': PriorLCB},
        'PI': {'Normal': PI, 'Prior': PriorPI},
    }

    acq_funcs = [acq_func_options[acq_function_name]['Normal'],
                acq_func_options[acq_function_name]['Prior']]
    acq_names = ['Normal', 'Prior']
    acq_kwargs = { 
                acq_names[0]: 
                {
                },
                acq_names[1]:
                {
                'prior': prior,
                'beta': prior_beta,
                }}
    if model == 'rf':
        acq_kwargs[acq_names[1]]['space_fraction'] = space_fraction
    #print(cs._hyperparameter_idx)

    for (acquisition_func, name) in zip(acq_funcs, acq_kwargs):

        if 'Prior' in name:
            if smac_run:
                continue 
            #print("Optimizing with %s! Depending on your machine, this might take a few minutes." % acquisition_func)
            if use_prior_samples:
                if use_prior_samples:
                    if model == 'sampling':
                        init_configs = prior.sample(n_iterations-1, normalize=False)
                    else:
                        init_configs = prior.sample(n_dims, normalize=False)
                    init_configs = np.append(init_configs, prior.get_max_location(), axis=0)
                #print('INIT CONFIGS', init_configs, len(init_configs))
                
                init_kwargs = {'configs': init_configs}
            else:
                #print('Using SMAC default sampling!')
                init_kwargs = {'init_budget': n_dims+1}
        
        # unnecessary, but to be safe
        else:
            if not smac_run:
                continue
            
            #print("Optimizing with %s! Depending on your machine, this might take a few minutes." % acquisition_func)
            init_kwargs = {'init_budget': n_dims+1}
        
        if model == 'rf':

            smac = SMAC4HPO(scenario=scenario,
                        tae_runner=bb_function,
                        #initial_design=LHDesign,
                        initial_design_kwargs=init_kwargs,
                        #runhistory2epm=RunHistory2EPM4InvScaledCost,
                        #acquisition_function_optimizer_kwargs={'max_steps': 100},
                        acquisition_function=acquisition_func,
                        acquisition_function_kwargs=acq_kwargs[name]
                        )
        else:
            if 'mcmc' in model:
                model_type = 'gp_mcmc'
            else:
                model_type = 'gp'

            smac = SMAC4BO(model_type=model_type,
                        scenario=scenario,
                        tae_runner=bb_function,
                        #initial_design=LHDesign,
                        initial_design_kwargs=init_kwargs,
                        #runhistory2epm=RunHistory2EPM4InvScaledCost,
                        #acquisition_function_optimizer_kwargs={'max_steps': 100},
                        acquisition_function=acquisition_func,
                        acquisition_function_kwargs=acq_kwargs[name]
                        )
        
        smac.optimize()

if __name__ == '__main__':

    main()