import os, sys
import subprocess
import json
import ast

import ConfigSpace as CS
import ConfigSpace.hyperparameters as CSH
from ConfigSpace.read_and_write import json as config_space_json_r_w
from hpbandster.core.worker import Worker


class Regression_worker(Worker):
    def __init__(self, eta, min_budget, max_budget, operations, layers, exp_name, dataset, variant, noise_level, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.path = '.'
        self.eta = eta
        self.min_budget = min_budget
        self.max_budget = max_budget
        self.operations = operations
        self.layers = layers
        self.exp_name = exp_name
        self.dataset = dataset
        self.variant = variant
        self.noise_level = noise_level

    def compute(self, config, budget, config_id, working_directory):
        return(
                run_config(
                    config=config,
                    budget=int(budget),
                    min_budget=self.min_budget,
                    eta=self.eta,
                    config_id=config_id,
                    directory=working_directory,
                    operations=self.operations,
                    layers= self.layers,
                    exp_name=self.exp_name,
                    dataset=self.dataset,
                    variant=self.variant,
                    noise_level=self.noise_level,

                )
            )
        return cs

    @staticmethod
    def get_config_space():
        # Load the configspace object
        config_path='configs/darts_1_configspace.json'
        cs = config_space_json_r_w.read(open(config_path, 'r').read())



#         cs = CS.ConfigurationSpace()
#         cs.add_hyperparameter(CSH.UniformFloatHyperparameter('param_lr', lower=0.0005, upper=0.01, log=True, default_value=0.0005))
#         cs.add_hyperparameter(CSH.UniformFloatHyperparameter('alpha_lr', lower=0.0003, upper=0.01, log=True, default_value=0.0003))
#         cs.add_hyperparameter(CSH.UniformFloatHyperparameter('weight_decay', lower=0.0001, upper=0.01, log=False, default_value=0.0001))
# #         cs.add_hyperparameter(CSH.UniformIntegerHyperparameter('batch_size', lower=32, upper=128, log=False, default_value=128))
#         cs.add_hyperparameter(CSH.CategoricalHyperparameter('batch_size', choices=[32, 128], default_value=128))
#         cs.add_hyperparameter(CSH.CategoricalHyperparameter('alpha_warmup', choices=[True, False], default_value=False))
#         cs.add_hyperparameter(CSH.CategoricalHyperparameter('alpha_scheduler', choices=['none', 'linear'], default_value='none'))
#         cs.add_hyperparameter(CSH.CategoricalHyperparameter('alpha_optimizer', choices=['GD', 'Adam'], default_value='GD'))
        return cs


def load_data(dest_dir):
    info = {}

    with open(os.path.join(dest_dir, 'results.txt'), 'r') as f:
        data = [ast.literal_eval(json.loads(line)) for line in f.readlines()]

    with open(os.path.join(dest_dir,'log.txt'), 'r') as fh:
        info['config'] = '\n'.join(fh.readlines())

    info['loss'] = [d['loss'] for d in data]
#    info['Arch_Train'] = [d['Arch_Train'] for d in data]
    info['Train_Loss'] = [d['Train_Loss'] for d in data]
    info['Validation_Loss'] = [d['Validation_Loss'] for d in data]
    info['Trunc_Loss'] = [d['Trunc_Loss'] for d in data]
    info['Train_Metric'] = [d['Train_Metric'] for d in data]
    info['Trunc_Metric'] = [d['Trunc_Metric'] for d in data]

    return info


def run_config(config, budget, min_budget, eta, config_id, directory, operations, layers, exp_name, dataset, variant, noise_level):
    dest_dir = os.path.join(directory, '_'.join(map(str, config_id)))
    if not os.path.exists(dest_dir):
        os.makedirs(dest_dir)

    ret_dict = {'Train_Loss':float('inf'), 'info': None}

    # cs=get_config_space()
    # config = cs.sample_configuration().get_dictionary()
    print(config)
    try:
        bash_string = [" python search_IP_NAS.py  --bohb",
                       "--operations {}".format(operations),
                       "--layers {}".format(layers),
                       "--exp_name {}".format(exp_name),
                       "--dataset {}".format(dataset),
                       "--variant {}".format(variant),
                       "--noise_level {}".format(noise_level),
                       "--save_dir {}".format(dest_dir),
                       "--epochs {}".format(budget),
                       "--param_lr {param_lr}".format(**config),
                       "--alpha_lr {alpha_lr}".format(**config),
                       "--param_weight_decay {param_weight_decay}".format(**config),
                       "--alpha_weight_decay {alpha_weight_decay}".format(**config),
                       "--alpha_warmup {alpha_warmup}".format(**config),
                       "--param_warmup {param_warmup}".format(**config),
                       "--alpha_scheduler {alpha_scheduler}".format(**config),
                       "--alpha_optimizer {alpha_optimizer}".format(**config)]

        subprocess.check_call(" ".join(bash_string), shell=True)
        info = load_data(dest_dir)
        ret_dict = {'loss': info['loss'][-1], 'info': info}
    except:
        raise

    return ret_dict
