from agents.autobot.autobot import Agent, TestAgent
from core.utils import pprint, str2bool
import numpy as np, os, time, torch
import core.utils as utils
import argparse
import random


class Autobot_Trainer:
	"""Policy Gradient Algorithm main object which carries out off-policy learning using policy gradient
	   Encodes all functionalities for 1. TD3 2. DDPG 3.Trust-region TD3/DDPG 4. Advantage TD3/DDPG

			Parameters:
				args (int): Parameter class with all the parameters

			"""

	def __init__(self, args, model_constructor, env_constructor):
		self.args = args

		######### Initialize the Multiagent Team of agents ########
		self.agent = Agent(self.args, model_constructor, env_constructor)

		####### Initialize Test Agent/s #######
		self.test_agents = []
		if self.args.evo_popn_size > 0:
			self.test_agents.append(TestAgent(self.args, model_constructor, env_constructor,source='evo'))
		if self.args.global_update:
			self.test_agents.append(TestAgent(self.args, model_constructor, env_constructor, source='pg'))


		#### STATS AND TRACKING WHICH ROLLOUT IS DONE ######
		self.total_frames = 0


	def forward_gen(self, gen):
		"""Main training loop to do rollouts and run policy gradients

			Parameters:
				gen (int): Current epoch of training

			Returns:
				None
		"""
		print("Generation {}".format(gen))
		########## AutoReward Update using Recipe #######
		start_auto_reward_time = time.time()
		self.agent.apply_auto_reward()
		end_auto_reward_time = time.time()
		print("Time take to do autoreward {}".format(end_auto_reward_time-start_auto_reward_time))
		
		########## START EVO ROLLOUT ##########
		start_auto_reward_time = time.time()
		if self.args.evo_popn_size > 0:
			self.agent.start_evo_rollouts()
		end_auto_reward_time = time.time()
		print("Time take to do Rollout {}".format(end_auto_reward_time-start_auto_reward_time))

		########## START POLICY GRADIENT ROLLOUT + GLOBAL CRITIC UPDATE ##########
		if self.args.rollout_size > 0:
			self.agent.start_pg_rollouts()

			#GLobal Parameter Update Start
			if self.args.global_update: self.agent.global_update()


		####### JOIN EVO ROLLOUTS ########
		if self.args.evo_popn_size > 0:
			popn_fits, frames = self.agent.join_evo_rollouts()
			self.total_frames += frames
		else:
			popn_fits = []

		# Test Rollout
		if gen % self.args.test_gap == 0:
			for test_agent in self.test_agents:
				test_agent.start_test_rollout(self.agent)



		####### JOIN PG ROLLOUTS ########
		if self.args.rollout_size > 0:
			pg_fits, frames = self.agent.join_pg_rollouts()
			self.total_frames += frames
		else:
			pg_fits = []


		# Evolution Step
		if self.args.evo_popn_size > 0: self.agent.evolve()

		#Sync parameters in the featur extractor
		if self.args.ps: self.agent.consolidate_features()

		####### JOIN TEST ROLLOUTS ########
		if gen % self.args.test_gap == 0:
			for test_agent in self.test_agents:
				test_agent.join_test_rollout(self.total_frames)



		# #Save models periodically
		# if gen % 20 == 0:
		# 	for id, test_actor in enumerate(self.test_agent.rollout_actor):
		# 		torch.save(test_actor.state_dict(), self.args.model_save + str(id) + '_' + self.args.actor_fname)
		# 	print("Models Saved")


		return pg_fits, popn_fits


	def train(self):
		"""Main training loop to do rollouts and run policy gradients

			Parameters:
				gen (int): Current epoch of training

			Returns:
				None
		"""

		time_start = time.time()

		###### TRAINING LOOP ########
		for gen in range(1, self.args.total_frames):  # RUN VIRTUALLY FOREVER

			# ONE EPOCH OF TRAINING
			pg_fits, popn_fits = self.forward_gen(gen)
			print('Ep:/Frames', gen, '/', self.total_frames, 'Popn stat:', utils.list_stat(popn_fits),
				  'PG_stat:', utils.list_stat(pg_fits),
				  'FPS:', pprint(self.total_frames / (time.time() - time_start)),
				  'Ratio', self.args.ratio,
				  'Savetag', self.args.savetag)
			for test_agent in self.test_agents:
				print(test_agent.source, test_agent.trace[-5:])

			try:
				for agent_id in range(self.config.num_agents):
					print('Agent Recipe:', self.agent.reward_recipes[agent_id].popn[self.agent.champ_ind[agent_id]])
			except:
					print('Champ is a SSNE agent')
			print()

			# if gen % 10 == 0 and self.args.rollout_size > 0:
			# 	print()
			# 	print('Q', pprint(self.agent.algo.q['mean']))
			# 	print('Q_loss', pprint(self.agent.algo.q_loss['mean']))
			# 	print('Policy', pprint(self.agent.algo.policy_loss['mean']))


			if self.total_frames > self.args.total_frames:
				break

		###Kill all processes
		self.agent.terminate()
		for test_agent in self.test_agents:
			test_agent.terminate()
		try:
			self.global_test_agent.terminate()
		except:
			None








