import datetime
import os
import pprint
import time
import threading
import torch as th
from types import SimpleNamespace as SN
from utils.logging import Logger
from utils.timehelper import time_left, time_str
from os.path import dirname, abspath

from learners import REGISTRY as le_REGISTRY
from runners import REGISTRY as r_REGISTRY
from controllers import REGISTRY as mac_REGISTRY
from components.episode_buffer import ReplayBuffer
from components.transforms import OneHot

import shutil

import random

def save_one_buffer(args, save_buffer, env_name, from_start=False):
    x_env_name = env_name
    if from_start:
        x_env_name += '_from_start/'
    path_name = '../buffer/' + x_env_name + '/buffer_' + str(args.save_buffer_id) + '/'
    if os.path.exists(path_name):
        random_name = '../buffer/' + x_env_name + '/buffer_' + str(random.randint(10, 1000)) + '/'
        os.rename(path_name, random_name)
    if not os.path.exists(path_name):
        os.makedirs(path_name)
    save_buffer.save(path_name)

def run(_run, _config, _log):
	# check args sanity
	_config = args_sanity_check(_config, _log)

	args = SN(**_config)
	args.device = "cuda" if args.use_cuda else "cpu"

	# setup loggers
	logger = Logger(_log)

	_log.info("Experiment Parameters:")
	experiment_params = pprint.pformat(_config,
	                                   indent=4,
	                                   width=1)
	_log.info("\n\n" + experiment_params + "\n")

	# configure tensorboard logger
	if 'map_name' in args.env_args:
		unique_token = "{}__{}__{}".format(
			args.name,
			datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"),
			args.env_args['map_name']
		)
	else:
		unique_token = "{}__{}__{}".format(
			args.name,
			datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"),
			args.env
		)
	args.unique_token = unique_token
	if args.use_tensorboard:
		tb_logs_direc = os.path.join(dirname(dirname(abspath(__file__))), "results", "tb_logs")
		tb_exp_direc = os.path.join(tb_logs_direc, "{}").format(unique_token)
		logger.setup_tb(tb_exp_direc)

	# sacred is on by default
	logger.setup_sacred(_run)

	# Run and train
	run_sequential(args=args, logger=logger)

	# Clean up after finishing
	print("Exiting Main")

	print("Stopping all threads")
	for t in threading.enumerate():
		if t.name != "MainThread":
			print("Thread {} is alive! Is daemon: {}".format(t.name, t.daemon))
			t.join(timeout=1)
			print("Thread joined")

	print("Exiting script")

	# Making sure framework really exits
	os._exit(os.EX_OK)


def evaluate_sequential(args, runner):
	if args.test_is_cut:
		if args.test_is_cut_prob:
			print('mu thres:', 0.)
			for _ in range(args.test_nepisode):
				runner.run(test_mode=True, thres=0., is_clean=(_ == 0))
			thres = args.test_cut_prob_thres
			for prob in args.test_cut_prob_list:
				print('mu+prob thres:', thres, prob)
				for _ in range(args.test_nepisode):
					runner.run(test_mode=True, thres=thres, prob=prob, is_clean=(_ == 0))
		else:
			for thres in args.test_cut_list:
				print('mu thres:', thres)
				for _ in range(args.test_nepisode):
					runner.run(test_mode=True, thres=thres, is_clean=(_ == 0))
	else:
		for _ in range(args.test_nepisode):
			batch=runner.run(test_mode=True)
			print(batch['reward'].sum())

	if args.save_replay:
		runner.save_replay()

	runner.close_env()


def run_sequential(args, logger):
	# Init runner so we can get env info
	runner = r_REGISTRY[args.runner](args=args, logger=logger)

	# Set up schemes and groups here
	env_info = runner.get_env_info()
	args.n_agents = env_info["n_agents"]
	args.n_actions = env_info["n_actions"]
	args.state_shape = env_info["state_shape"]
	args.obs_shape=env_info["obs_shape"]

	# Default/Base scheme
	scheme = {
		"state": {"vshape": env_info["state_shape"]},
		"obs": {"vshape": env_info["obs_shape"], "group": "agents"},
		"messages": {"vshape": args.comm_embed_dim*args.n_agents, "group": "agents"},
		"actions": {"vshape": (1,), "group": "agents", "dtype": th.long},
		"avail_actions": {"vshape": (env_info["n_actions"],), "group": "agents", "dtype": th.int},
		"reward": {"vshape": (1,)},
		"terminated": {"vshape": (1,), "dtype": th.uint8},
	}
	scheme_old = {
		"state": {"vshape": env_info["state_shape"]},
		"obs": {"vshape": env_info["obs_shape"], "group": "agents"},
		"messages": {"vshape": args.comm_embed_dim*args.n_agents, "group": "agents"},
		"actions": {"vshape": (1,), "group": "agents", "dtype": th.long},
		"avail_actions": {"vshape": (env_info["n_actions"],), "group": "agents", "dtype": th.int},
		"reward": {"vshape": (1,)},
		"terminated": {"vshape": (1,), "dtype": th.uint8},
	}
	print(scheme)
	groups = {
		"agents": args.n_agents
	}
	preprocess = {
		"actions": ("actions_onehot", [OneHot(out_dim=args.n_actions)])
	}

	env_name = args.env
	if env_name == 'sc2':
		env_name += '/' + args.env_args['map_name']
		
	buffer = ReplayBuffer(scheme_old, groups, args.buffer_size, env_info["episode_limit"] + 1,
	                      preprocess=preprocess,
	                      device="cpu" if args.buffer_cpu_only else args.device)

	if args.is_save_buffer:
		save_buffer = ReplayBuffer(scheme, groups, args.save_buffer_size, env_info["episode_limit"] + 1,
									preprocess=preprocess,
									device="cpu")# if args.buffer_cpu_only else args.device)

	if args.is_batch_rl:
		env_name = args.env
		if env_name == 'sc2':
			env_name += '/' + args.env_args['map_name']
		#assert (args.is_save_buffer == False)
		x_env_name = env_name
		if args.is_from_start:
			x_env_name += '_from_start/'
		path_name = '../buffer/' + x_env_name + '/buffer_' + str(args.load_buffer_id) + '/'
		print('path',path_name)
		assert (os.path.exists(path_name) == True)
		buffer.load(path_name)

	# Setup multiagent controller here
	mac = mac_REGISTRY[args.mac](buffer.scheme, groups, args)

	# Give runner the scheme
	runner.setup(scheme=scheme, groups=groups, preprocess=preprocess, mac=mac)

	# Learner
	learner = le_REGISTRY[args.learner](mac, buffer.scheme, logger, args)

	if args.use_cuda:
		learner.cuda()

	if args.checkpoint_path != "":

		timesteps = []
		timestep_to_load = 0

		if not os.path.isdir(args.checkpoint_path):
			logger.console_logger.info("Checkpoint directiory {} doesn't exist".format(args.checkpoint_path))
			return

		# Go through all files in args.checkpoint_path
		for name in os.listdir(args.checkpoint_path):
			full_name = os.path.join(args.checkpoint_path, name)
			# Check if they are dirs the names of which are numbers
			if os.path.isdir(full_name) and name.isdigit():
				timesteps.append(int(name))

		timesteps.sort()

		if args.make_message_distribution_video and args.draw_message_distributions:
			save_dir = args.checkpoint_path.replace('models', 'plots')
			if os.path.exists(save_dir):
				shutil.rmtree(save_dir)
			os.mkdir(save_dir)

			for timestep in timesteps:
				model_path = os.path.join(args.checkpoint_path, str(timestep))

				logger.console_logger.info("Loading model from {}".format(model_path))
				learner.load_models(model_path)
				args.loaded_model_ts = timestep

				episode_batch = runner.run(test_mode=False)
		else:
			if args.load_step == 0:
				# choose the max timestep
				timestep_to_load = max(timesteps)
			else:
				# choose the timestep closest to load_step
				timestep_to_load = min(timesteps, key=lambda x: abs(x - args.load_step))
			print('timestep load',timestep_to_load)

			model_path = os.path.join(args.checkpoint_path, str(timestep_to_load))

			logger.console_logger.info("Loading model from {}".format(model_path))
			learner.load_models(model_path)
			runner.t_env = timestep_to_load

			if args.evaluate or args.save_replay:
				evaluate_sequential(args, runner)
				return
	
	if args.comm_checkpoint_path!='':
		learner.load_comm(args.comm_checkpoint_path)

	# start training
	episode = 0
	last_test_T = -args.test_interval - 1
	last_log_T = 0
	model_save_time = 0

	start_time = time.time()
	last_time = start_time

	logger.console_logger.info("Beginning training for {} timesteps".format(args.t_max))

	save_list=[500,1000,1500,2000,2500,3000,3500,4000,4500]
	while runner.t_env <= args.t_max:
		# Run for a whole episode at a time
		episode_batch = runner.run(test_mode=False)# todo: make right

		if not args.is_batch_rl:

			if args.draw_message_distributions:
				return

			buffer.insert_episode_batch(episode_batch)
			#print('buffer episodes',buffer.episodes_in_buffer)

			if args.is_save_buffer and save_buffer.is_from_start:
				'''# by rew
				rew=episode_batch['reward']
				for i in range(rew.shape[0]):
					if rew[i].sum()>=15:
						save_buffer.insert_episode_batch(episode_batch[i])'''
				
				# by t
				if runner.t_env>=0:#5000000:#2000000:
					save_buffer.insert_episode_batch(episode_batch)
					print('current episodes_in_buffer: ', save_buffer.episodes_in_buffer)
					
					#for mmm
					'''if save_buffer.episodes_in_buffer>save_list[0]:
						save_one_buffer(args, save_buffer, env_name, from_start=True)
						save_list=save_list[1:]'''

				'''# original
				save_buffer.insert_episode_batch(episode_batch)'''
				if save_buffer.is_from_start and save_buffer.episodes_in_buffer == save_buffer.buffer_size:
					save_buffer.is_from_start = False
					save_one_buffer(args, save_buffer, env_name, from_start=True)
				if save_buffer.buffer_index % args.save_buffer_interval == 1:
					print('current episodes_in_buffer: ', save_buffer.episodes_in_buffer)

		if buffer.can_sample(args.batch_size):
			episode_sample = buffer.sample(args.batch_size)

			# Truncate batch to only filled timesteps
			max_ep_t = episode_sample.max_t_filled()
			episode_sample = episode_sample[:, :max_ep_t]

			if episode_sample.device != args.device:
				episode_sample.to(args.device)

			learner.train(episode_sample, runner.t_env, episode)

		# Execute test runs once in a while
		n_test_runs = max(1, args.test_nepisode // runner.batch_size)
		if (runner.t_env - last_test_T) / args.test_interval >= 1.0:

			logger.console_logger.info("t_env: {} / {}".format(runner.t_env, args.t_max))
			logger.console_logger.info("Estimated time left: {}. Time passed: {}".format(
				time_left(last_time, last_test_T, runner.t_env, args.t_max), time_str(time.time() - start_time)))
			last_time = time.time()

			last_test_T = runner.t_env
			'''
			for i in range(6):
				for _ in range(n_test_runs):
					runner.run(test_mode=True, thres=i * 20.)
			'''
			for i in [100.]:#[90., 95., 100.]:
				for _ in range(n_test_runs):
					runner.run(test_mode=True, thres=i)

		if args.save_model and (runner.t_env - model_save_time >= args.save_model_interval or model_save_time == 0):
			model_save_time = runner.t_env
			save_path = os.path.join(args.local_results_path, "models", args.unique_token, str(runner.t_env))
			# "results/models/{}".format(unique_token)
			os.makedirs(save_path, exist_ok=True)
			logger.console_logger.info("Saving models to {}".format(save_path))

			# learner should handle saving/loading -- delegate actor save/load to mac,
			# use appropriate filenames to do critics, optimizer states
			learner.save_models(save_path)

		episode += args.batch_size_run

		if (runner.t_env - last_log_T) >= args.log_interval:
			logger.log_stat("episode", episode, runner.t_env)
			logger.print_recent_stats()
			last_log_T = runner.t_env

	runner.close_env()
	logger.console_logger.info("Finished Training")


def args_sanity_check(config, _log):
	# set CUDA flags
	# config["use_cuda"] = True # Use cuda whenever possible!
	if config["use_cuda"] and not th.cuda.is_available():
		config["use_cuda"] = False
		_log.warning("CUDA flag use_cuda was switched OFF automatically because no CUDA devices are available!")

	if config["test_nepisode"] < config["batch_size_run"]:
		config["test_nepisode"] = config["batch_size_run"]
	else:
		config["test_nepisode"] = (config["test_nepisode"] // config["batch_size_run"]) * config["batch_size_run"]

	return config
