from agents.sr.sr import Agent, TestAgent
from core.utils import pprint, str2bool
import numpy as np, os, time, torch
import core.utils as utils
import argparse
import random


class SR_Trainer:
	"""Policy Gradient Algorithm main object which carries out off-policy learning using policy gradient
	   Encodes all functionalities for 1. TD3 2. DDPG 3.Trust-region TD3/DDPG 4. Advantage TD3/DDPG

			Parameters:
				args (int): Parameter class with all the parameters

			"""

	def __init__(self, args, model_constructior, env_constructor):
		self.args = args

		######### Initialize the Multiagent Team of agents ########
		self.agent = Agent(self.args)

		####### Initialize Test Agent/s #######
		self.test_agents = []
		self.test_agents.append(TestAgent(self.args, source='sr'))


		#### STATS AND TRACKING WHICH ROLLOUT IS DONE ######
		self.total_frames = 0


	def forward_gen(self, gen):
		"""Main training loop to do rollouts and run policy gradients

			Parameters:
				gen (int): Current epoch of training

			Returns:
				None
		"""

		########## START EVO ROLLOUT ##########
		popn_fits, frames, test_fit = self.agent.evo_rollouts()
		self.total_frames += frames


		# # Test Rollout
		# if gen % self.args.test_gap == 0:
		# 	for test_agent in self.test_agents:
		# 		test_agent.start_test_rollout(self.agent)


		# Evolution Step
		self.agent.evolve()


		# ####### JOIN TEST ROLLOUTS ########
		# if gen % self.args.test_gap == 0:
		# 	for test_agent in self.test_agents:
		# 		test_agent.join_test_rollout(self.total_frames)


		return popn_fits, test_fit


	def train(self):
		"""Main training loop to do rollouts and run policy gradients

			Parameters:
				gen (int): Current epoch of training

			Returns:
				None
		"""

		time_start = time.time()

		###### TRAINING LOOP ########
		for gen in range(1, self.args.frames_bound):  # RUN VIRTUALLY FOREVER

			# ONE EPOCH OF TRAINING
			popn_fits, test_fit = self.forward_gen(gen)

			# PRINT PROGRESS
			print()
			print('Ep:/Frames', gen, '/', self.total_frames, 'Popn stat:', utils.list_stat(popn_fits),
				  'Test_Fitness', test_fit,
				  'FPS:', pprint(self.total_frames / (time.time() - time_start)),
				  'Savetag', self.args.savetag)
			# for test_agent in self.test_agents:
			# 	print(test_agent.source, test_agent.trace[-5:])


			if self.total_frames > self.args.frames_bound:
				break

		###Kill all processes
		self.agent.terminate()
		self.test_agent.terminate()
		try:
			self.global_test_agent.terminate()
		except:
			None








