import argparse
import os
import time
import gymnasium as gym
import numpy as np
import HA3C
import copy

def train(RL_agent, env, eval_env, args):
	evals = []
	start_time = time.time()
	allow_train = False
	state,_ = env.reset(seed=args.seed)
	ep_total_reward, ep_timesteps, ep_num = 0, 0, 1

	con_state = np.zeros((args.sampling_size, state_dim))
	con_next_state = np.zeros((args.sampling_size, state_dim))
	con_state = con_state.reshape(1, 1, args.sampling_size, state_dim)
	con_next_state = con_next_state.reshape(1, 1, args.sampling_size, state_dim)

	for t in range(int(args.max_timesteps + 1)):
		maybe_evaluate_and_print(RL_agent, eval_env, evals, t, start_time, args)

		if allow_train:
			action = RL_agent.select_action(np.array(state), con_state)
		else:
			action = env.action_space.sample()

		next_state, reward, ep_finished, finished, _ = env.step(action)
		if finished == True:
			ep_finished = True

		ep_total_reward += reward
		ep_timesteps += 1

		done = float(ep_finished) if ep_timesteps < env._max_episode_steps else 0
		con_next_state[0, 0, 0:args.sampling_size - 1, :] = copy.deepcopy((con_state[0, 0, 1:args.sampling_size, :]))
		con_next_state[0, 0, args.sampling_size - 1, :] = copy.deepcopy(state)
		RL_agent.replay_buffer.add(state, action, next_state, reward, done, con_state, con_next_state)

		state = next_state
		con_state = copy.deepcopy(con_next_state)

		if allow_train and not args.use_checkpoints:
			RL_agent.train()
		if ep_finished:
			print(f"Total T: {t + 1} Episode Num: {ep_num} Episode T: {ep_timesteps} Reward: {ep_total_reward:.3f}")

			if allow_train and args.use_checkpoints:
				RL_agent.maybe_train_and_checkpoint(ep_timesteps, ep_total_reward)

			if t >= args.timesteps_before_training:
				allow_train = True
			done = False
			state,_ = env.reset(seed=args.seed)
			ep_total_reward, ep_timesteps = 0, 0
			con_state = np.zeros((args.sampling_size, state_dim))
			con_state = con_state.reshape(1, 1, args.sampling_size, state_dim)
			ep_num += 1





def maybe_evaluate_and_print(RL_agent, eval_env, evals, t, start_time, args):
	if t % args.eval_freq == 0:
		print("---------------------------------------")
		print(f"Evaluation at {t} time steps")
		print(f"Total time passed: {round((time.time()-start_time)/60.,2)} min(s)")

		total_reward = np.zeros(args.eval_eps)
		for ep in range(args.eval_eps):
			done = False
			state,_ = eval_env.reset(seed = args.seed)
			con_state = np.zeros((args.sampling_size, state_dim))
			con_next_state = np.zeros((args.sampling_size, state_dim))
			while not done:
				con_next_state[0:args.sampling_size - 1, :] = copy.deepcopy(con_state[1:args.sampling_size, :])
				con_next_state[args.sampling_size - 1, :] = copy.deepcopy(state)
				action = RL_agent.select_action(np.array(state), con_state, args.use_checkpoints, use_exploration=False)
				con_state = copy.deepcopy(con_next_state)
				state, reward, done, finished,_= eval_env.step(action)
				if finished == True:
					done = True
				total_reward[ep] += reward
		print(f"Average total reward over {args.eval_eps} episodes: {total_reward.mean():.3f}")
		print("---------------------------------------")

		evals.append(total_reward)
		np.save(f"results/{args.file_name}", evals)


if __name__ == "__main__":
	parser = argparse.ArgumentParser()
	# RL
	parser.add_argument("--env", default="Ant-v4", type=str)
	parser.add_argument("--seed", default=0, type=str)
	parser.add_argument('--use_checkpoints', default=True, action=argparse.BooleanOptionalAction)
	parser.add_argument("--timesteps_before_training", default=25e3, type=int)
	parser.add_argument('--sampling_size', default=5, type=int)
	parser.add_argument('--discount', default=0.99, type=float)
	parser.add_argument("--eval_freq", default=5e3, type=int)
	parser.add_argument("--eval_eps", default=10, type=int)
	parser.add_argument("--max_timesteps", default=1e6, type=int)

	# File
	parser.add_argument('--file_name', default=None)
	args = parser.parse_args()


	if args.file_name is None:
		args.file_name = f"HA3C_{args.env}_{args.seed}"

	if not os.path.exists("results"):
		os.makedirs("results")

	env = gym.make(args.env)
	eval_env = gym.make(args.env)
	seed = args.seed

	print("---------------------------------------")
	print(f"Algorithm: HA3C Env: {args.env}")
	print("---------------------------------------")


	state_dim = env.observation_space.shape[0]
	action_dim = env.action_space.shape[0]
	max_action = float(env.action_space.high[0])

	RL_agent = HA3C.Agent(state_dim, action_dim, max_action, args.sampling_size, args.discount)


	train(RL_agent, env, eval_env, args)