import numpy as np, os, sys, re, glob, subprocess, math, unittest, time, shutil, logging, gc, psutil
np.set_printoptions(precision=2)
from quan_decomp import Quantum_Decomp
import torch
from torch import optim, nn
from random import shuffle, choice
from itertools import product
from functools import partial
# from overlore import Generation
from pathlib import Path
import copy

agent_id = sys.argv[1]
base_folder = './'
try:
	os.mkdir(base_folder+'center_log')
	os.mkdir(base_folder+'agent_log')
	os.mkdir(base_folder+'agent_pool')
	os.mkdir(base_folder+'job_pool')
	os.mkdir(base_folder+'result_pool')
except:
	pass

log_name = base_folder + 'agent_log/{}.log'.format(agent_id)
logging.basicConfig(filename=log_name, filemode='a', level=logging.DEBUG,
										format='%(asctime)s: %(message)s', datefmt='%H:%M:%S')
console = logging.StreamHandler()
console.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s:  %(message)s')
console.setFormatter(formatter)
logging.getLogger('').addHandler(console)

def evaluate(tensors, qtn_layers, qtn_Q, qtn_R, epochs=10000, std=1, lr=1e-2, repeat_time=1, loss_measure='total'):

    if type(qtn_layers) is not dict:
        qtn_layers_all = {}
        for i in range(len(tensors)):
            qtn_layers_all[i] = copy.deepcopy(qtn_layers)
    else:
        qtn_layers_all = copy.deepcopy(qtn_layers)

    norms_sq = 0
    norm_list = []
    for idx, tensor in tensors.items():
        norm_idx = torch.linalg.vector_norm(tensor.reshape(-1))
        norms_sq += norm_idx ** 2
        norm_list.append(norm_idx)
    total_norms = math.sqrt(norms_sq)

    data_num = tensor.numel()

    repeat_loss = []

    se_list = []

    for t in range(repeat_time):

        start_time = time.time()

        model = Quantum_Decomp(Q=qtn_Q, R=qtn_R, qtn_layers=qtn_layers_all, std=std, tensors=tensors)
        model.cuda()
        optimizer = optim.Adam(model.parameters(), lr=lr)
        loss_fn = nn.MSELoss()
        loss_fn.cuda()

        for epoch in range(epochs):
            total_loss = 0
            optimizer.zero_grad()

            for idx, tensor in tensors.items():
                tensor = tensor.cuda()
                output = model(int(idx))
                loss = loss_fn(output, tensor)
                total_loss += loss

            total_loss.backward()
            optimizer.step()

        end_time = time.time()
		
		#compute the output for all tensors
        se_list_t = []

        for idx, tensor in tensors.items():
            tensor = tensor.cuda()
            output = model(int(idx))
            loss = loss_fn(output, tensor)

            data_num = tensor.numel()
            se_list_t.append(loss.item()*data_num)

        se_list.append(se_list_t)

        logging.info(f'Epoch [{epoch + 1}/{epochs}], Se_list: {se_list_t}, time: {end_time - start_time:.4f}')

    min_se = np.min(np.array(se_list), axis=0)
		
    if loss_measure == 'total':
        repeat_loss.append(math.sqrt(sum(min_se))/total_norms)
    elif loss_measure == 'max':
        min_rse = []
        for idx, se in enumerate(min_se):
            min_rse.append(math.sqrt(se)/norm_list[idx])
        repeat_loss.append(max(min_rse))
    elif loss_measure == 'mean':
        min_rse = []
        for idx, se in enumerate(min_se):
            min_rse.append(math.sqrt(se)/norm_list[idx])
        repeat_loss.append(sum(min_rse)/len(min_rse))
    
    return repeat_loss

def check_and_load(agent_id):
	file_name = base_folder+'/agent_pool/{}.POOL'.format(agent_id)
	if os.stat(file_name).st_size == 0:
		return False, False
	else:
		with open(file_name, 'r') as f:
			goal_name = f.readline()
			data = torch.load(goal_name)
			target_tensors = data['data']
		return True, target_tensors

def memory():
	pid = os.getpid()
	py = psutil.Process(pid)
	memoryUse = py.memory_info()[0]/2.**30 
	print('memory use:', memoryUse)

if __name__ == '__main__':

	Path(base_folder+'/agent_pool/{}.POOL'.format(agent_id)).touch()

	while True:
		
		flag, target_tensors = check_and_load(agent_id)
		
		if flag:

			indv = np.load(base_folder+'/job_pool/{}.npz'.format(agent_id), allow_pickle=True)

			scope = indv['scope'].tolist()
			qtn_seq = indv['qtn_seq']
			repeat = int(indv['repeat'])
			iters = int(indv['iters'])
			qtn_Q = int(indv['qtn_Q'])
			qtn_R = int(indv['qtn_R'])
			init_std = float(indv['init_std'])
			loss_measure = indv['loss_measure'].tolist()
			if type(qtn_seq.tolist()) == list:
				qtn_layer = [tuple(qtn_seq[:, i]) for i in range(qtn_seq.shape[1])]
			elif type(qtn_seq.tolist()) == dict:
				qtn_layer_dict = qtn_seq.tolist()
				qtn_layer = {}
				for k, v in qtn_layer_dict.items():
					qtn_layer[k] = [tuple(v[:, i]) for i in range(v.shape[1])]

			logging.info('Receiving individual {} with repeat times {} , iterations {} and initial std {}...'.format(scope, repeat, iters, init_std))

			try:
				repeat_loss = evaluate(tensors=target_tensors, qtn_layers=qtn_layer, qtn_Q=qtn_Q, qtn_R=qtn_R, epochs=iters, repeat_time=repeat, std=init_std, loss_measure=loss_measure)
				logging.info('Reporting result {}.'.format(repeat_loss))
				np.savez(base_folder+'/result_pool/{}.npz'.format(scope.replace('/', '_')),
									repeat_loss=[ float('{:0.4f}'.format(l)) for l in repeat_loss ],
									qtn_seq=qtn_seq)

				os.remove(base_folder+'/job_pool/{}.npz'.format(agent_id))
				open(base_folder+'/agent_pool/{}.POOL'.format(agent_id), 'w').close()

			except Exception as e:
				os.remove(base_folder+'/agent_pool/{}.POOL'.format(agent_id))
				raise e

			del repeat_loss
			gc.collect()

		time.sleep(1)



