import copy

import numpy as np
import scipy
import statsmodels.api as sm

#import bohb.configspace as cs
import configspace as cs
import dask

from utils_SD import get_sigma
from utils_SD import get_next_round_n_model
from utils_SD import get_pair_wise_confidence


class KDEMultivariate(sm.nonparametric.KDEMultivariate):
	def __init__(self, configurations):
		self.configurations = configurations
		data = []
		for config in configurations:
			data.append(np.array(config.to_list()))
		data = np.array(data)
		super().__init__(data, configurations[0].kde_vartypes, 'normal_reference')


class Log():
	def __init__(self, size):
		self.size = size
		self.logs = np.empty(self.size, dtype=dict)
		self.best = {'loss': np.inf}

	def __getitem__(self, index):
		return self.logs[index]

	def __setitem__(self, index, value):
		self.logs[index] = value

	def __repr__(self):
		string = []
		string.append(f's_max: {self.size}')
		for s, log in enumerate(self.logs):
			string.append(f's: {s}')
			for budget in log:
				string.append(f'Budget: {budget}')
				string.append(f'Loss: {log[budget]["loss"]}')
				string.append(str(log[budget]['hyperparameter']))
		string.append('Best Hyperparameter Configuration:')
		string.append(f'Budget: {self.best["budget"]}')
		string.append(f'Loss: {self.best["loss"]}')
		string.append(str(self.best['hyperparameter']))
		return '\n'.join(string)


class BOHB:
	def __init__(self, configspace, evaluate, max_budget, min_budget,
				 eta=3, best_percent=0.15, random_percent=1/3, n_samples=64,
				 bw_factor=3, min_bandwidth=1e-3, n_proc=1):
		self.eta = eta
		self.configspace = configspace
		self.max_budget = max_budget
		self.min_budget = min_budget
		self.evaluate = evaluate

		self.best_percent = best_percent
		self.random_percent = random_percent
		self.n_samples = n_samples
		self.min_bandwidth = min_bandwidth
		self.bw_factor = bw_factor
		self.n_proc = n_proc

		self.s_max = int(np.log(self.max_budget/self.min_budget) / np.log(self.eta))
		self.budget = (self.s_max + 1) * self.max_budget

		self.kde_good = None
		self.kde_bad = None
		self.samples = np.array([])

	def optimize(self, threshold_ss = 0.5):
		print('Lalala ----------------------------------------------------')
		print('self.max_budget', self.max_budget)
		logs = Log(self.s_max+1)
		total_budget_used = 0
		for s in reversed(range(self.s_max + 1)):
			logs[s] = {}
			n = int(np.ceil(
				(self.budget * (self.eta ** s)) / (self.max_budget * (s + 1))))

			if n >= 50:
				n = 50

			r = self.max_budget * (self.eta ** -s) 
			r = self.max_budget * (self.eta ** -s) # for sub-sampling instead of boss
			print('!----------------------------- r = ', r)
			self.kde_good = None
			self.kde_bad = None
			self.samples = np.array([])

			total_resources_per_model = 0 #### Add BOSD

			boss_losses = [[] for i in range(n)]
			boss_epochs = [0 for i in range(n)]

			#sample

			sample_test = self.get_sample()
			print(sample_test.to_dict())


			# sub sampling arrange
			id2order = [0] * 50
			order2id = []
			counter = 0
			for j in range(n):
				#print('j = ', j)
				if j < 10 and (j != 4 and j !=5 and j!=6):
					continue
				id2order[j] = counter
				order2id.append(j)
				counter += 1
				#print('j = ', j)
				b = int(r * self.eta)
				for plus in range(b):
					boss_losses[j].append(self.evaluate({'integer' : j}, plus + 1))

				boss_epochs[j] = b

			print('first completed!')
			
			for i in range(1, s+1):
				print('round number = :', i)
				print(boss_losses)
				n_i = n * self.eta ** (-i)  # Number of configs
				r_i = r * self.eta ** (i)  # Budget
				print("r_i = !!!!!!!", r_i)

				total_resources_per_model = r_i ### Add OBSD
				
				logs[s][r_i] = {'loss': np.inf}

				samples = []
				losses = []

				print('boss epochs :' , boss_epochs)
				idxc = np.argmax(boss_epochs)
				print('idxc : ', idxc)
				
				most_epochs = boss_epochs[int(idxc)]
				preference_set = []
				for j in range(n):
					this_epochs = boss_epochs[j]
					if this_epochs == 0 :
						continue

					if this_epochs < np.sqrt(np.log(most_epochs)): ##
						preference_set.append(j)
						continue

					# pairwise confidence check
					pairwise_confidence = get_pair_wise_confidence(boss_losses[idxc][boss_epochs[idxc] - 1], get_sigma(boss_epochs[idxc]), boss_losses[j][boss_epochs[j] - 1], get_sigma(boss_epochs[j])) 

					upper_threshold_ss = 0.95
					if pairwise_confidence > upper_threshold_ss: # added upper threshold for ss_plus 1225 TODO: find this is beneficial to the original SS or the SS plus
						preference_set.append(j)
						continue

					if pairwise_confidence < threshold_ss:
						continue

					this_average = np.mean(boss_losses[j])

					if this_epochs >= most_epochs:
						continue

					flag = False
					for k in range(most_epochs - this_epochs + 1):
						'''
						print('idxc = ', idxc)
						print('k = ', k)
						print('this_epochs = ', this_epochs)
						'''
						if this_average < np.mean(boss_losses[idxc][k : k + this_epochs]):   #### should add boss_losses for each config
							flag = True
							break

					if flag == True:
						print('more potential : ', j)
						preference_set.append(j)

				print('idxc = ', idxc)
				print('preference set : ', preference_set)

				if len(preference_set) == 0:
					for plus in range(int(r_i)):
						boss_losses[idxc].append(self.evaluate({'integer' : idxc}, boss_epochs[idxc] + plus + 1))
					boss_epochs[idxc] += int(r_i)
				else:
					for idx in preference_set:
						for plus in range(int(r_i / len(preference_set))):
							boss_losses[idx].append(self.evaluate({'integer' : idx}, boss_epochs[idx] + plus + 1))
						boss_epochs[idx] += int(r_i / len(preference_set))

				for j in range(n):
					if j < 10 and (j !=4 and j !=5 and j !=6):
						continue
					losses.append(boss_losses[j][boss_epochs[j] - 1])


				print('losses : ', losses)
				midx = np.argmin(losses)
				logs[s][r_i]['loss'] = losses[midx]
#				logs[s][r_i]['hyperparameter'] = samples[midx]
				logs[s][r_i]['hyperparameter'] = {'Integer' : midx}

				if logs[s][r_i]['loss'] < logs.best['loss']:
					logs.best['loss'] = logs[s][r_i]['loss']
					logs.best['budget'] = r_i
					logs.best['hyperparameter'] = logs[s][r_i]['hyperparameter']


				n = n

			break # for sub-sampling instead of boss


		print(logs.best['hyperparameter'])
		return logs.best['hyperparameter'], total_budget_used

	def get_sample(self):
		if self.kde_good is None or np.random.random() < self.random_percent:
			if len(self.samples):
				idx = np.random.randint(0, len(self.samples))
				sample = self.samples[idx]
				self.samples = np.delete(self.samples, idx)
				return sample
			else:
				return self.configspace.sample_configuration()

		# Sample from the good data
		best_tpe_val = np.inf
		for _ in range(self.n_samples):
			idx = np.random.randint(0, len(self.kde_good.configurations))
			configuration = copy.deepcopy(self.kde_good.configurations[idx])
			for hyperparameter, bw in zip(configuration, self.kde_good.bw):
				if hyperparameter.type == cs.Type.Continuous:
					value = hyperparameter.value
					bw = bw * self.bw_factor
					hyperparameter.value = scipy.stats.truncnorm.rvs(
						-value/bw, (1-value)/bw, loc=value, scale=bw)
				elif hyperparameter.type == cs.Type.Discrete:
					if np.random.rand() >= (1-bw):
						idx = np.random.randint(len(hyperparameter.choices))
						hyperparameter.value = idx
				else:
					raise NotImplementedError

			tpe_val = (self.kde_bad.pdf(configuration.to_list()) /
					   self.kde_good.pdf(configuration.to_list()))
			if tpe_val < best_tpe_val:
				best_tpe_val = tpe_val
				best_configuration = configuration

		return best_configuration


local_test=False
if local_test:
	import sys
	sys.path.append("../examples/")
	sys.path.append("../")
	from data.nasbench2.nats import get_accuracy
#from data.nasbench2.nats import load_test_data

	def objective(resources, rank2index = [i for i in range(50)], integer=0, checkpoint=None):
		return get_accuracy(integer, resources, is_tss=True) / 100, integer


	def evaluate(params, n_iterations):
		loss = 0.0
		if n_iterations >= 200:
			n_iterations = 199

		loss = -objective(**params, resources=n_iterations)[0]
		return loss

	integer = cs.CategoricalHyperparameter('integer', [i for i in range(50) if i >= 10 or i == 4 or i ==5 or i ==6])
	configspace = cs.ConfigurationSpace([integer])

	budget_ratio=0.1
	for _ in range(10):
		print('test id : ', _)
		best_list = []
		ratio_list = [0.03, 0.06, 0.09, 0.12, 0.15, 0.18, 0.21, 0.24, 0.27, 0.3, 0.33, 0.36, 0.39, 0.42, 0.45, 0.48, 0.51, 0.54, 0.57, 0.6, 0.63, 0.66, 0.69, 0.72, 0.75, 0.78, 0.81, 0.84, 0.87, 0.90, 0.93, 0.96, 1.0]
		for ratio in ratio_list:
			budget_ratio = ratio
			opt = BOHB(configspace, evaluate, max_budget=1250 * budget_ratio, min_budget=1)

# Parallel
# opt = BOHB(configspace, evaluate, max_budget=10, min_budget=1, n_proc=4)

			logs, total_budget_used = opt.optimize()
#		best_config = logs.best['hyperparameter'][0].value
			best_config = logs['Integer'] + 1
#	best_score = logs.best['hyperparameter'][0]
			print('best hyperparameter :  ', best_config)
			best_list.append(best_config)
			print('best list : ', best_list)
			print('Budget saving : ', total_budget_used / (1250 * budget_ratio))
#		break

		break


	exit()

'''
from eval_model_quality_BOHB import test_BOHB

rank2index = load_test_data(0)
test_BOHB(rank2index)

exit()

from bohb import BOHB
import configspace as cs


def objective(step, alpha, beta):
	return 1 / (alpha * step + 0.1) + beta


def evaluate(params, n_iterations):
	loss = 0.0
	for i in range(int(n_iterations)):
		loss += objective(**params, step=i)
	return loss/n_iterations


if __name__ == '__main__':
	alpha = cs.CategoricalHyperparameter('alpha', [0.001, 0.01, 0.1])
	beta = cs.CategoricalHyperparameter('beta', [1, 2, 3])
	configspace = cs.ConfigurationSpace([alpha, beta])

	opt = BOHB(configspace, evaluate, max_budget=10, min_budget=1)

	# Parallel
	# opt = BOHB(configspace, evaluate, max_budget=10, min_budget=1, n_proc=4)

	logs = opt.optimize()
	print(logs)
'''
