from .base import Baseline
import numpy as np
import os
import time
import pickle as pkl
from openbox.surrogate.base.build_gp import create_gp_model
from openbox.surrogate.base.rf_with_instances import RandomForestWithInstances
from openbox.acquisition_function.acquisition import EI
from openbox.utils.util_funcs import get_types
from openbox.utils.config_space.util import convert_configurations_to_array

from .acq_optimizer.local_random import InterleavedLocalAndRandomSearch
from utils.space import get_space
from utils.time import time_limit, TimeoutException


class RisingBandit(Baseline):
    def __init__(self, config_space, eval_func, iter_num=200, save_dir='./results', task_name='default',
                 surrogate_type='prf'):
        super().__init__(config_space, eval_func, iter_num, save_dir, task_name)

        self.arm_choices = config_space.get_hyperparameters()[0].choices
        self.arms = dict()
        self.acq_funcs = dict()
        self.acq_optimizers = dict()
        self.sub_space = dict()

        for arm_name in self.arm_choices:
            sub_space = get_space(algorithm_set=[arm_name])  # Only support large space
            self.sub_space[arm_name] = sub_space
            types, bounds = get_types(sub_space)
            if surrogate_type == 'gp':
                surrogate = create_gp_model(model_type='gp',
                                            config_space=sub_space,
                                            types=types,
                                            bounds=bounds,
                                            rng=self.rng)
            elif surrogate_type == 'prf':
                surrogate = RandomForestWithInstances(types=types, bounds=bounds, seed=self.seed)
            else:
                raise ValueError("Surrogate type %s not supported!" % surrogate_type)
            self.arms[arm_name] = surrogate
            self.acq_funcs[arm_name] = EI(surrogate)
            self.acq_optimizers[arm_name] = InterleavedLocalAndRandomSearch(
                acquisition_function=self.acq_funcs[arm_name],
                config_space=sub_space, rng=self.rng)

        self.timestamp = time.time()
        self.save_path = os.path.join(self.save_dir, '%s_%s_%d_%s.pkl' % (task_name, 'rb', iter_num, self.timestamp))

        self.alpha = 3
        self.trial_per_action = 3
        self.init_num = self.trial_per_action * len(self.arm_choices) * self.alpha
        self.lower_bounds = dict()
        self.upper_bounds = dict()
        self.rewards = {arm_name: list() for arm_name in self.arm_choices}
        self.candidates = list(self.arm_choices)
        self.candidate_idx = 0
        self.trial_in_action = 0

    def sample(self):
        num_config_evaluated = len(self.observations)

        candidate_name = self.candidates[self.candidate_idx]
        sub_space = self.sub_space[candidate_name]
        print(candidate_name)

        if num_config_evaluated < self.init_num:  # Sample initial configurations randomly
            repeated_flag = True
            while repeated_flag:
                repeated_flag = False
                config = sub_space.sample_configuration()
                for observation in self.observations:
                    if config == observation[0]:
                        repeated_flag = True
                        break

            return config

        sub_observations = list()
        for observation in self.observations:
            if observation[0]['algorithm'] == candidate_name:
                sub_observations.append(observation)

        X = convert_configurations_to_array([observation[0] for observation in sub_observations])
        Y = np.array([observation[1] for observation in sub_observations])

        self.arms[candidate_name].train(X, Y)

        self.acq_funcs[candidate_name].update(model=self.arms[candidate_name],
                                              eta=self.incumbent_value,
                                              num_data=len(sub_observations))

        challengers = self.acq_optimizers[candidate_name].maximize(observations=sub_observations,
                                                                   num_points=5000)

        repeated_flag = True
        repeated_time = 0
        cur_config = None
        while repeated_flag:
            repeated_flag = False
            cur_config = challengers.challengers[repeated_time]
            for observation in self.observations:
                if cur_config == observation[0]:
                    repeated_flag = True
                    repeated_time += 1
                    break

        return cur_config

    def update(self, config, val_perf, test_perf, val_pred, test_pred, time, left_budget):
        if val_perf < self.incumbent_value:
            self.incumbent_value = val_perf
            self.incumbent_config = config
        self.observations.append((config, val_perf, test_perf, val_pred, test_pred, time))

        self.trial_in_action += 1
        if self.trial_in_action == self.trial_per_action:
            candidate_name = self.candidates[self.candidate_idx]
            sub_observations = list()
            for observation in self.observations:
                if observation[0]['algorithm'] == candidate_name:
                    sub_observations.append(observation)
            self.rewards[candidate_name].append(-min([observation[1] for observation in sub_observations]))
            self.trial_in_action = 0
            self.candidate_idx += 1

        if self.candidate_idx == len(self.candidates):
            if len(self.observations) > self.init_num:  # Eliminate arm
                upper_bounds, lower_bounds = list(), list()
                for candidate_name in self.candidates:
                    rewards = self.rewards[candidate_name]
                    slope = (rewards[-1] - rewards[-self.alpha]) / self.alpha
                    steps = int(left_budget / self.trial_per_action)
                    upper_bound = np.min([1.0, rewards[-1] + slope * steps])
                    upper_bounds.append(upper_bound)
                    lower_bounds.append(rewards[-1])

                # Reject the sub-optimal arms.
                n = len(self.candidates)
                flags = [False] * n
                for i in range(n):
                    for j in range(n):
                        if i != j:
                            if upper_bounds[i] < lower_bounds[j]:
                                flags[i] = True

                self.candidates = [item for index, item in enumerate(self.candidates) if not flags[index]]
                print("Candidates after reject:")
                print(self.candidates)

            self.candidate_idx = 0

    def run(self, time_limit_per_trial=30):
        iter_cnt = 0
        for iter in range(self.iter_num):
            config = self.sample()
            start_time = time.time()
            try:
                with time_limit(time_limit_per_trial):
                    val_obj, test_obj, val_pred, test_pred = self.eval_func(config)
                runtime = time.time() - start_time
                print('Iter: %d, Obj: %f, Test obj: %f, Eval time: %f' % (iter, val_obj, test_obj, runtime))
            except TimeoutException as e:
                print('Time out!')
                val_obj, test_obj, val_pred, test_pred = np.inf, np.inf, None, None
                runtime = time.time() - start_time
                print('Iter: %d, Failed Obj: %f, Test obj: %f, Eval time: %f' % (iter, val_obj, test_obj, runtime))
            except Exception as e:
                print(e)
                val_obj, test_obj, val_pred, test_pred = np.inf, np.inf, None, None
                runtime = time.time() - start_time
                print('Iter: %d, Failed Obj: %f, Test obj: %f, Eval time: %f' % (iter, val_obj, test_obj, runtime))
            iter_cnt += 1
            self.update(config, val_obj, test_obj, val_pred, test_pred, runtime, self.iter_num - iter_cnt)
            with open(self.save_path, 'wb') as f:
                pkl.dump(self.observations, f)
