import math
import numpy as np, os, sys, re, glob, subprocess, math, unittest, shutil, time, string, logging, gc
np.set_printoptions(precision=4)
from time import gmtime, strftime
from random import shuffle, choice, sample, choices
import random
from itertools import product
from functools import partial
import inspect
import itertools
from quan_decomp import *
import copy
from typing import List
from collections import OrderedDict, defaultdict


# create realted folders
base_folder = './'
try:
	os.mkdir(base_folder+'center_log')
	os.mkdir(base_folder+'agent_log')
	os.mkdir(base_folder+'agent_pool')
	os.mkdir(base_folder+'job_pool')
	os.mkdir(base_folder+'result_pool')
except:
	pass

class Individual:

	def __init__(self, qtn_seq=None, scope=None, **kwargs):
			
		self.qtn_seq = qtn_seq

		# set the decomposition args
		self.scope = scope
		self.repeat = kwargs['qtn_evaluate_repeat'] if 'qtn_evaluate_repeat' in kwargs.keys() else 1
		self.iters = kwargs['qtn_max_iterations'] if 'qtn_max_iterations' in kwargs.keys() else 10000
		self.qtn_Q = kwargs['qtn_Q']
		self.qtn_R = kwargs['qtn_R']
		self.init_std = kwargs['qtn_init_std'] if 'qtn_init_std' in kwargs.keys() else 0.1
		self.loss_measure = kwargs['qtn_loss_measure']

	# deploy the qtn_seq to the pool
	def deploy(self, sge_job_id):
		try:
			path = base_folder+'/job_pool/{}.npz'.format(sge_job_id)
			np.savez(path, qtn_seq=self.qtn_seq, scope=self.scope, repeat=self.repeat, iters=self.iters, qtn_Q=self.qtn_Q, qtn_R=self.qtn_R, init_std=self.init_std, loss_measure=self.loss_measure)
			self.sge_job_id = sge_job_id
			return True
		except Exception as e:
			raise e

	# collect the result from the pool
	def collect(self, fake_loss=False):
		if not fake_loss:
			try:
				path = base_folder+'/result_pool/{}.npz'.format(self.scope.replace('/', '_'))
				result = np.load(path)
				self.repeat_loss = result['repeat_loss']

				os.remove(path)
				return True
			
			except Exception:
				return False
		else:
			self.repeat_loss = [9999]*self.repeat
			return True

class GenerationQTN:

	# initialize the generation
	def __init__(self, previous_generation=None, name=None, mode=None, **kwargs):

		self.name = name
		self.mode = mode
		self.kwargs = kwargs
		self.gate_num = self.kwargs['qtn_L']
		self.qtn_Q = self.kwargs['qtn_Q']
		self.qtn_K = self.kwargs['qtn_K']
		self.qtn_R = self.kwargs['qtn_R']
		self.look_up_num = self.kwargs['sb_look_up_num'] if 'sb_look_up_num' in self.kwargs.keys() else 1
		self.population_size = self.kwargs['sb_population_size'] if 'sb_population_size' in self.kwargs.keys() else 10
		self.sb_loss_tolerance = self.kwargs['sb_loss_tolerance'] if 'sb_loss_tolerance' in self.kwargs.keys() else 0.01
		self.is_pool_full = False
		self.indv_to_collect = []
		self.individuals = {}

		self.sb_sampling_strategy = self.kwargs['sb_sampling_strategy'] if 'sb_sampling_strategy' in self.kwargs.keys() else 'random'
		if previous_generation is not None:
			self.best_qtn = previous_generation.best_qtn
			self.min_loss = previous_generation.min_loss
			self.sb_loss_list = previous_generation.sb_loss_list
			self.gates_pool = previous_generation.gates_pool
			self.total_evaluation_count = previous_generation.total_evaluation_count
		else:
			self.initial_qtn = self.kwargs['sb_init_qtn']
			self.min_loss = self.kwargs['sb_init_loss']
			self.best_qtn = self.initialize_qtn_sb(self.initial_qtn)
			self.sb_loss_list = [self.min_loss for _ in range(self.initial_qtn.shape[1]*len(self.best_qtn))]
			self.gates_pool = []
			self.total_evaluation_count = 0
		logging.info('\033[32m===== SB Start with QTN {} and loss {} =====\033[0m'.format(array_to_gate(self.best_qtn), self.min_loss))
		logging.info('\033[32m===== Gates Indexs Pool {} =====\033[0m'.format(self.gates_pool))

		self.gate_list = self.update_gate_list()
		self.population = self.generate_population_sb()

		# initialize the population
		self.individuals['indv'] = [Individual(qtn_seq=self.population[i],scope='{}/{:03d}'.format(self.name, i), **self.kwargs) for i in range(len(self.population))]
		self.indv_to_distribute = [indv for indv in self.individuals['indv']]
    
	def __call__(self, **kwargs):
		try:
			self.__evaluate__()
			if 'callbacks' in kwargs.keys():
				for c in kwargs['callbacks']:
					c(self)
			self.__evolve_sb__()
			return True
		except Exception as e:
			raise e

	def __evolve_sb__(self):
		
		if len(self.gates_pool) == 0:
			loc = self.remove_gate_idx[self.first_rank_idx]
            #self.sb_loss_list = [loss for loss in self.sb_loss_list if loss != self.current_loss]
			self.sb_loss_list.pop(loc)
			self.gate_list.pop(loc)
			if self.merge_locs[self.first_rank_idx] is not None:
				tensor_idx = self.merge_locs[self.first_rank_idx][0]
				gate_loc = self.merge_locs[self.first_rank_idx][1][0]
				logging.info(f"merge loss {gate_loc} and {gate_loc+1} of tensor {tensor_idx}, in the location of loss_list {loc}")
				min_value = min(self.sb_loss_list[loc-1], self.sb_loss_list[loc])
				self.sb_loss_list[loc-1] = min_value
				self.sb_loss_list.pop(loc)
				self.gate_list.pop(loc)
				logging.info('Current gate_list: {}'.format(self.gate_list))
				logging.info('Current loss_list: {}'.format(self.sb_loss_list))

	def __evaluate__(self):
		# RANKING
		self.individuals['loss'] = [np.min(indv.repeat_loss) for indv in self.individuals['indv'] ]
		self.individuals['rank'] =  self.get_sorted_positions(self.individuals['loss'])
		#self.individuals['fitness'] = [max(0.01, math.log(200/(1e-2+5*self.individuals['rank'][i]))) for i in range(len(self.individuals['rank']))]

		# find the lowest indx that below the threshold self.loss_tolerance
		look_up_block = math.ceil(len(self.individuals['loss'])/self.look_up_num)
		for i in range(look_up_block):
			self.total_evaluation_count += self.look_up_num
			# look at five consecutive elements and find the lowest loss
			look_up_list = []
			for j in range(i*self.look_up_num, (i+1)*self.look_up_num):
				if j < len(self.individuals['loss']):
					look_up_list.append(self.individuals['loss'][j])
					self.gates_pool.append(self.remove_gate_idx[j])
					self.sb_loss_list[self.remove_gate_idx[j]] = self.individuals['loss'][j]
			min_loss = min(look_up_list)
			min_indx = look_up_list.index(min_loss)
			if min_loss < self.sb_loss_tolerance:
				min_indx_total= i*self.look_up_num + min_indx
				self.best_qtn = self.individuals['indv'][min_indx_total].qtn_seq
				self.min_loss = self.individuals['loss'][min_indx_total]
				self.individuals['rank'][min_indx_total] = -1
				self.gates_pool = []
				break
		
		if len(self.gates_pool)>=len(self.gate_list):
			self.is_pool_full = True
			logging.info('Gates pool is full!!!') 
		
		ranks = self.individuals['rank']
		self.first_rank_idx = ranks.index(min(ranks))
	
	def initialize_qtn_sb(self, qtn_layer) -> List[dict]:
		qtn_layers = {}
		for i in range(num_of_tensors):
			qtn_layers[i] = copy.deepcopy(qtn_layer)
		return qtn_layers
	
	def generate_population_sb(self):

		if self.sb_sampling_strategy == 'rank':
			gate_list_with_loss = [tuple((self.gate_list[i], self.sb_loss_list[i])) for i in range(len(self.gate_list))]
			logging.info('Current gate_list and loss: {}'.format(gate_list_with_loss))

		self.remove_gate_idx = self.generate_remove_gates()

		remove_gates_list = []
		for i in range(len(self.remove_gate_idx)):
			if type(self.remove_gate_idx[i]) == np.ndarray:
				remove_per_indv = [self.gate_list[j] for j in self.remove_gate_idx[i]]
				remove_gates_list.append(self.merge_lists_by_first_element(remove_per_indv))
			else:
				remove_per_indv = self.gate_list[self.remove_gate_idx[i]]
				remove_gates_list.append([remove_per_indv])
			

		new_population = []
		merge_locs = []
		logging.info(f"'===== Symmetric breaking process ====='")
		for remove_gates in remove_gates_list:
			parent = copy.deepcopy(self.best_qtn)
			logging.info(f"====================================")
			for items in remove_gates:
				tensor_idx = items[0]
				gate_idx = items[1:]
				logging.info(f"remove number {gate_idx} from tensor {tensor_idx}")
				parent[tensor_idx] = np.delete(parent[tensor_idx], gate_idx, axis=1)
				parent[tensor_idx], merge_gate_locs = self.merge_duplicate_gates(parent[tensor_idx])
			if len(merge_gate_locs) > 0:
				merge_gate_locs_end = [idx+1 for idx in merge_gate_locs]
				logging.info(f"merge gates {merge_gate_locs} and {merge_gate_locs_end} of tensor {tensor_idx}")
				locs = [tensor_idx, merge_gate_locs]
			else:
				locs = None
			#logging.info(f"add population {parent}")
			new_population.append(parent)
			merge_locs.append(locs)

		self.merge_locs = merge_locs

		return new_population
	
	def weighted_sample_no_replacement(self, items, weights, k):

		result = []
		for _ in range(k):
			chosen = random.choices(items, weights=weights, k=1)[0]
			result.append(chosen)
			index = items.index(chosen)
			items.pop(index)
			weights.pop(index)
		return result
	
	def generate_remove_gates(self):

		idx_all = range(len(self.gate_list))
		idx_not_in_pool = [i for i in idx_all if i not in self.gates_pool]

		if self.sb_sampling_strategy == 'random':
			remove_gate_idx = random.sample(idx_not_in_pool, min(self.population_size,len(idx_not_in_pool))) #int(len(remove_gate_list)*0.3))
		elif self.sb_sampling_strategy == 'rank':
			ranks = self.get_sorted_positions(self.sb_loss_list)
			fitness = [max(0.01, math.log(200/(1e-2+5*ranks[i]))) for i in range(len(ranks))]
			fitness_not_in_pool = [fitness[i] for i in idx_not_in_pool]
			remove_gate_idx = self.weighted_sample_no_replacement(idx_not_in_pool, weights=fitness_not_in_pool, k=min(self.population_size,len(idx_not_in_pool)))
		else:
			raise ValueError(f"Invalid sampling strategy: {self.sb_sampling_strategy}")
			
		return remove_gate_idx

	def merge_lists_by_first_element(self, lists):
		merged_dict = defaultdict(list)
		
		for lst in lists:
			first_element = lst[0]
			merged_dict[first_element].extend(lst[1:])
		
		merged_lists = [[key] + value for key, value in merged_dict.items()]
		
		return merged_lists
	
	def update_gate_list(self):
		
		gate_list = []

		for tensor_idx in range(num_of_tensors):
			matrix = self.best_qtn[tensor_idx]
			num_columns = matrix.shape[1]
			
			for col_idx in range(num_columns):
				remove_gates = [tensor_idx, col_idx]
				gate_list.append(remove_gates)
			
		return gate_list
	
	def merge_duplicate_gates(self, seq):
		# remove adjcent elements in qtn_layers that are the same (no more than one duplicate gates)
		locs = []
		for i in range(seq.shape[1]-1):
			if np.array_equal(seq[:, i], seq[:, i+1]):
				locs.append(i)
		if len(locs) > 0:
			seq = np.delete(seq, locs, axis=1)
		return seq, locs

	def get_sorted_positions(self, lst):
		
		sorted_lst = sorted(enumerate(lst), key=lambda x: x[1])
		positions = [0] * len(lst)

		for rank, (index, value) in enumerate(sorted_lst):
			positions[index] = rank
		return positions

	def distribute_indv(self, agent):
		if self.indv_to_distribute:
			indv = self.indv_to_distribute.pop(0)
			agent.receive(indv)
			self.indv_to_collect.append(indv)
			logging.info('Assigned individual {} to agent {}.'.format(indv.scope, agent.sge_job_id))

	def collect_indv(self):
		for indv in self.indv_to_collect:
			if indv.collect():
				logging.info('Collected individual result {}.'.format(indv.scope))
				self.indv_to_collect.remove(indv)

	def is_finished(self):
		if len(self.indv_to_distribute) == 0 and len(self.indv_to_collect) == 0:
			return True
		else:
			return False

class Agent(object):
	def __init__(self, **kwargs):
		super(Agent, self).__init__()
		self.kwargs = kwargs
		self.sge_job_id = self.kwargs['sge_job_id']

	def receive(self, indv):
		indv.deploy(self.sge_job_id)
		with open(base_folder+'/agent_pool/{}.POOL'.format(self.sge_job_id), 'a') as f:
			f.write(target_data)

	def is_available(self):
		return True if os.stat(base_folder+'/agent_pool/{}.POOL'.format(self.kwargs['sge_job_id'])).st_size == 0 else False


class PipelineSB(object):

	def __init__(self, gss_max_iteration=100, sb_max_iteration=10, **kwargs):
		super(PipelineSB, self).__init__()
		self.kwargs = kwargs
		self.dummy_func = lambda *args, **kwargs: None
		self.max_gss_iteration = gss_max_iteration
		self.max_symbreaking_iteration = sb_max_iteration
		self.current_generation = None
		self.previous_generation = None
		self.N_group_ss = 0
		self.N_symbreaking = 0
		self.generation_list = {}
		self.available_agents = []
		self.best_qtn_list = []
		self.min_loss_list = []
		self.evaluation_count_list = []
		self.known_agents = {}
		self.time = 0
		self.initial_qtn = kwargs['sb_init_qtn']
		self.initial_loss = kwargs['sb_init_loss']

	def __call_with_interval__(self, func, interval):
		return func if self.time%interval == 0 else self.dummy_func

	def __tik__(self, sec):
		# logging.info(self.time)
		self.time += sec
		time.sleep(sec)

	def __check_available_agent__(self):
		self.available_agents.clear()
		agents = glob.glob(base_folder+'/agent_pool/*.POOL')
		agents_id = [ a.split('/')[-1][:-5] for a in agents ]

		for aid in list(self.known_agents.keys()):
			if aid not in agents_id:
				logging.info('Dead agent id = {} found!'.format(aid))
				self.known_agents.pop(aid, None)

		for aid in agents_id:
			if aid in self.known_agents.keys():
				if self.known_agents[aid].is_available():
					self.available_agents.append(self.known_agents[aid])
			else:
				self.known_agents[aid] = Agent(sge_job_id=aid)
				logging.info('New agent id = {} found!'.format(aid))

	def __assign_job__(self):
		self.__check_available_agent__()
		if len(self.available_agents)>0:
			for agent in self.available_agents:
				self.current_generation.distribute_indv(agent)

	def __collect_result__(self):
		self.current_generation.collect_indv()

	def __report_agents__(self):
		logging.info('Current number of known agents is {}.'.format(len(self.known_agents)))
		logging.info(list(self.known_agents.keys()))

	def __report_generation__(self):
		logging.info('Current length of indv_to_distribute is {}.'.format(len(self.current_generation.indv_to_distribute)))
		logging.info('Current length of indv_to_collect is {}.'.format(len(self.current_generation.indv_to_collect)))
		logging.info([(indv.scope, indv.sge_job_id) for indv in self.current_generation.indv_to_collect])

	def __chromosome_to_tuple__(self, chromosome: np.ndarray) -> tuple:
		return tuple(tuple(col) for col in chromosome.T)
	
	def __tuple_to_chromosome__(self, chromosome_tuple: tuple) -> np.ndarray:
		return np.array(chromosome_tuple).T

	def __update_all_chromosomes__(self, population: List[np.ndarray], losses: List[float]):
		for chrom, loss in zip(population, losses):
			chrom_tuple = self.__chromosome_to_tuple__(chrom)
			if chrom_tuple not in self.generation_list or loss < self.generation_list[chrom_tuple]:
				self.generation_list[chrom_tuple] = loss

        # Sort the OrderedDict by loss values
		self.generation_list = OrderedDict(sorted(self.generation_list.items(), key=lambda x: x[1]))
		
	def __symbreaking__(self):
		if self.N_symbreaking > self.max_symbreaking_iteration:
			return False
		else:
			if self.N_symbreaking == 0:
				#self.kwargs['sb_initial_qtn'] = self.initial_qtn
				#self.kwargs['sb_initial_loss'] = self.initial_loss
				self.current_generation = GenerationQTN(name='symbreaking_init', mode='symmetric_breaking', **self.kwargs)
				self.N_symbreaking += 1
				#self.current_generation.indv_to_distribute = []

			if self.current_generation.is_finished():
				if self.N_symbreaking > 0:
					#self.current_population = self.current_generation.population
					self.current_generation(**self.kwargs)
					self.best_qtn_list.append(self.current_generation.best_qtn)
					self.min_loss_list.append(self.current_generation.min_loss)
					self.evaluation_count_list.append(self.current_generation.total_evaluation_count)
					logging.info('\033[31m===== Best QTN_SEQ of No.{} Symbreaking =====\033[0m'.format(self.N_symbreaking))
					logging.info('\033[31mQtn: {} | Loss: {:.5f}\033[0m'.format(array_to_gate(self.best_qtn_list[-1]), self.min_loss_list[-1]))
					logging.info('\033[31m===== Total number of Evaluations: {}=====\033[0m'.format(self.current_generation.total_evaluation_count))
					#self.__update_all_chromosomes__(self.current_population, self.current_generation.individuals['loss'])
				self.N_symbreaking += 1
				self.previous_generation = self.current_generation
				if self.current_generation.is_pool_full == False:
					self.current_generation = GenerationQTN(self.previous_generation, 
														name='symbreaking_{:03d}'.format(self.N_symbreaking), mode='symmetric_breaking', **self.kwargs)

			return True
		

	def __call__(self):
		
		logging.info('\033[32m===== Load from GSS with QTN {} and loss {} =====\033[0m'.format(array_to_gate(self.initial_qtn), self.initial_loss))
		while self.__symbreaking__():
			self.__call_with_interval__(self.__check_available_agent__, 4)()
			self.__call_with_interval__(self.__assign_job__, 4)()
			self.__call_with_interval__(self.__collect_result__, 4)()
			self.__call_with_interval__(self.__report_agents__, 180)()
			self.__call_with_interval__(self.__report_generation__, 160)()
			self.__tik__(2)
			if self.current_generation.is_pool_full == True:
				break
		logging.info('\033[32m===== SB Finished =====\033[0m')
		logging.info('===== Best QTN and loss for each iteration=====')
		for i in range(len(self.best_qtn_list)):
			logging.info('Qtn: {} | Loss: {:.5f}'.format(array_to_gate(self.best_qtn_list[i]), self.min_loss_list[i]))

def array_to_gate(qtn_seq: np.ndarray) -> List[tuple]:
	#return [tuple(qtn_seq[:, i]) for i in range(qtn_seq.shape[1])]
	if type(qtn_seq) == np.ndarray:
		#qtn_layer = [tuple(qtn_seq[:, i]) for i in range(qtn_seq.shape[1])]
		qtn_layer = [tuple(map(int, tuple(qtn_seq[:, i]))) for i in range(qtn_seq.shape[1])]
	elif type(qtn_seq) == dict:
		qtn_layer = {}
		for k, v in qtn_seq.items():
			#qtn_layer[k] = [tuple(v[:, i]) for i in range(v.shape[1])]
			qtn_layer[k] = [tuple(map(int, tuple(v[:, i]))) for i in range(v.shape[1])]
	return qtn_layer

def gate_to_array(qtn_layer: List[tuple]) -> np.ndarray:
	return np.array([list(qtn) for qtn in qtn_layer]).T

def score_summary(obj):
	logging.info('\033[32m===== {} =====\033[0m'.format(obj.name))

	block_count = 0

	for idx, indv in enumerate(obj.individuals['indv']):
		if block_count%sb_look_up_num == 0:
			logging.info('-------------------------------------Block {}-------------------------------------------------'.format(int(block_count//sb_look_up_num)))
		if -1 == obj.individuals['rank'][idx]:
			logging.info('\033[31m{} | repeat_losses: {} | loss: {:.5f} | rank: {} | qtn_seq: {}\033[0m'.format(indv.scope, [ float('{:0.4f}'.format(l)) for l in indv.repeat_loss ], obj.individuals['loss'][idx],  obj.individuals['rank'][idx], array_to_gate(indv.qtn_seq)))
			#logging.info(array_to_gate(indv.qtn_seq))
		else:
			logging.info('{} | repeat_losses: {} | loss: {:.5f} | rank: {} | qtn_seq: {}'.format(indv.scope, [ float('{:0.4f}'.format(l)) for l in indv.repeat_loss ], obj.individuals['loss'][idx],  obj.individuals['rank'][idx], array_to_gate(indv.qtn_seq)))
		block_count += 1

if __name__ == '__main__':

	target_data = sys.argv[1]
	gss_mode = sys.argv[2]
	sb_mode = sys.argv[3]
	
	if len(sys.argv) > 1:
		print(f"target_data: {sys.argv[1]}")
	if len(sys.argv) > 2:
		print(f"gss_mode: {sys.argv[2]}")
	if len(sys.argv) > 3:
		print(f"sb_mode: {sys.argv[3]}")
	if len(sys.argv) > 4:
		loss_measure = sys.argv[4]
	else:
		loss_measure = 'total'
	print(f"Loss measure: {loss_measure}")

	# load data and information
	data = torch.load(target_data)
	qtn_R = data['qtn_R']
	qtn_Q = data['qtn_Q']
	qtn_K = data['qtn_K']
	mask = data['mask']
	std_init = data['std_init']
	num_of_tensors = len(data['data'])

	target_data_name = target_data.split('.')[0]

	gss_mode_sep = gss_mode.split('_')
	gss_strategy = gss_mode_sep[1]
	qtn_L = int(gss_mode_sep[2])
	ga_population_size = int(gss_mode_sep[3])
	gss_max_iteration = int(gss_mode_sep[4])

	# load gss data
	gss_data = torch.load('gss_results/sim_{}_{}_{}.pt'.format(target_data_name, gss_mode, loss_measure))
	gss_qtn = gss_data['gss_qtn']
	gss_loss = gss_data['gss_loss']

	current_time = strftime("%Y%m%d_%H%M%S", gmtime())

	log_name = 'center_log/sim_{}_{}_{}_{}.log'.format(target_data_name, gss_mode, sb_mode, current_time)
	logging.basicConfig(filename=log_name, filemode='a', level=logging.DEBUG,
											format='%(asctime)s: %(message)s', datefmt='%H:%M:%S')
	console = logging.StreamHandler()
	console.setLevel(logging.INFO)
	formatter = logging.Formatter('%(asctime)s:  %(message)s')
	console.setFormatter(formatter)
	logging.getLogger('').addHandler(console)

	sb_mode_sep = sb_mode.split('_')
	sb_strategy = sb_mode_sep[1]
	sb_look_up_num = int(sb_mode_sep[2])
	sb_population_size = int(sb_mode_sep[3])
	sb_max_iteration = int(sb_mode_sep[4])
	sb_loss_tolerance = max(gss_loss*1.02, float(sb_mode_sep[5]))
	
	if sb_population_size%sb_look_up_num != 0:
		raise ValueError('sb_population_size must be divisible by sb_look_up_num')

	logging.info('Load gss data from gss_results/sim_{}_{}_{}.pt'.format(target_data_name, gss_mode, loss_measure))
	logging.info(f'SB paramsettings: sb_strategy: {sb_strategy}, sb_look_up_num: {sb_look_up_num}, sb_population_size: {sb_population_size}, sb_max_iteration: {sb_max_iteration}, sb_loss_tolerance: {sb_loss_tolerance}')

	sb_pipeline = PipelineSB(
		# QTN params
		qtn_R=qtn_R, qtn_Q=qtn_Q, qtn_K=qtn_K, qtn_L=qtn_L, qtn_max_iterations=1500, qtn_evaluate_repeat=10, qtn_init_std=std_init, qtn_loss_measure=loss_measure,
		# SB params
		sb_init_qtn=gss_qtn, sb_init_loss=gss_loss, sb_max_iteration = sb_max_iteration, sb_population_size= sb_population_size, sb_loss_tolerance = sb_loss_tolerance, sb_sampling_strategy = sb_strategy, sb_look_up_num=sb_look_up_num,
		callbacks=[score_summary]
		)
	sb_pipeline()

	base_folder = './'
	try:
		os.mkdir(base_folder+'sb_results')
	except:
		pass

	torch.save({
		'sb_best_qtn_list': sb_pipeline.best_qtn_list,
		'sb_min_loss_list': sb_pipeline.min_loss_list,
		'sb_evaluation_count_list': sb_pipeline.evaluation_count_list,
		'total_evaluation_count': sb_pipeline.current_generation.total_evaluation_count
	}
	, 'sb_results/sim_{}_{}_{}_{}.pt'.format(target_data_name, gss_mode, sb_mode, loss_measure)
	)