import copy

import numpy as np
import scipy
import statsmodels.api as sm

#import bohb.configspace as cs
import configspace as cs
import dask

from utils_SD import get_sigma
from utils_SD import get_next_round_n_model


class KDEMultivariate(sm.nonparametric.KDEMultivariate):
	def __init__(self, configurations):
		self.configurations = configurations
		data = []
		for config in configurations:
			data.append(np.array(config.to_list()))
		data = np.array(data)
		super().__init__(data, configurations[0].kde_vartypes, 'normal_reference')


class Log():
	def __init__(self, size):
		self.size = size
		self.logs = np.empty(self.size, dtype=dict)
		self.best = {'loss': np.inf}

	def __getitem__(self, index):
		return self.logs[index]

	def __setitem__(self, index, value):
		self.logs[index] = value

	def __repr__(self):
		string = []
		string.append(f's_max: {self.size}')
		for s, log in enumerate(self.logs):
			string.append(f's: {s}')
			for budget in log:
				string.append(f'Budget: {budget}')
				string.append(f'Loss: {log[budget]["loss"]}')
				string.append(str(log[budget]['hyperparameter']))
		string.append('Best Hyperparameter Configuration:')
		string.append(f'Budget: {self.best["budget"]}')
		string.append(f'Loss: {self.best["loss"]}')
		string.append(str(self.best['hyperparameter']))
		return '\n'.join(string)


class BOHB:
	def __init__(self, configspace, evaluate, max_budget, min_budget,
				 eta=3, best_percent=0.15, random_percent=1/3, n_samples=64,
				 bw_factor=3, min_bandwidth=1e-3, n_proc=1):
		self.eta = eta
		self.configspace = configspace
		self.max_budget = max_budget
		self.min_budget = min_budget
		self.evaluate = evaluate

		self.best_percent = best_percent
		self.random_percent = random_percent
		self.n_samples = n_samples
		self.min_bandwidth = min_bandwidth
		self.bw_factor = bw_factor
		self.n_proc = n_proc

		self.s_max = int(np.log(self.max_budget/self.min_budget) / np.log(self.eta))
		self.budget = (self.s_max + 1) * self.max_budget

		self.kde_good = None
		self.kde_bad = None
		self.samples = np.array([])

	def optimize(self):
		print('Lalala ----------------------------------------------------')
		logs = Log(self.s_max+1)
		total_budget_used = 0
		for s in reversed(range(self.s_max + 1)):
			logs[s] = {}
			n = int(np.ceil(
				(self.budget * (self.eta ** s)) / (self.max_budget * (s + 1))))
			r = self.max_budget * (self.eta ** -s)
			self.kde_good = None
			self.kde_bad = None
			self.samples = np.array([])

			total_resources_per_model = 0 #### Add BOSD
			for i in range(s+1):
				n_i = n * self.eta ** (-i)  # Number of configs
				r_i = r * self.eta ** (i)  # Budget

				total_resources_per_model = r_i ### Add OBSD
				
				logs[s][r_i] = {'loss': np.inf}

				samples = []
				losses = []
				for j in range(n):
					sample = self.get_sample()
					if self.n_proc > 1:
						loss = dask.delayed(self.evaluate)(sample.to_dict(), int(r_i))
					else:
						loss = self.evaluate(sample.to_dict(), int(total_resources_per_model)) ### Add  BOSD
#						loss = self.evaluate(sample.to_dict(), int(r_i)) ### Add  BOSD
						total_budget_used += total_resources_per_model
					samples.append(sample)
					losses.append(loss)
				print("n : {}, r_i : {}".format(n, r_i))

				if self.n_proc > 1:
					losses = dask.compute(
						*losses, scheduler='processes', num_workers=self.n_proc)
				midx = np.argmin(losses)
				logs[s][r_i]['loss'] = losses[midx]
				logs[s][r_i]['hyperparameter'] = samples[midx]

				if logs[s][r_i]['loss'] < logs.best['loss']:
					logs.best['loss'] = logs[s][r_i]['loss']
					logs.best['budget'] = r_i
					logs.best['hyperparameter'] = logs[s][r_i]['hyperparameter']

				# next round update

				idxs = np.argsort(losses)

				models = [i[0].value for i in self.samples]
				print(models)
				
				n = int(np.ceil(n_i/self.eta))

				self.samples = np.array(samples)[idxs[:n]]

				models = [i[0].value for i in self.samples]
				print(models)

				n_good = int(np.ceil(self.best_percent * len(samples)))
				if n_good > len(samples[0].kde_vartypes) + 2:
					print('n_good : ', n_good)
					print('kde_var : ', samples[0].kde_vartypes)
					good_data = np.array(samples)[idxs[:n_good]]
					bad_data = np.array(samples)[idxs[n_good:]]
					self.kde_good = KDEMultivariate(good_data)
					self.kde_bad = KDEMultivariate(bad_data)
					self.kde_bad.bw = np.clip(
						self.kde_bad.bw, self.min_bandwidth, None)
					self.kde_good.bw = np.clip(
						self.kde_good.bw, self.min_bandwidth, None)
		print('total budget used : ', total_budget_used)
		return logs, total_budget_used

	def get_sample(self):
		if self.kde_good is None or np.random.random() < self.random_percent:
			if len(self.samples):
				idx = np.random.randint(0, len(self.samples))
				sample = self.samples[idx]
				self.samples = np.delete(self.samples, idx)
				return sample
			else:
				return self.configspace.sample_configuration()

		# Sample from the good data
		best_tpe_val = np.inf
		for _ in range(self.n_samples):
			idx = np.random.randint(0, len(self.kde_good.configurations))
			configuration = copy.deepcopy(self.kde_good.configurations[idx])
			for hyperparameter, bw in zip(configuration, self.kde_good.bw):
				if hyperparameter.type == cs.Type.Continuous:
					value = hyperparameter.value
					bw = bw * self.bw_factor
					hyperparameter.value = scipy.stats.truncnorm.rvs(
						-value/bw, (1-value)/bw, loc=value, scale=bw)
				elif hyperparameter.type == cs.Type.Discrete:
					if np.random.rand() >= (1-bw):
						idx = np.random.randint(len(hyperparameter.choices))
						hyperparameter.value = idx
				else:
					raise NotImplementedError

			tpe_val = (self.kde_bad.pdf(configuration.to_list()) /
					   self.kde_good.pdf(configuration.to_list()))
			if tpe_val < best_tpe_val:
				best_tpe_val = tpe_val
				best_configuration = configuration

		return best_configuration


import sys
sys.path.append("../examples/")
sys.path.append("../")
from data.nasbench2.nats import get_accuracy
#from data.nasbench2.nats import load_test_data

def objective(resources, rank2index = [i for i in range(50)], integer=0, checkpoint=None):
	return get_accuracy(integer, resources, is_tss=True) / 100, integer


def evaluate(params, n_iterations):
	loss = 0.0
	if n_iterations >= 200:
		n_iterations = 199

	loss = -objective(**params, resources=n_iterations)[0]
	return loss

integer = cs.CategoricalHyperparameter('integer', [i for i in range(50) if i >= 10 or i == 4 or i ==5 or i ==6])
configspace = cs.ConfigurationSpace([integer])

for _ in range(10):
	print('test id : ', _)
	best_list = []
	ratio_list = [0.03, 0.06, 0.09, 0.12, 0.15, 0.18, 0.21, 0.24, 0.27, 0.3, 0.33, 0.36, 0.39, 0.42, 0.45, 0.48, 0.51, 0.54, 0.57, 0.6, 0.63, 0.66, 0.69, 0.72, 0.75, 0.78, 0.81, 0.84, 0.87, 0.90, 0.93, 0.96, 1.0]
	for ratio in ratio_list:
#		print(ratio )
		budget_ratio = ratio
		opt = BOHB(configspace, evaluate, max_budget=1250 * budget_ratio, min_budget=1)

# Parallel
# opt = BOHB(configspace, evaluate, max_budget=10, min_budget=1, n_proc=4)

		logs, total_budget_used = opt.optimize()
		best_config = logs.best['hyperparameter'][0].value
#	best_score = logs.best['hyperparameter'][0]
		print('best hyperparameter :  ', best_config)
		best_list.append(best_config)
		print('best list : ', best_list)
		print('total : ', total_budget_used)
		print('ratio : ', budget_ratio)
		print('/ : ', 1250 * budget_ratio)
		print('Budget saving : ', total_budget_used / (1250 * budget_ratio))



exit()
from eval_model_quality_BOHB import test_BOHB

rank2index = load_test_data(0)
test_BOHB(rank2index)

exit()

from bohb import BOHB
import configspace as cs


def objective(step, alpha, beta):
	return 1 / (alpha * step + 0.1) + beta


def evaluate(params, n_iterations):
	loss = 0.0
	for i in range(int(n_iterations)):
		loss += objective(**params, step=i)
	return loss/n_iterations


if __name__ == '__main__':
	alpha = cs.CategoricalHyperparameter('alpha', [0.001, 0.01, 0.1])
	beta = cs.CategoricalHyperparameter('beta', [1, 2, 3])
	configspace = cs.ConfigurationSpace([alpha, beta])

	opt = BOHB(configspace, evaluate, max_budget=10, min_budget=1)

	# Parallel
	# opt = BOHB(configspace, evaluate, max_budget=10, min_budget=1, n_proc=4)

	logs = opt.optimize()
	print(logs)
