import os, sys  
os.environ["MKL_NUM_THREADS"] = "1" 
os.environ["NUMEXPR_NUM_THREADS"] = "1" 
os.environ["OMP_NUM_THREADS"] = "1"
import numpy as np
import argparse
import shutil

from config import *
from utilis import *
from types import SimpleNamespace
from datetime import datetime
from algs import REGISTRY as alg_REGISTRY
from envs import REGISTRY as env_REGISTRY
from torch.utils.tensorboard import SummaryWriter
from os.path import dirname, abspath
parser = argparse.ArgumentParser()
parser.add_argument("--alg", type=str, default="DC2Net",help="The algorithm used in this experiment")
parser.add_argument("--env", type=str, default="mpe",help="The enviroment used in this experiment")
parser.add_argument("--map", type=str, default="ss", help="The map tested in this experiment")
parser.add_argument("--info", type=str, default="default setting", help="detailed settings in this experiment")
args = parser.parse_args()

env = env_REGISTRY[args.env][args.map].parallel_env()
env.reset()
env_args = {}
env_args['n_ant'] = env.num_agents
env_args['n_actions'] = list(env.action_spaces.values())[0].n
env_args['obs_space'] = list(env.observation_spaces.values())[0].shape[0] + env_args['n_ant']
env_args['state_shape'] = env.state_space.shape[0]

files = {}
files['alg'] = "algs/{}.py".format(args.alg)
files['model'] = "models/{}.py".format(args.alg)
files['config'] = "config.py"
files['main'] = "main.py"

env_args = SimpleNamespace(**env_args)
env_args.files = files
env_args.time_token = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
print(env_args)

alg = alg_REGISTRY[args.alg](env_args)

file_path = os.path.join('results', args.env, args.map, args.alg, env_args.time_token)
os.makedirs(file_path, exist_ok=True)
f_info = open(os.path.join(file_path, 'description.txt'),'w')
f_info.write(args.info)
f_info.close()

f = open(os.path.join(file_path, 'log.csv'),'w')
f.write('timestep,reward'+'\n')
f.flush()
writer = SummaryWriter(file_path)

os.makedirs(os.path.join(file_path, "code_backup"))

for k, v in env_args.files.items():
	print(os.path.join(dirname(abspath(__file__)), v))
	shutil.copyfile(os.path.join(dirname(abspath(__file__)), v), os.path.join(file_path, "code_backup", "{}.py".format(k)))

while i_episode < n_episode:
	if time_step > max_timestep:
		break
	if i_episode > 100:
		epsilon -= 0.001
		if epsilon < 0.02:
			epsilon = 0.02
	i_episode += 1
	obs = env.reset()
	terminated = False
	obs_list = list(obs.values())
	obs = get_obs(obs_list,env_args.n_ant)
	adj = np.eye(env_args.n_ant)
	mask = np.ones([env_args.n_ant, env_args.n_actions])
	ep_reward = 0
	agent_list = env.agents
	while not terminated:
		test_flag += 1
		time_step += 1
		# MPE has no action mask
		action = alg.forward(obs, adj, mask, epsilon)
		action_dict = dict(zip(agent_list, action))
		next_obs, reward, dones, infos = env.step(action_dict)
		reward = sum(list(reward.values()))
		terminated = all(list(dones.values()))
		next_obs_list = list(next_obs.values())
		ep_reward += reward
		next_obs = get_obs(next_obs_list,env_args.n_ant)
		next_adj = np.eye(env_args.n_ant)
		next_mask = np.ones([env_args.n_ant, env_args.n_actions])
		alg.addbuff(np.array(obs),action,reward,np.array(next_obs),adj,next_adj,next_mask,terminated)

		obs = next_obs

	if i_episode < 100:
		continue

	for epoch in range(n_epoch):	
		loss = alg.train()
		
	if test_flag > 1000:
		log_r = alg.test(env)
		print(log_r)
		log_data(f, alg.model, loss, log_r, ep_reward, time_step, writer)
		test_flag = 0
		# result_plot(file_path)
		