
"""Generic training script that train a model using a given dataset."""

import os
import sys
import time
import glob
import torch
import logging
import argparse
import numpy as np
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
import torch.backends.cudnn as cudnn
from timeit import default_timer as timer
import torchvision.transforms as transforms

sys.path.append("../models")
sys.path.append("../../")
sys.path.append("../")
import utils

parser = argparse.ArgumentParser("Evaluating a model")
parser.add_argument('--model_path', type=str, default='../pre_trained_models/squeeze_complex_bypass.pt', help='model directory')
parser.add_argument('--checkpoint_path', type=str, default='../EXP', help='checkpoint and logging directory')
parser.add_argument('--dataset_path', type=str, default='../../data', help='dataset directory')
parser.add_argument('--model_name', type=str, default='SuccessiveDiscarding', help='name of model')
parser.add_argument('--batch_size', type=int, default=256, help='batch size')
parser.add_argument('--gpu', type=int, default=0, help='gpu device id')
args = parser.parse_args()

''' Logging settings '''
utils.create_exp_dir(args.checkpoint_path)
args.checkpoint_path = '{}/{}-{}'.format(args.checkpoint_path, args.model_name, time.strftime("%Y%m%d-%H%M%S"))
utils.create_exp_dir(args.checkpoint_path)
log_format = '%(asctime)s %(message)s'
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
	format=log_format, datefmt='%m/%d %I:%M:%S %p')
fh = logging.FileHandler(os.path.join(args.checkpoint_path, 'log.txt'))
fh.setFormatter(logging.Formatter(log_format))
logging.getLogger().addHandler(fh)


from data.nasbench2.nats import prepare_dataset
from data.nasbench2.nats import get_accuracy

'''Load test data'''
def load_test_data(test_id = 0):
	rank2index = prepare_dataset(test_id)

	return rank2index


from baseline_SH import successive_halving
print("***************************************")
from hpo import successive_discarding
from baseline_HB import hyperband

from skopt.space import Real, Integer

'''Evaluation'''
def test_SH(rank2index = [i for i in range(50)]):
	def objective(resources, rank2index = rank2index,  integer=0, checkpoint=None):
		return get_accuracy(integer, resources, is_tss=True, rank2index=rank2index), integer

	synthesized_dimensions = [Integer(0, 4, name='integer')]

	scores, hyperparameters = successive_halving(
			objective=objective,
			dimensions=synthesized_dimensions,
			max_resources_per_model=80,
			downsample=2,
			initial_resources=5,
			n_models=50,
			random_seed=None,
			progress_bar=True)

	return scores, hyperparameters

def test_SD(rank2index = [i for i in range(50)], budget_ratio=1.0):
	def objective(resources, rank2index = rank2index, integer=0, checkpoint=None):
		return get_accuracy(integer, resources, is_tss=True, rank2index=rank2index) / 100, integer

	synthesized_dimensions = [Integer(0, 4, name='integer')]

	scores, hyperparameters = successive_discarding(
			objective = objective,
			dimensions = synthesized_dimensions,
			max_resources_per_round = 250,
			total_budgets = 250*5 * budget_ratio,
			threshold = 0.9,
			initial_resources=3,
			n_models = 50,
			random_seed=None,
			progress_bar=True,
			test=True)

	return scores, hyperparameters

def test_HB(rank2index = [i for i in range(50)], budget_ratio=1.0):
	def objective(resources, rank2index = rank2index, integer=0, checkpoint=None):
		return get_accuracy(integer, resources, is_tss=True, rank2index=rank2index) / 100, integer

	synthesized_dimensions = [Integer(0, 49, name='integer')]

	scores, hyperparameters = hyperband(
			objective = objective,
			dimensions = synthesized_dimensions,
			max_resources_per_model = 80,
			total_resources = 250*5 * budget_ratio,
			random_seed=None,
			progress_bar=True)


	results = sorted(zip(scores, hyperparameters), key = lambda k: -k[0])

	print(results)
	scores = results[0][0]
	hyperparameters = results[0][1]['integer']

	logging.info('scores: {}'.format(scores))
	logging.info('hyperparameters: {}'.format(hyperparameters))

	return scores, hyperparameters

from bohb import BOHB
import configspace as cs


def test_BOHB(rank2index, budget_ratio=1.0):
	def objective(resources, rank2index = rank2index, integer=0, checkpoint=None):
		return get_accuracy(integer, resources, is_tss=True) / 100, integer


	def evaluate(params, n_iterations):
		loss = 0.0
		if n_iterations >= 200:
			n_iterations = 199

		loss = -objective(**params, resources=n_iterations)[0]
		return loss

	integer = cs.CategoricalHyperparameter('integer', [i for i in range(50) if i >= 10 or i == 4 or i ==5 or i ==6])
	configspace = cs.ConfigurationSpace([integer])

	opt = BOHB(configspace, evaluate, max_budget=1250 * budget_ratio * 0.05, min_budget=1)

	# Parallel
	# opt = BOHB(configspace, evaluate, max_budget=10, min_budget=1, n_proc=4)

	logs = opt.optimize()
	'''
	print("logs ------- ")
	print(logs)
	print("logs [0] ------- ")
	print(logs[0].best)
	'''

	best_config = logs[0].best['hyperparameter'][0].value
#	best_score = logs.best['hyperparameter'][0]
	print('best hyperparameter :  ', best_config)

	return best_config



def test(rank2index):


	scores_SH, hyperparameters_SH = test_SH(rank2index)

	scores_SD, hyperparameters_SD = test_SD(rank2index)
	print("print testing !--!  SH --- SD")
	print(str(scores_SH))
#	logging.info(str(scores_SH))
	logging.info('Successive Halving : {}{}'.format(str(scores_SH), str(hyperparameters_SH)))
	logging.info('Successive Discarding : {}{}'.format(scores_SD, hyperparameters_SD))

	print(scores_SH, hyperparameters_SH)

#	print(scores_SD, hyperparameters_SD)

def main():

	logging.info('Starting...')

	rank2index = []
	test_id = 1

	tmp = []
	for test_id in range(10, 100):

		logging.info('test id : {}'.format(test_id))

		rank2index = load_test_data(test_id)
		print(rank2index)
		result = test_BOHB(rank2index)

		tmp.append(result)
		logging.info('current results : {}'.format(tmp))
	
	for i in tmp:
		if i < 10:
			print(i - 3)
		else:
			print(i - 6)
	
	for i in tmp:
		if i < 10:
			logging.info(i-3)
		else:
			logging.info(i-6)


print("***************************************")


if __name__ == '__main__':
    main()
