import math
import numpy as np, os, sys, re, glob, subprocess, math, unittest, shutil, time, string, logging, gc
np.set_printoptions(precision=4)
from time import gmtime, strftime
from random import shuffle, choice, sample, choices
import random
from itertools import product
from functools import partial
import inspect
import itertools
from quan_decomp import *
import copy
from typing import List
from collections import OrderedDict


# create realted folders
base_folder = './'
try:
	os.mkdir(base_folder+'center_log')
	os.mkdir(base_folder+'agent_log')
	os.mkdir(base_folder+'agent_pool')
	os.mkdir(base_folder+'job_pool')
	os.mkdir(base_folder+'result_pool')
except:
	pass

class Individual:

	def __init__(self, qtn_seq=None, scope=None, **kwargs):
			
		self.qtn_seq = qtn_seq

		# set the decomposition args
		self.scope = scope
		self.repeat = kwargs['qtn_evaluate_repeat'] if 'qtn_evaluate_repeat' in kwargs.keys() else 1
		self.iters = kwargs['qtn_max_iterations'] if 'qtn_max_iterations' in kwargs.keys() else 10000
		self.qtn_Q = kwargs['qtn_Q']
		self.qtn_R = kwargs['qtn_R']
		self.init_std = kwargs['qtn_init_std'] if 'qtn_init_std' in kwargs.keys() else 0.1
		self.loss_measure = kwargs['qtn_loss_measure']

	# deploy the qtn_seq to the pool
	def deploy(self, sge_job_id):
		try:
			path = base_folder+'/job_pool/{}.npz'.format(sge_job_id)
			np.savez(path, qtn_seq=self.qtn_seq, scope=self.scope, repeat=self.repeat, iters=self.iters, qtn_Q=self.qtn_Q, qtn_R=self.qtn_R, init_std=self.init_std, loss_measure=self.loss_measure)
			self.sge_job_id = sge_job_id
			return True
		except Exception as e:
			raise e

	# collect the result from the pool
	def collect(self, fake_loss=False):
		if not fake_loss:
			try:
				path = base_folder+'/result_pool/{}.npz'.format(self.scope.replace('/', '_'))
				result = np.load(path)
				self.repeat_loss = result['repeat_loss']

				os.remove(path)
				return True
			
			except Exception:
				return False
		else:
			self.repeat_loss = [9999]*self.repeat
			return True

class GenerationQTN:

	# initialize the generation
	def __init__(self, previous_generation=None, name=None, mode=None, **kwargs):

		self.name = name
		self.mode = mode
		self.kwargs = kwargs
		self.gate_num = self.kwargs['qtn_L']
		self.qtn_Q = self.kwargs['qtn_Q']
		self.qtn_K = self.kwargs['qtn_K']
		self.qtn_R = self.kwargs['qtn_R']
		self.population_size = self.kwargs['ga_population_size'] if 'ga_population_size' in self.kwargs.keys() else 20
		self.epochs = self.kwargs['epochs'] if 'epochs' in self.kwargs.keys() else 3000
		self.std = self.kwargs['std'] if 'std' in self.kwargs.keys() else 0.1
		self.repeat_time = self.kwargs['repeat_time'] if 'repeat_time' in self.kwargs.keys() else 1

		self.indv_to_collect = []
		self.individuals = {}

		self.gss_mode = kwargs['gss_mode'] if 'gss_mode' in kwargs.keys() else 'ga'
		if self.gss_mode == 'ga':
			if previous_generation is not None:
				# Copy relevant information from the previous generation
				self.population = [indv.qtn_seq for indv in previous_generation.individuals['indv']]
			else:
				# Initialize the population as before
				self.population = self.initialize_population_ga()
		elif self.gss_mode == 'ale':
			self.ale_init_strategy = self.kwargs['ale_init_strategy'] if 'ale_init_strategy' in self.kwargs.keys() else 'combine'
			if previous_generation is not None:
				self.best_qtn = previous_generation.best_qtn
				self.min_loss = previous_generation.min_loss
				self.ale_step = previous_generation.ale_step + 1
				logging.info('\033[32m===== ALE Start with QTN {} and loss {}=====\033[0m'.format(array_to_gate(self.best_qtn), self.min_loss))
				self.population = self.generate_population_ale()
			else:
				self.best_qtn = self.initialize_qtn_ale()
				self.ale_step = 0
				self.min_loss = 9999
				self.population = self.generate_population_ale()
		elif self.gss_mode == 'random':
			if previous_generation is not None:
				self.population_pool = previous_generation.population_pool
				self.population = self.generate_population_random()
			else:
				self.population_pool = []
				self.population = self.generate_population_random()
		elif self.gss_mode == 'greedy':
			if previous_generation is not None:
				self.best_qtn = previous_generation.best_qtn
				self.min_loss = previous_generation.min_loss
				self.greedy_step = previous_generation.greedy_step + 1
				logging.info('\033[32m===== GREEDY Start with QTN {} and loss {}=====\033[0m'.format(array_to_gate(self.best_qtn), self.min_loss))
				self.population = self.generate_population_greedy()
			else:
				self.greedy_step = 0
				self.min_loss = 9999
				self.population = self.generate_population_greedy()
						
		# initialize the population
		self.individuals['indv'] = [Individual(qtn_seq=self.population[i],scope='{}/{:03d}'.format(self.name, i), **self.kwargs) for i in range(len(self.population))]
		self.indv_to_distribute = [indv for indv in self.individuals['indv']]
    
	def __call__(self, **kwargs):
		try:
			self.__evaluate__()
			if 'callbacks' in kwargs.keys():
				for c in kwargs['callbacks']:
					c(self)
			if self.gss_mode == 'ga':
				self.__evolve_ga__()
			elif self.gss_mode == 'ale':
				self.__evolve_ale__()
			elif self.gss_mode == 'random':
				self.__evolve_random__()
			elif self.gss_mode == 'greedy':
				self.__evolve_greedy__()
			else:
				raise ValueError('Invalid mode: {}'.format(self.mode))
			return True
		except Exception as e:
			raise e

	def __evolve_ga__(self):

		# get the fitness
		ranks = self.individuals['rank']
		self.population = [self.individuals['indv'][i].qtn_seq for i in range(self.population_size)]
		fitness = self.individuals['fitness'] #[max(0.01, math.log(200/(1e-2+5*ranks[i]))) for i in range(self.population_size)]

		# remove the worst 20% of the chormosomes'
		worst_indices = [idx for idx, rank in enumerate(ranks) if ranks[idx]>=int(len(ranks) * 0.8)]
		self.population = [chrom for i, chrom in enumerate(self.population) if i not in worst_indices]
		fitness = [fit for i, fit in enumerate(fitness) if i not in worst_indices]

		new_population = []
		logging.info(f"'===== Evolution process ====='")
		for _ in range(len(self.population)):
			parent1, parent2 = self.select_parents(fitness)		
			child, split = self.crossover(parent1, parent2)
			logging.info('parent1: {} | parent2: {} | split: {} | child: {}'.format(array_to_gate(parent1), array_to_gate(parent2), split, array_to_gate(child)))
			child = self.mutate(child)
			new_population.append(child)
		# Deduplicate and fill the new population
		self.population = self.deduplicate_and_fill(new_population)

		# put poluation to 
		for i, indv in enumerate(self.individuals['indv']):
			indv.qtn_seq = self.population[i]

		self.min_loss = self.current_loss
		self.best_qtn = self.current_qtn

	def __evolve_ale__(self):
		
		if self.current_loss < self.min_loss:
			self.min_loss = self.current_loss
			self.best_qtn = self.current_qtn

	def __evolve_random__(self):
		
		self.min_loss = self.current_loss
		self.best_qtn = self.current_qtn

	def __evolve_greedy__(self):
		
		self.min_loss = self.current_loss
		self.best_qtn = self.current_qtn

	def __evaluate__(self):
		# RANKING
		self.individuals['loss'] = [np.min(indv.repeat_loss) for indv in self.individuals['indv'] ]
		self.individuals['rank'] =  self.get_sorted_positions(self.individuals['loss'])
		self.individuals['fitness'] = [max(0.01, math.log(200/(1e-2+5*self.individuals['rank'][i]))) for i in range(len(self.individuals['rank']))]
		
		ranks = self.individuals['rank']
		self.first_rank_idx = ranks.index(min(ranks))
		self.current_qtn = self.individuals['indv'][self.first_rank_idx].qtn_seq
		self.current_loss = self.individuals['loss'][self.first_rank_idx]

	def initialize_population_ga(self) -> List[np.ndarray]:
		initial_population = [self.random_chromosome() for _ in range(self.population_size)]
		return self.deduplicate_and_fill(initial_population)
	
	def initialize_qtn_ale(self) -> List[np.ndarray]:
		
		# get possible gates using itertools select K elements from [1, Q]
		if self.ale_init_strategy == 'combine':
			possible_gates = list(itertools.combinations(range(-1, -self.qtn_Q-1, -1), self.qtn_K))
			# initialize the qtn_layers using the possible gates cyclically until reach the gate_num
			qtn_layers_init = []
			for i in range(self.gate_num):
				qtn_layers_init.append(possible_gates[i % len(possible_gates)])
			qtn_layers_init = gate_to_array(qtn_layers_init)
		elif self.ale_init_strategy == 'stair':
			# genetrate date cyclically from (-1, -2) to (-Q-1, Q) and copy them until total gates reach the gate_num
			possible_gates = [(-i, -i-1) for i in range(1, self.qtn_Q, 1)]
			qtn_layers_init = []
			for i in range(self.gate_num):
				qtn_layers_init.append(possible_gates[i % len(possible_gates)])
			qtn_layers_init = gate_to_array(qtn_layers_init)
		elif self.ale_init_strategy == 'random':
			qtn_layers_init = self.random_chromosome()

		return qtn_layers_init
	
	def generate_population_ale(self):

		# compute the mod of
		gate_idx = self.ale_step % (self.gate_num-1)
		iter_num = self.ale_step // (self.gate_num-1)
		if iter_num % 2 == 1:
			gate_idx = self.gate_num - gate_idx - 1

		# get possible gate locations:
		possible_gates = list(itertools.combinations(range(-1, -self.qtn_Q-1, -1), self.qtn_K))

		new_population = []
		logging.info("===== ALE process on gate {}=====".format(gate_idx))
		if self.ale_step == 0:
			new_population.append(self.best_qtn)
		for gate in possible_gates:
			parent = copy.deepcopy(self.best_qtn)
			if gate_idx == 0:
				if gate == tuple(self.best_qtn[:, gate_idx+1]):
					continue
			elif gate_idx == self.gate_num-1:
				if gate == tuple(self.best_qtn[:, gate_idx-1]):
					continue
			elif gate == tuple(self.best_qtn[:, gate_idx-1]) or gate == tuple(self.best_qtn[:, gate_idx+1]):
				continue
			if gate != tuple(self.best_qtn[:, gate_idx]):
				parent[:, gate_idx] = np.array(gate)
				logging.info(f"change gate {tuple(self.best_qtn[:, gate_idx])} to {gate}")
				new_population.append(parent)

		return new_population
	
	def generate_population_greedy(self):

		# compute the mod of
		gate_idx = self.greedy_step

		# get possible gate locations:
		possible_gates = list(itertools.combinations(range(-1, -self.qtn_Q-1, -1), self.qtn_K))

		new_population = []
		print("===== GREEDY process on gate {}=====".format(gate_idx))
		for gate in possible_gates:
			if gate_idx == 0:
				parent = []
			else:
				parent = copy.deepcopy(array_to_gate(self.best_qtn))
			if gate_idx > 0 and gate == parent[gate_idx-1]:
				continue
			parent.append(gate)
			print(f"add gate {gate} to {gate_idx}")
			new_population.append(gate_to_array(parent))

		return new_population
	
	def generate_population_random(self):
		
		new_population = []
		attempts = 0
		max_attempts = self.population_size * 20  # Arbitrary limit to prevent infinite loops

		while len(new_population) < self.population_size and attempts < max_attempts:
			new_chromosome = self.random_chromosome()
			new_chromosome_tuple = self.chromosome_to_tuple(new_chromosome)
			
			if new_chromosome_tuple not in self.population_pool:
				new_population.append(new_chromosome)
				self.population_pool.append(new_chromosome_tuple)
			
			attempts += 1

		if len(new_population) < self.population_size:
			logging.warning(f"Could only generate {len(new_population)} unique chromosomes out of {self.population_size} requested.")
		
		return new_population

	def random_chromosome(self) -> np.ndarray:
		chromosome = np.zeros((self.qtn_K, self.gate_num), dtype=int)
		for i in range(self.gate_num):
			while True:
				column = np.sort(random.sample(range(-self.qtn_Q, 0), self.qtn_K))[::-1]
				if i == 0 or not np.array_equal(column, chromosome[:, i-1]):
					chromosome[:, i] = column
					break
		return chromosome
	
	def merge_duplicate_gates(self, seq):
		# remove adjcent elements in qtn_layers that are the same (no more than one duplicate gates)
		loc = None
		for i in range(seq.shape[1]-1):
			if np.array_equal(seq[:, i], seq[:, i+1]):
				seq = np.delete(seq, i, axis=1)
				loc = i
				break
		return seq, loc

	def deduplicate_and_fill(self, population: List[np.ndarray]) -> List[np.ndarray]:
		# Convert numpy arrays to tuples for hashing
		unique_chromosomes = list({tuple(map(tuple, chrom)) for chrom in population})
		
		# Convert back to numpy arrays
		unique_population = [np.array(chrom) for chrom in unique_chromosomes]
		
		# Fill the population if needed
		if len(unique_population) < self.population_size:
			print("{} duplicate (missing) chromosomes find!".format(self.population_size - len(unique_population)))

		while len(unique_population) < self.population_size:
			new_chrom = self.random_chromosome()
			if not any(np.array_equal(new_chrom, chrom) for chrom in unique_population):
				unique_population.append(new_chrom)
		
		return unique_population

	def select_parents(self, fitness) -> tuple[np.ndarray, np.ndarray]:
		parents = random.choices(self.population, weights=fitness, k=2)
		return parents[0], parents[1]
	
	def chromosome_to_tuple(self, chromosome: np.ndarray) -> tuple:
		return tuple(tuple(col) for col in chromosome.T)

	def crossover(self, parent1: np.ndarray, parent2: np.ndarray) -> np.ndarray:
		split = random.randint(1, self.gate_num - 1)
		child = np.zeros_like(parent1)
		child[:, :split] = parent1[:, :split]
		child[:, split:] = parent2[:, split:]
		# Ensure adjacent columns are not the same
		for i in range(1, self.gate_num):
			while np.array_equal(child[:, i], child[:, i-1]):
				child[:, i] = np.sort(random.sample(range(-self.qtn_Q, 0), self.qtn_K))[::-1]
		return child, split

	def mutate(self, chromosome: np.ndarray, mutation_rate: float = 0.05) -> np.ndarray:
		for i in range(self.gate_num):
			if random.random() < mutation_rate:
				while True:
					new_column = np.sort(random.sample(range(-self.qtn_Q, 0), self.qtn_K))[::-1]
					if (i == 0 or not np.array_equal(new_column, chromosome[:, i-1])) and \
						(i == self.gate_num-1 or not np.array_equal(new_column, chromosome[:, i+1])):
						chromosome[:, i] = new_column
						break
		return chromosome

	def get_sorted_positions(self, lst):
		sorted_lst = sorted(enumerate(lst), key=lambda x: x[1])
		positions = [0] * len(lst)
		for rank, (index, value) in enumerate(sorted_lst):
			positions[index] = rank
		return positions

	def distribute_indv(self, agent):
		if self.indv_to_distribute:
			indv = self.indv_to_distribute.pop(0)
			agent.receive(indv)
			self.indv_to_collect.append(indv)
			logging.info('Assigned individual {} to agent {}.'.format(indv.scope, agent.sge_job_id))

	def collect_indv(self):
		for indv in self.indv_to_collect:
			if indv.collect():
				logging.info('Collected individual result {}.'.format(indv.scope))
				self.indv_to_collect.remove(indv)

	def is_finished(self):
		if len(self.indv_to_distribute) == 0 and len(self.indv_to_collect) == 0:
			return True
		else:
			return False

class Agent(object):
	def __init__(self, **kwargs):
		super(Agent, self).__init__()
		self.kwargs = kwargs
		self.sge_job_id = self.kwargs['sge_job_id']

	def receive(self, indv):
		indv.deploy(self.sge_job_id)
		with open(base_folder+'/agent_pool/{}.POOL'.format(self.sge_job_id), 'a') as f:
			f.write(target_data)

	def is_available(self):
		return True if os.stat(base_folder+'/agent_pool/{}.POOL'.format(self.kwargs['sge_job_id'])).st_size == 0 else False


class PipelineGSS(object):

	def __init__(self, gss_max_iteration=100, sb_max_iteration=10, **kwargs):
		super(PipelineGSS, self).__init__()
		self.kwargs = kwargs
		self.gss_mode = kwargs['gss_mode']
		self.dummy_func = lambda *args, **kwargs: None
		self.max_gss_iteration = gss_max_iteration
		self.current_generation = None
		self.previous_generation = None
		self.N_group_ss = 0
		self.generation_list = {}
		self.available_agents = []
		self.best_qtn_list = []
		self.min_loss_list = []
		self.known_agents = {}
		self.time = 0

	def __call_with_interval__(self, func, interval):
		return func if self.time%interval == 0 else self.dummy_func

	def __tik__(self, sec):
		# logging.info(self.time)
		self.time += sec
		time.sleep(sec)

	def __check_available_agent__(self):
		self.available_agents.clear()
		agents = glob.glob(base_folder+'/agent_pool/*.POOL')
		agents_id = [ a.split('/')[-1][:-5] for a in agents ]

		for aid in list(self.known_agents.keys()):
			if aid not in agents_id:
				logging.info('Dead agent id = {} found!'.format(aid))
				self.known_agents.pop(aid, None)

		for aid in agents_id:
			if aid in self.known_agents.keys():
				if self.known_agents[aid].is_available():
					self.available_agents.append(self.known_agents[aid])
			else:
				self.known_agents[aid] = Agent(sge_job_id=aid)
				logging.info('New agent id = {} found!'.format(aid))

	def __assign_job__(self):
		self.__check_available_agent__()
		if len(self.available_agents)>0:
			for agent in self.available_agents:
				self.current_generation.distribute_indv(agent)

	def __collect_result__(self):
		self.current_generation.collect_indv()

	def __report_agents__(self):
		logging.info('Current number of known agents is {}.'.format(len(self.known_agents)))
		logging.info(list(self.known_agents.keys()))

	def __report_generation__(self):
		logging.info('Current length of indv_to_distribute is {}.'.format(len(self.current_generation.indv_to_distribute)))
		logging.info('Current length of indv_to_collect is {}.'.format(len(self.current_generation.indv_to_collect)))
		logging.info([(indv.scope, indv.sge_job_id) for indv in self.current_generation.indv_to_collect])

	def __chromosome_to_tuple__(self, chromosome: np.ndarray) -> tuple:
		return tuple(tuple(col) for col in chromosome.T)
	
	def __tuple_to_chromosome__(self, chromosome_tuple: tuple) -> np.ndarray:
		return np.array(chromosome_tuple).T

	def __update_all_chromosomes__(self, population: List[np.ndarray], losses: List[float]):
		for chrom, loss in zip(population, losses):
			chrom_tuple = self.__chromosome_to_tuple__(chrom)
			if chrom_tuple not in self.generation_list or loss < self.generation_list[chrom_tuple]:
				self.generation_list[chrom_tuple] = loss

        # Sort the OrderedDict by loss values
		self.generation_list = OrderedDict(sorted(self.generation_list.items(), key=lambda x: x[1]))

	def __generation__(self):
		if self.N_group_ss > self.max_gss_iteration:
			return False
		else:
			if self.current_generation is None:
				self.current_generation = GenerationQTN(name='generation_init', mode='group_struture_search', **self.kwargs)
				self.N_group_ss += 1
				#self.current_generation.indv_to_distribute = []

			if self.current_generation.is_finished():
				if self.N_group_ss > 0:
					self.current_population = self.current_generation.population
					self.current_generation(**self.kwargs)
					self.__update_all_chromosomes__(self.current_population, self.current_generation.individuals['loss'])
					self.best_qtn_list.append(self.current_generation.best_qtn)
					self.min_loss_list.append(self.current_generation.min_loss)
					logging.info('\033[31m===== Best QTN_SEQ of No.{} Generation =====\033[0m'.format(self.N_group_ss))
					logging.info('\033[31mQtn: {} | Loss: {:.5f}\033[0m'.format(array_to_gate(self.best_qtn_list[-1]), self.min_loss_list[-1]))
				self.N_group_ss += 1
				self.previous_generation = self.current_generation
				self.current_generation = GenerationQTN(self.previous_generation, 
														name='generation_{:03d}'.format(self.N_group_ss), mode='group_struture_search', **self.kwargs)

			return True

	def __call__(self):
		logging.info('\033[32m===== {} Start =====\033[0m'.format(self.gss_mode.upper()))
		while self.__generation__():
			self.__call_with_interval__(self.__check_available_agent__, 4)()
			self.__call_with_interval__(self.__assign_job__, 4)()
			self.__call_with_interval__(self.__collect_result__, 4)()
			self.__call_with_interval__(self.__report_agents__, 180)()
			self.__call_with_interval__(self.__report_generation__, 160)()
			self.__tik__(2)
		logging.info('===== {} Finished ====='.format(self.gss_mode.upper()))
		logging.info('===== Top 20 QTN_SEQ =====')
		for chrom, loss in list(self.generation_list.items())[:20]:
			logging.info('{} | {:.5f}'.format(chrom, loss))
		
		# get the key in first value and key of self.generation_list
		best_qtn = next(iter(self.generation_list))
		min_loss = self.generation_list[best_qtn]
		logging.info('\033[32m===== Best QTN {} and loss {} =====\033[0m'.format(best_qtn, min_loss))
		self.gss_qtn = self.__tuple_to_chromosome__(best_qtn)
		self.gss_loss = min_loss

def array_to_gate(qtn_seq: np.ndarray) -> List[tuple]:
	#return [tuple(qtn_seq[:, i]) for i in range(qtn_seq.shape[1])]
	if type(qtn_seq) == np.ndarray:
		#qtn_layer = [tuple(qtn_seq[:, i]) for i in range(qtn_seq.shape[1])]
		qtn_layer = [tuple(map(int, tuple(qtn_seq[:, i]))) for i in range(qtn_seq.shape[1])]
	elif type(qtn_seq) == dict:
		qtn_layer = {}
		for k, v in qtn_seq.items():
			#qtn_layer[k] = [tuple(v[:, i]) for i in range(v.shape[1])]
			qtn_layer[k] = [tuple(map(int, tuple(v[:, i]))) for i in range(v.shape[1])]
	return qtn_layer


def gate_to_array(qtn_layer: List[tuple]) -> np.ndarray:
	return np.array([list(qtn) for qtn in qtn_layer]).T

def score_summary(obj):
	logging.info('\033[32m===== {} =====\033[0m'.format(obj.name))

	for idx, indv in enumerate(obj.individuals['indv']):
		if 0 == obj.individuals['rank'][idx]:
			logging.info('\033[31m{} | repeat_losses: {} | loss: {:.5f} | rank: {} | fitness: {} | qtn_seq: {}\033[0m'.format(indv.scope, [ float('{:0.4f}'.format(l)) for l in indv.repeat_loss ], obj.individuals['loss'][idx],  obj.individuals['rank'][idx], obj.individuals['fitness'][idx], array_to_gate(indv.qtn_seq)))
			#logging.info(array_to_gate(indv.qtn_seq))
		else:
			logging.info('{} | repeat_losses: {} | loss: {:.5f} | rank: {} | fitness: {} | qtn_seq: {}'.format(indv.scope, [ float('{:0.4f}'.format(l)) for l in indv.repeat_loss ], obj.individuals['loss'][idx],  obj.individuals['rank'][idx], obj.individuals['fitness'][idx], array_to_gate(indv.qtn_seq)))

if __name__ == '__main__':

	target_data = sys.argv[1]
	gss_mode = sys.argv[2]

	if len(sys.argv) > 1:
		print(f"target_data: {sys.argv[1]}")
	if len(sys.argv) > 2:
		print(f"gss_mode: {sys.argv[2]}")
	if len(sys.argv) > 3:
		loss_measure = sys.argv[3]
	else:
		loss_measure = 'total'
	print(f"Loss measure: {loss_measure}")

	# gss input
	gss_mode_sep = gss_mode.split('_')
	gss_strategy = gss_mode_sep[1]
	qtn_L = int(gss_mode_sep[2])
	ga_population_size = int(gss_mode_sep[3])
	gss_max_iteration = int(gss_mode_sep[4])

	# load data and information
	data = torch.load(target_data)
	qtn_R = data['qtn_R']
	qtn_Q = data['qtn_Q']
	qtn_K = data['qtn_K']
	mask = data['mask']
	std_init = data['std_init'] if 'std_init' in data.keys() else 0.3
	num_of_tensors = len(data['data'])

	target_data_name = target_data.split('.')[0]
	
	current_time = strftime("%Y%m%d_%H%M%S", gmtime())

	log_name = 'center_log/sim_{}_{}_{}.log'.format(target_data_name, gss_mode, current_time)
	logging.basicConfig(filename=log_name, filemode='a', level=logging.DEBUG,
											format='%(asctime)s: %(message)s', datefmt='%H:%M:%S')
	console = logging.StreamHandler()
	console.setLevel(logging.INFO)
	formatter = logging.Formatter('%(asctime)s:  %(message)s')
	console.setFormatter(formatter)
	logging.getLogger('').addHandler(console)

	gss_pipeline = PipelineGSS(
		# GA params
		gss_mode=gss_strategy, gss_max_iteration=gss_max_iteration, ga_population_size=ga_population_size,
		# QTN params
		qtn_R=qtn_R, qtn_Q=qtn_Q, qtn_K=qtn_K, qtn_L=qtn_L, qtn_max_iterations=1500, qtn_evaluate_repeat=10, qtn_init_std=std_init, qtn_loss_measure=loss_measure,
		callbacks=[score_summary]
		)
	
	gss_pipeline()

	base_folder = './'
	try:
		os.mkdir(base_folder+'gss_results')
	except:
		pass

	# save results
	torch.save({
		'gss_qtn': gss_pipeline.gss_qtn,
		'gss_loss': gss_pipeline.gss_loss,
		'gss_best_qtn_list': gss_pipeline.best_qtn_list,
		'gss_min_loss_list': gss_pipeline.min_loss_list,
	}
	, 'gss_results/sim_{}_{}_{}.pt'.format(target_data_name, gss_mode, loss_measure)
	)