import math
import numpy as np, os, sys, re, glob, subprocess, math, unittest, shutil, time, string, logging, gc
np.set_printoptions(precision=4)
from time import gmtime, strftime
from random import shuffle, choice, sample, choices
import random
from itertools import product
from functools import partial
import inspect
import itertools
import torch
import copy
from typing import List
from collections import OrderedDict, defaultdict


# create realted folders
base_folder = './'
try:
	os.mkdir(base_folder+'center_log')
	os.mkdir(base_folder+'agent_log')
	os.mkdir(base_folder+'agent_pool')
	os.mkdir(base_folder+'job_pool')
	os.mkdir(base_folder+'result_pool')
except:
	pass

class Individual:

	def __init__(self, qtn_seq=None, scope=None, **kwargs):
			
		self.qtn_seq = qtn_seq

		# set the decomposition args
		self.scope = scope

	# deploy the qtn_seq to the pool
	def deploy(self, sge_job_id):
		try:
			path = base_folder+'/job_pool/{}.npz'.format(sge_job_id)
			np.savez(path, qtn_seq=self.qtn_seq, scope=self.scope)
			self.sge_job_id = sge_job_id
			return True
		except Exception as e:
			raise e

	# collect the result from the pool
	def collect(self, fake_loss=False):
		if not fake_loss:
			try:
				path = base_folder+'/result_pool/{}.npz'.format(self.scope.replace('/', '_'))
				result = np.load(path)
				self.dev_loss = result['dev_loss']
				self.test_loss = result['test_loss']

				os.remove(path)
				return True
			
			except Exception:
				return False
		else:
			self.dev_loss = [9999]*self.repeat
			return True

class GenerationQTN:

	# initialize the generation
	def __init__(self, previous_generation=None, name=None, mode=None, **kwargs):

		self.name = name
		self.mode = mode
		self.kwargs = kwargs
		self.look_up_num = self.kwargs['sb_look_up_num'] if 'sb_look_up_num' in self.kwargs.keys() else 1
		self.population_size = self.kwargs['sb_population_size'] if 'sb_population_size' in self.kwargs.keys() else 10
		self.sb_loss_tolerance = self.kwargs['sb_loss_tolerance'] if 'sb_loss_tolerance' in self.kwargs.keys() else 0.01
		self.is_pool_full = False
		self.indv_to_collect = []
		self.individuals = {}

		self.sb_sampling_strategy = self.kwargs['sb_sampling_strategy'] if 'sb_sampling_strategy' in self.kwargs.keys() else 'random'
		if previous_generation is not None:
			self.best_qtn = previous_generation.best_qtn
			self.min_loss = previous_generation.min_loss
			self.sb_loss_list = previous_generation.sb_loss_list
			self.gates_pool = previous_generation.gates_pool
			self.total_evaluation_count = previous_generation.total_evaluation_count
		else:
			self.initial_qtn = self.kwargs['sb_init_qtn']
			self.min_loss = self.kwargs['sb_init_loss']
			self.best_qtn = self.initialize_qtn_sb(self.initial_qtn)
			self.sb_loss_list = [self.min_loss for _ in range(self.initial_qtn.shape[1]*len(self.best_qtn))]
			self.gates_pool = []
			self.total_evaluation_count = 0
		logging.info('\033[32m===== SB Start with QTN {} and loss {} =====\033[0m'.format(array_to_gate(self.best_qtn), self.min_loss))
		logging.info('\033[32m===== Gates Indexs Pool {} =====\033[0m'.format(self.gates_pool))

		self.gate_list = self.update_gate_list()
		self.population = self.generate_population_sb()

		# initialize the population
		self.individuals['indv'] = [Individual(qtn_seq=self.population[i],scope='{}/{:03d}'.format(self.name, i), **self.kwargs) for i in range(len(self.population))]
		self.indv_to_distribute = [indv for indv in self.individuals['indv']]
    
	def __call__(self, **kwargs):
		try:
			self.__evaluate__()
			if 'callbacks' in kwargs.keys():
				for c in kwargs['callbacks']:
					c(self)
			self.__evolve_sb__()
			return True
		except Exception as e:
			raise e

	def __evolve_sb__(self):
		
		if len(self.gates_pool) == 0:
			loc = self.remove_gate_idx[self.first_rank_idx]
            #self.sb_loss_list = [loss for loss in self.sb_loss_list if loss != self.current_loss]
			self.sb_loss_list.pop(loc)
			self.gate_list.pop(loc)
			if self.merge_locs[self.first_rank_idx] is not None:
				tensor_idx = self.merge_locs[self.first_rank_idx][0]
				gate_loc = self.merge_locs[self.first_rank_idx][1][0]
				logging.info(f"merge loss {gate_loc} and {gate_loc+1} of tensor {tensor_idx}, in the location of loss_list {loc}")
				min_value = min(self.sb_loss_list[loc-1], self.sb_loss_list[loc])
				self.sb_loss_list[loc-1] = min_value
				self.sb_loss_list.pop(loc)
				self.gate_list.pop(loc)
				logging.info('Current gate_list: {}'.format(self.gate_list))
				logging.info('Current loss_list: {}'.format(self.sb_loss_list))

	def __evaluate__(self):
		# RANKING
		self.individuals['loss'] = [np.min(indv.dev_loss) for indv in self.individuals['indv'] ]
		self.individuals['rank'] =  self.get_sorted_positions(self.individuals['loss'])
		#self.individuals['fitness'] = [max(0.01, math.log(200/(1e-2+5*self.individuals['rank'][i]))) for i in range(len(self.individuals['rank']))]

		# find the lowest indx that below the threshold self.loss_tolerance
		look_up_block = math.ceil(len(self.individuals['loss'])/self.look_up_num)
		for i in range(look_up_block):
			self.total_evaluation_count += self.look_up_num
			# look at five consecutive elements and find the lowest loss
			look_up_list = []
			for j in range(i*self.look_up_num, (i+1)*self.look_up_num):
				if j < len(self.individuals['loss']):
					look_up_list.append(self.individuals['loss'][j])
					self.gates_pool.append(self.remove_gate_idx[j])
					self.sb_loss_list[self.remove_gate_idx[j]] = self.individuals['loss'][j]
			min_loss = min(look_up_list)
			min_indx = look_up_list.index(min_loss)
			if min_loss < self.sb_loss_tolerance:
				min_indx_total= i*self.look_up_num + min_indx
				self.best_qtn = self.individuals['indv'][min_indx_total].qtn_seq
				self.min_loss = self.individuals['loss'][min_indx_total]
				self.individuals['rank'][min_indx_total] = -1
				self.gates_pool = []
				break
		
		if len(self.gates_pool)>=len(self.gate_list):
			self.is_pool_full = True
			logging.info('Gates pool is full!!!') 
		
		ranks = self.individuals['rank']
		self.first_rank_idx = ranks.index(min(ranks))
	
	def initialize_qtn_sb(self, qtn_layer) -> List[dict]:
		qtn_layers = {}
		for i in range(num_of_tensors):
			qtn_layers[i] = copy.deepcopy(qtn_layer)
		return qtn_layers
	
	def generate_population_sb(self):

		if self.sb_sampling_strategy == 'rank':
			gate_list_with_loss = [tuple((self.gate_list[i], self.sb_loss_list[i])) for i in range(len(self.gate_list))]
			logging.info('Current gate_list and loss: {}'.format(gate_list_with_loss))

		self.remove_gate_idx = self.generate_remove_gates()

		remove_gates_list = []
		for i in range(len(self.remove_gate_idx)):
			if type(self.remove_gate_idx[i]) == np.ndarray:
				remove_per_indv = [self.gate_list[j] for j in self.remove_gate_idx[i]]
				remove_gates_list.append(self.merge_lists_by_first_element(remove_per_indv))
			else:
				remove_per_indv = self.gate_list[self.remove_gate_idx[i]]
				remove_gates_list.append([remove_per_indv])
			

		new_population = []
		merge_locs = []
		logging.info(f"'===== Symmetric breaking process ====='")
		for remove_gates in remove_gates_list:
			parent = copy.deepcopy(self.best_qtn)
			logging.info(f"====================================")
			for items in remove_gates:
				tensor_idx = items[0]
				gate_idx = items[1:]
				logging.info(f"remove number {gate_idx} from tensor {tensor_idx}")
				parent[tensor_idx] = np.delete(parent[tensor_idx], gate_idx, axis=1)
				parent[tensor_idx], merge_gate_locs = self.merge_duplicate_gates(parent[tensor_idx])
			if len(merge_gate_locs) > 0:
				merge_gate_locs_end = [idx+1 for idx in merge_gate_locs]
				logging.info(f"merge gates {merge_gate_locs} and {merge_gate_locs_end} of tensor {tensor_idx}")
				locs = [tensor_idx, merge_gate_locs]
			else:
				locs = None
			#logging.info(f"add population {parent}")
			new_population.append(parent)
			merge_locs.append(locs)

		self.merge_locs = merge_locs

		return new_population
	
	def weighted_sample_no_replacement(self, items, weights, k):

		result = []
		for _ in range(k):
			chosen = random.choices(items, weights=weights, k=1)[0]
			result.append(chosen)
			index = items.index(chosen)
			items.pop(index)
			weights.pop(index)
		return result
	
	def generate_remove_gates(self):

		idx_all = range(len(self.gate_list))
		idx_not_in_pool = [i for i in idx_all if i not in self.gates_pool]

		if self.sb_sampling_strategy == 'random':
			remove_gate_idx = random.sample(idx_not_in_pool, min(self.population_size,len(idx_not_in_pool))) #int(len(remove_gate_list)*0.3))
		elif self.sb_sampling_strategy == 'rank':
			ranks = self.get_sorted_positions(self.sb_loss_list)
			fitness = [max(0.01, math.log(200/(1e-2+5*ranks[i]))) for i in range(len(ranks))]
			fitness_not_in_pool = [fitness[i] for i in idx_not_in_pool]
			remove_gate_idx = self.weighted_sample_no_replacement(idx_not_in_pool, weights=fitness_not_in_pool, k=min(self.population_size,len(idx_not_in_pool)))
		else:
			raise ValueError(f"Invalid sampling strategy: {self.sb_sampling_strategy}")
			
		return remove_gate_idx

	def merge_lists_by_first_element(self, lists):
		merged_dict = defaultdict(list)
		
		for lst in lists:
			first_element = lst[0]
			merged_dict[first_element].extend(lst[1:])
		
		merged_lists = [[key] + value for key, value in merged_dict.items()]
		
		return merged_lists
	
	def update_gate_list(self):
		
		gate_list = []

		for tensor_idx in range(num_of_tensors):
			matrix = self.best_qtn[tensor_idx]
			num_columns = matrix.shape[1]
			
			for col_idx in range(num_columns):
				remove_gates = [tensor_idx, col_idx]
				gate_list.append(remove_gates)
			
		return gate_list
	
	def merge_duplicate_gates(self, seq):
		# remove adjcent elements in qtn_layers that are the same (no more than one duplicate gates)
		locs = []
		for i in range(seq.shape[1]-1):
			if np.array_equal(seq[:, i], seq[:, i+1]):
				locs.append(i)
		if len(locs) > 0:
			seq = np.delete(seq, locs, axis=1)
		return seq, locs

	def get_sorted_positions(self, lst):
		sorted_lst = sorted(enumerate(lst), key=lambda x: x[1])
		positions = [0] * len(lst)
		for rank, (index, value) in enumerate(sorted_lst):
			positions[index] = rank
		return positions

	def distribute_indv(self, agent):
		if self.indv_to_distribute:
			indv = self.indv_to_distribute.pop(0)
			agent.receive(indv)
			self.indv_to_collect.append(indv)
			logging.info('Assigned individual {} to agent {}.'.format(indv.scope, agent.sge_job_id))

	def collect_indv(self):
		for indv in self.indv_to_collect:
			if indv.collect():
				logging.info('Collected individual result {}.'.format(indv.scope))
				self.indv_to_collect.remove(indv)

	def is_finished(self):
		if len(self.indv_to_distribute) == 0 and len(self.indv_to_collect) == 0:
			return True
		else:
			return False

class Agent(object):
	def __init__(self, **kwargs):
		super(Agent, self).__init__()
		self.kwargs = kwargs
		self.sge_job_id = self.kwargs['sge_job_id']

	def receive(self, indv):
		indv.deploy(self.sge_job_id)
		with open(base_folder+'/agent_pool/{}.POOL'.format(self.sge_job_id), 'a') as f:
			f.write(task_name)

	def is_available(self):
		return True if os.stat(base_folder+'/agent_pool/{}.POOL'.format(self.kwargs['sge_job_id'])).st_size == 0 else False


class PipelineSB(object):

	def __init__(self, gss_max_iteration=100, sb_max_iteration=10, **kwargs):
		super(PipelineSB, self).__init__()
		self.kwargs = kwargs
		self.dummy_func = lambda *args, **kwargs: None
		self.max_gss_iteration = gss_max_iteration
		self.max_symbreaking_iteration = sb_max_iteration
		self.current_generation = None
		self.previous_generation = None
		self.N_group_ss = 0
		self.N_symbreaking = 0
		self.generation_list = {}
		self.available_agents = []
		self.best_qtn_list = []
		self.min_loss_list = []
		self.evaluation_count_list = []
		self.known_agents = {}
		self.time = 0
		self.initial_qtn = kwargs['sb_init_qtn']
		self.initial_loss = kwargs['sb_init_loss']

	def __call_with_interval__(self, func, interval):
		return func if self.time%interval == 0 else self.dummy_func

	def __tik__(self, sec):
		# logging.info(self.time)
		self.time += sec
		time.sleep(sec)

	def __check_available_agent__(self):
		self.available_agents.clear()
		agents = glob.glob(base_folder+'/agent_pool/*.POOL')
		agents_id = [ a.split('/')[-1][:-5] for a in agents ]

		for aid in list(self.known_agents.keys()):
			if aid not in agents_id:
				logging.info('Dead agent id = {} found!'.format(aid))
				self.known_agents.pop(aid, None)

		for aid in agents_id:
			if aid in self.known_agents.keys():
				if self.known_agents[aid].is_available():
					self.available_agents.append(self.known_agents[aid])
			else:
				self.known_agents[aid] = Agent(sge_job_id=aid)
				logging.info('New agent id = {} found!'.format(aid))

	def __assign_job__(self):
		self.__check_available_agent__()
		if len(self.available_agents)>0:
			for agent in self.available_agents:
				self.current_generation.distribute_indv(agent)

	def __collect_result__(self):
		self.current_generation.collect_indv()

	def __report_agents__(self):
		logging.info('Current number of known agents is {}.'.format(len(self.known_agents)))
		logging.info(list(self.known_agents.keys()))

	def __report_generation__(self):
		logging.info('Current length of indv_to_distribute is {}.'.format(len(self.current_generation.indv_to_distribute)))
		logging.info('Current length of indv_to_collect is {}.'.format(len(self.current_generation.indv_to_collect)))
		logging.info([(indv.scope, indv.sge_job_id) for indv in self.current_generation.indv_to_collect])

	def __chromosome_to_tuple__(self, chromosome: np.ndarray) -> tuple:
		return tuple(tuple(col) for col in chromosome.T)
	
	def __tuple_to_chromosome__(self, chromosome_tuple: tuple) -> np.ndarray:
		return np.array(chromosome_tuple).T

	def __update_all_chromosomes__(self, population: List[np.ndarray], losses: List[float]):
		for chrom, loss in zip(population, losses):
			chrom_tuple = self.__chromosome_to_tuple__(chrom)
			if chrom_tuple not in self.generation_list or loss < self.generation_list[chrom_tuple]:
				self.generation_list[chrom_tuple] = loss

        # Sort the OrderedDict by loss values
		self.generation_list = OrderedDict(sorted(self.generation_list.items(), key=lambda x: x[1]))
		
	def __symbreaking__(self):
		if self.N_symbreaking > self.max_symbreaking_iteration:
			return False
		else:
			if self.N_symbreaking == 0:
				#self.kwargs['sb_initial_qtn'] = self.initial_qtn
				#self.kwargs['sb_initial_loss'] = self.initial_loss
				self.current_generation = GenerationQTN(name='symbreaking_init', mode='symmetric_breaking', **self.kwargs)
				self.N_symbreaking += 1
				#self.current_generation.indv_to_distribute = []

			if self.current_generation.is_finished():
				if self.N_symbreaking > 0:
					#self.current_population = self.current_generation.population
					self.current_generation(**self.kwargs)
					self.best_qtn_list.append(self.current_generation.best_qtn)
					self.min_loss_list.append(self.current_generation.min_loss)
					self.evaluation_count_list.append(self.current_generation.total_evaluation_count)
					logging.info('\033[31m===== Best QTN_SEQ of No.{} Symbreaking =====\033[0m'.format(self.N_symbreaking))
					logging.info('\033[31mQtn: {} | Loss: {:.5f}\033[0m'.format(array_to_gate(self.best_qtn_list[-1]), self.min_loss_list[-1]))
					logging.info('\033[31m===== Total number of Evaluations: {}=====\033[0m'.format(self.current_generation.total_evaluation_count))
					#self.__update_all_chromosomes__(self.current_population, self.current_generation.individuals['loss'])
				self.N_symbreaking += 1
				self.previous_generation = self.current_generation
				if self.current_generation.is_pool_full == False:
					self.current_generation = GenerationQTN(self.previous_generation, 
														name='symbreaking_{:03d}'.format(self.N_symbreaking), mode='symmetric_breaking', **self.kwargs)

			return True
		

	def __call__(self):
		
		logging.info('\033[32m===== Load from GSS with QTN {} and loss {} =====\033[0m'.format(array_to_gate(self.initial_qtn), self.initial_loss))
		while self.__symbreaking__():
			self.__call_with_interval__(self.__check_available_agent__, 4)()
			self.__call_with_interval__(self.__assign_job__, 4)()
			self.__call_with_interval__(self.__collect_result__, 4)()
			self.__call_with_interval__(self.__report_agents__, 180)()
			self.__call_with_interval__(self.__report_generation__, 160)()
			self.__tik__(2)
			if self.current_generation.is_pool_full == True:
				break
		logging.info('\033[32m===== SB Finished =====\033[0m')
		logging.info('===== Best QTN and loss for each iteration=====')
		for i in range(len(self.best_qtn_list)):
			logging.info('Qtn: {} | Loss: {:.5f}'.format(array_to_gate(self.best_qtn_list[i]), self.min_loss_list[i]))

def array_to_gate(qtn_seq: np.ndarray) -> List[tuple]:
	#return [tuple(qtn_seq[:, i]) for i in range(qtn_seq.shape[1])]
	if type(qtn_seq) == np.ndarray:
		#qtn_layer = [tuple(qtn_seq[:, i]) for i in range(qtn_seq.shape[1])]
		qtn_layer = [tuple(map(int, tuple(qtn_seq[:, i]))) for i in range(qtn_seq.shape[1])]
	elif type(qtn_seq) == dict:
		qtn_layer = {}
		for k, v in qtn_seq.items():
			#qtn_layer[k] = [tuple(v[:, i]) for i in range(v.shape[1])]
			qtn_layer[k] = [tuple(map(int, tuple(v[:, i]))) for i in range(v.shape[1])]
	return qtn_layer

def gate_to_array(qtn_layer: List[tuple]) -> np.ndarray:
	return np.array([list(qtn) for qtn in qtn_layer]).T

def score_summary(obj):
	logging.info('\033[32m===== {} =====\033[0m'.format(obj.name))

	block_count = 0

	for idx, indv in enumerate(obj.individuals['indv']):
		if block_count%sb_look_up_num == 0:
			logging.info('-------------------------------------Block {}-------------------------------------------------'.format(int(block_count//sb_look_up_num)))
		if -1 == obj.individuals['rank'][idx]:
			logging.info('\033[31m{} | dev_loss: {:.5f} | test_loss: {:.5f} | rank: {} | qtn_seq: {}\033[0m'.format(indv.scope, indv.dev_loss, indv.test_loss,  obj.individuals['rank'][idx], array_to_gate(indv.qtn_seq)))
			#logging.info(array_to_gate(indv.qtn_seq))
		else:
			logging.info('{} | dev_loss: {:.5f} | test_loss: {:.5f} | rank: {} | qtn_seq: {}'.format(indv.scope, indv.dev_loss, indv.test_loss,  obj.individuals['rank'][idx], array_to_gate(indv.qtn_seq)))
		block_count += 1

if __name__ == '__main__':

	task_name = sys.argv[1]
	gss_mode = sys.argv[2]
	sb_mode = sys.argv[3]

	# load data and information
	num_of_tensors = 32

	gss_mode_sep = gss_mode.split('_')
	qtn_L = int(gss_mode_sep[1])

	# load gss data
	gss_data = torch.load('gss_results/quanta_{}_{}.pt'.format(task_name, gss_mode))
	gss_qtn = gss_data['gss_qtn']
	gss_loss = gss_data['gss_loss']

	current_time = strftime("%Y%m%d_%H%M%S", gmtime())

	log_name = 'center_log/quanta_{}_{}_{}_{}.log'.format(task_name, gss_mode, sb_mode, current_time)
	logging.basicConfig(filename=log_name, filemode='a', level=logging.DEBUG,
											format='%(asctime)s: %(message)s', datefmt='%H:%M:%S')
	console = logging.StreamHandler()
	console.setLevel(logging.INFO)
	formatter = logging.Formatter('%(asctime)s:  %(message)s')
	console.setFormatter(formatter)
	logging.getLogger('').addHandler(console)

	sb_mode_sep = sb_mode.split('_')
	sb_strategy = sb_mode_sep[1]
	sb_look_up_num = int(sb_mode_sep[2])
	sb_population_size = int(sb_mode_sep[3])
	sb_max_iteration = int(sb_mode_sep[4])
	sb_loss_tolerance = float(sb_mode_sep[5])
	
	if sb_population_size%sb_look_up_num != 0:
		raise ValueError('sb_population_size must be divisible by sb_look_up_num')

	logging.info('Load gss data from gss_results/quanta_{}.pt'.format(gss_mode))
	logging.info(f'SB paramsettings: sb_strategy: {sb_strategy}, sb_look_up_num: {sb_look_up_num}, sb_population_size: {sb_population_size}, sb_max_iteration: {sb_max_iteration}, sb_loss_tolerance: {sb_loss_tolerance}')

	sb_pipeline = PipelineSB(
		# SB params
		sb_init_qtn=gss_qtn, sb_init_loss=gss_loss, sb_max_iteration = sb_max_iteration, sb_population_size= sb_population_size, sb_loss_tolerance = sb_loss_tolerance, sb_sampling_strategy = sb_strategy, sb_look_up_num=sb_look_up_num,
		callbacks=[score_summary]
		)
	sb_pipeline()

	torch.save({
		'sb_best_qtn_list': sb_pipeline.best_qtn_list,
		'sb_min_loss_list': sb_pipeline.min_loss_list,
		'sb_evaluation_count_list': sb_pipeline.evaluation_count_list,
		'total_evaluation_count': sb_pipeline.current_generation.total_evaluation_count
	}
	, 'sb_results/quanta_{}_{}_{}.pt'.format(task_name, gss_mode, sb_mode)
	)