import sys
import copy
import time
import random
import threading
import numpy as np
import torch
from torch.multiprocessing import Process, Pipe
from torch.multiprocessing import Manager

from algos.ddqn import MultiDDQN
from algos.MaxminDQN import MaxminDQN
from algos.sac import MultiSAC
from components.replay import ReplayBuffer
from algos.ssne import SSNE
from gp.gp import GP_Population
from core.runner import rollout_worker
import core.utils as utils



class Agent:
	"""Learner object encapsulating a local learner

		Parameters:
		algo_name (str): Algorithm Identifier
		state_dim (int): State size
		action_dim (int): Action size
		actor_lr (float): Actor learning rate
		critic_lr (float): Critic learning rate
		gamma (float): DIscount rate
		tau (float): Target network sync generate
		init_w (bool): Use kaimling normal to initialize?
		**td3args (**kwargs): arguments for TD3 algo


		Population of Models self.popn --> [popn_id, agent_id]
		Rollout Team self.rollout_team --> [1, agent_id]
		AutoRewards: self.reward_recipe --> [agent_id]
		Fitnesses: self.fitnesses --> [popn_id, agent_id, []]
		Champ_ID: self.champ_id --> [agent_id]


	"""

	def __init__(self, args, model_constructor, env_constructor):
		self.args = args
		self.manager = Manager()
		self.policy_type = 'MaxminDQN' if env_constructor.is_discrete else 'Gaussian_FF'
		self.model_constructor = model_constructor

		###Initalize neuroevolution module###
		self.evolver = [SSNE(self.args) for _ in range(args.num_agents)]

		##Log numbe of GP-based Autobots and normal SNNE policies (NE)
		self.num_autobots = int(self.args.ratio*self.args.evo_popn_size)
		self.num_ssne_policies = self.args.evo_popn_size - self.num_autobots

		self.state_dim, self.action_dim = env_constructor.state_dim, env_constructor.action_dim

		########Initialize population
		self.popn = self.manager.list()
		for _ in range(args.evo_popn_size):
			self.popn.append([model_constructor.make_model(self.policy_type) for _ in range(args.num_agents)])
		#Autobot Recipes
		if env_constructor.is_discrete:
			recipe_dimension = 1 + self.state_dim
		else:
			recipe_dimension = 2*self.state_dim + self.action_dim
		self.reward_recipes = [GP_Population(recipe_dimension, 1, \
							  self.num_autobots, args.elite_ratio, args.kill_ratio,\
							   args.lineage_alpha) for _ in range(args.num_agents)]

		#### Rollout Actor is a template used for MP #####
		self.rollout_team = self.manager.list()
		self.rollout_team.append([model_constructor.make_model(self.policy_type) for _ in range(args.num_agents)])

		### HOF Team ###
		self.hof = [model_constructor.make_model(self.policy_type) for _ in range(args.num_agents)]

		#### INITIALIZE PG ALGO #####
		if env_constructor.is_discrete:
			self.algo = MaxminDQN(args, model_constructor)
		else:
			self.algo = MultiSAC(args, model_constructor)

		#Initalize buffer
		self.buffer  = ReplayBuffer(args.num_agents, args.buffer_size, args.batch_size)
		#Agent metrics
		self.fitnesses = [[[] for _ in range(args.num_agents)] for _ in range(args.evo_popn_size)]
		self.shaped_fitnesses = [[[] for _ in range(args.num_agents)] for _ in range(args.evo_popn_size)]


		###Best Policy HOF####
		self.champ_ind = [0 for _ in range(args.num_agents)]

		###### Buffer and Model Bucket as references to the corresponding agent's attributes ####
		self.buffer_bucket = None  # self.agent.buffer.tuples


		################### Setup MP Workers to do rollout using Python's MP ###############

		######### EVOLUTIONARY WORKERS ############
		store_transitions = self.args.rollout_size > 0 or self.args.ratio > 0
		if self.args.evo_popn_size > 0:
			self.evo_task_pipes = [Pipe() for _ in range(args.evo_popn_size * args.num_evals)]
			self.evo_result_pipes = [Pipe() for _ in range(args.evo_popn_size * args.num_evals)]
			self.evo_workers = [Process(target=rollout_worker,
								args=(self.args, i, env_constructor, 'evo', \
								self.evo_task_pipes[i][1], self.evo_result_pipes[i][0], \
								self.popn, store_transitions, False)) \
								for i in range(args.evo_popn_size * args.num_evals)]
			for worker in self.evo_workers: worker.start()

		######### POLICY GRADIENT WORKERS ############
		if self.args.rollout_size > 0:
			self.pg_task_pipes = [Pipe() for _ in range(args.rollout_size)]
			self.pg_result_pipes = [Pipe() for _ in range(args.rollout_size)]
			self.pg_workers = [Process(target=rollout_worker,
			                   args=(self.args, id, env_constructor, 'pg', \
							   self.pg_task_pipes[id][1], self.pg_result_pipes[id][0], \
							   self.rollout_team, self.args.rollout_size > 0, False)) \
							   for id in range(args.rollout_size)]
			for worker in self.pg_workers: worker.start()


	def start_evo_rollouts(self):
		for team_id, pipe in enumerate(self.evo_task_pipes):
				pipe[0].send(team_id)

	def start_pg_rollouts(self):
		########## START POLICY GRADIENT ROLLOUT ##########
		if self.args.rollout_size > 0:
			# Synch pg_actors to its corresponding rollout_bucket
			self.update_rollout_team()

			# Start rollouts using the rollout actors
			for pipe in self.pg_task_pipes:
				pipe[0].send(0)


	def join_evo_rollouts(self):
		all_fits = []; total_frames = 0
		for idx, pipe in enumerate(self.evo_result_pipes):
			entry = pipe[1].recv()
			team_id = entry[0]; fitness = entry[1]; frames = entry[2]; entropy_fitness = entry[4]; shaped_fitness = entry[5]
			self.buffer.add(entry[3])
			for agent_id in range(self.args.num_agents):
				self.fitnesses[team_id][agent_id].append(utils.list_mean(fitness))  ##Assign
				#self.fitnesses[team_id][agent_id].append(entropy_fitness)
				self.shaped_fitnesses[team_id][agent_id] = shaped_fitness

			all_fits.append(utils.list_mean(fitness))

			total_frames += frames
		self.update_champ_ind()

		return all_fits, total_frames

	def join_pg_rollouts(self):
		pg_fits = []; total_frames = 0
		for pipe in self.pg_result_pipes:
			entry = pipe[1].recv()

			pg_fits.append(entry[1])
			total_frames += entry[2]
			self.buffer.add(entry[3])

		return pg_fits, total_frames


	def apply_auto_reward(self):

		#self.buffer.referesh()
		if self.buffer.__len__() < self.args.learning_start:
			self.buffer.pg_frames = 0
			('AUTOBOTS BURN IN: BUFFER AT', self.buffer.__len__())
			return
		threads = []; return_policies = [[None] * self.args.num_agents for _ in  range(self.num_autobots)]
		for popn_id in range(self.num_autobots):
			for agent_id in range(self.args.num_agents):
				threads.append(threading.Thread(target=self.algo.autoreward_update, \
						args=(self.buffer, self.reward_recipes[agent_id], self.popn[popn_id][agent_id],\
						agent_id, popn_id, self.args, return_policies), daemon=True))
				threads[-1].start()

		# # Start threads
		# for thread in threads: thread.start()
		#
		#Join threads
		for thread in threads: thread.join()

		#Sync params to the self.popn manager object from teh intermediary list
		for popn_id in range(self.num_autobots):
			for agent_id in range(self.args.num_agents):
				#print('Diff', self.net_diff(self.popn[popn_id][agent_id], return_policies[popn_id][agent_id]))
				utils.hard_update(self.popn[popn_id][agent_id], return_policies[popn_id][agent_id])
				#print('Diff After', self.net_diff(self.popn[popn_id][agent_id], return_policies[popn_id][agent_id]))


	def net_diff(self, n1, n2):
		diff = 0
		for p1, p2 in zip(n1.parameters(), n2.parameters()):
			diff += torch.abs(p1.cpu() - p2.cpu()).sum().detach().cpu().item()
		return diff

	def global_update(self):

		#self.buffer.referesh()
		if self.buffer.__len__() < self.args.learning_start:
			self.buffer.pg_frames = 0
			print('GLOBAL CRITIC BURN IN')
			return


		for _ in range(int(self.args.gradperstep * self.buffer.pg_frames)):
			self.algo.global_update(self.args.reward_scaling, self.buffer)

		self.buffer.pg_frames = 0


	def evolve(self):


		#GP Evolution
		if self.num_autobots > 0:

			if self.buffer.__len__() >= self.args.learning_start:
				print('AUTOBOTS START TRAINING')

				#Fill in fitnesses
				for popn_id in range(self.num_autobots):
					for agent_id, recipe in enumerate(self.reward_recipes):
						recipe.popn[popn_id].fitness.values = (utils.list_mean(self.fitnesses[popn_id][agent_id]),)

				#Evolve each recipe independently
				for agent_id, recipe in enumerate(self.reward_recipes):
					unselects, elites, new_elites = recipe.evolve(self.args.crossover_prob, self.args.mutation_prob)

					#Preserve the elite nets' weights [Copy from old elite to new elite]
					for i,j in zip(elites, new_elites):
						utils.hard_update(self.popn[j][agent_id], self.popn[i][agent_id])
						self.algo.preserve_critic(agent_id, j, i)

					#Reset all unselects' weights
					for ind in unselects:
						new_net = self.model_constructor.make_model(self.policy_type)
						utils.hard_update(self.popn[ind][agent_id], new_net)
						self.algo.reset_critic(ind, agent_id)


		#Normal SSNE Neuroevolution
		if self.num_ssne_policies > 0:

			for agent_id, ssne in enumerate(self.evolver):
				popn_i = [self.popn[ind][agent_id] for ind in range(self.num_autobots, self.args.evo_popn_size)]
				fitnesses_i = [utils.list_mean(self.fitnesses[ind][agent_id]) for ind in range(self.num_autobots, self.args.evo_popn_size)]
				shaped_fits_i = [self.shaped_fitnesses[ind][agent_id] for ind in range(self.num_autobots, self.args.evo_popn_size)]
				net_inds = range(self.num_ssne_policies)

				if self.args.rollout_size > 0:
					migration = [self.rollout_team[0][agent_id].cpu()]
				else: migration = []

				self.evolver[agent_id].evolve(popn_i, net_inds, fitnesses_i, migration, None, shaped_fits_i)
				#The nets in self.popn tested to have been correctly changed from evolve()



		# Reset fitness metrics
		self.fitnesses = [[[] for _ in range(self.args.num_agents)] for _ in range(self.args.evo_popn_size)]
		self.shaped_fitnesses = [[[] for _ in range(self.args.num_agents)] for _ in range(self.args.evo_popn_size)]

	def update_champ_ind(self):
		for agent_id in range(self.args.num_agents):
			fit = np.squeeze(np.array(self.fitnesses)[:, agent_id, :])
			self.champ_ind[agent_id] = np.argmax(fit)
			#print(agent_id, self.champ_ind[agent_id], len(self.popn), len(self.popn[0]))

			utils.hard_update(self.hof[agent_id], self.popn[self.champ_ind[agent_id]][agent_id])
			#print(fit.shape, self.champ_ind)
			try:
				print('Avg_Reward:', self.reward_recipes[agent_id].avg_response[self.champ_ind[agent_id]], 'Recipe',self.reward_recipes[agent_id].popn[self.champ_ind[agent_id]])
			except:
				None


	def update_rollout_team(self):
		for agent_id, actor in enumerate(self.rollout_team[0]):
			self.algo.policies[agent_id].cpu()
			utils.hard_update(actor, self.algo.policies[agent_id])
			if torch.cuda.is_available(): self.algo.policies[agent_id].cuda()

	def consolidate_features(self):
		#Get All Keys
		all_keys = list(self.algo.policies[0].feature_extractor.state_dict())

		#Make a W initilaizing with the first policy
		W = {}
		for key in all_keys: W[key] = self.algo.policies[0].feature_extractor.state_dict()[key]/len(self.algo.policies)

		#Collect for all other policies
		for i, net in enumerate(self.algo.policies):
			if i == 0: continue

			for key in all_keys:
				W[key] += net.feature_extractor.state_dict()[key]/len(self.algo.policies)

		#Assign Weights
		for net in self.algo.policies:
			for key in all_keys:
				net.feature_extractor.state_dict()[key] = W[key]

	def terminate(self):
		try:
			for p in self.pg_task_pipes: p[0].send('TERMINATE')
		except:
			None

		try:
			for p in self.evo_task_pipes: p[0].send('TERMINATE')
		except:
			None


class TestAgent:
	"""Learner object encapsulating a local learner

		Parameters:
		algo_name (str): Algorithm Identifier
		state_dim (int): State size
		action_dim (int): Action size
		actor_lr (float): Actor learning rate
		critic_lr (float): Critic learning rate
		gamma (float): DIscount rate
		tau (float): Target network sync generate
		init_w (bool): Use kaimling normal to initialize?
		**td3args (**kwargs): arguments for TD3 algo


	"""
	def __init__(self, args, model_constructor, env_constructor, source):
		self.args = args
		prefix = '' if source == 'evo' else '_global'
		self.logger = utils.Tracker(args.metric_save, [prefix+args.log_fname], '.csv')
		self.source = source
		self.policy_type = 'MaxminDQN' if env_constructor.is_discrete else 'Gaussian_FF'

		#### Rollout Actor is a template used for MP #####
		self.manager = Manager()
		self.rollout_team = self.manager.list()
		self.rollout_team.append([model_constructor.make_model(self.policy_type) for _ in range(args.num_agents)])

		### Best Team ####
		self.best_team = [model_constructor.make_model(self.policy_type) for _ in range(args.num_agents)]

		######### TEST WORKERS ############
		self.test_task_pipes = [Pipe() for _ in range(args.num_test)]
		self.test_result_pipes = [Pipe() for _ in range(args.num_test)]
		self.test_workers = [Process(target=rollout_worker,
		                             args=(self.args, id, env_constructor, 'test', self.test_task_pipes[id][1], self.test_result_pipes[id][0],
		                                    self.rollout_team, False, False))  for id in range(args.num_test)]
		for worker in self.test_workers: worker.start()

		self.trace = []

		self.it = 0
		self.best_score = -float('inf')


	def start_test_rollout(self, agent):

		self.make_champ_team(agent) # Sync the champ policies into the TestAgent
		for p in self.test_task_pipes:
			p[0].send(0)

	def join_test_rollout(self, total_frames):

		test_fits = []
		for p in self.test_result_pipes:
			entry = p[1].recv()
			test_fits.append(utils.list_mean(entry[1]))

		test_mean = utils.list_mean(test_fits)


		self.logger.update([test_mean], total_frames)
		self.trace.append(test_mean)

		self.it+=1

		# #Periodically save policies
		# if self.it % 10 == 0:
		# 	for id, test_actor in enumerate(self.rollout_team[0]):
		# 		torch.save(test_actor.state_dict(), self.args.model_save + str(id) + '_' + self.source + '_' +self.args.actor_fname)
		# 	print("Models Saved")

		#Save best test score
		if test_mean > self.best_score:
			self.best_score = test_mean
			for agent_id in range(len(self.rollout_team[0])):
				utils.hard_update(self.best_team[agent_id], self.rollout_team[0][agent_id])
				self.best_team[agent_id].cpu()
				torch.save(self.best_team[agent_id].state_dict(), self.args.model_save + str(agent_id) + '_best_' + self.source + '_' + self.args.actor_fname + '.pth')

			print("Best Team Saved with Score", test_mean)



	def make_champ_team(self, agent):

		if self.source == 'pg':  #Testing without Evo
			agent.update_rollout_team()
			for agent_id, model in enumerate(agent.rollout_team[0]):
				utils.hard_update(self.rollout_team[0][agent_id], model)

		elif self.source == 'evo':

			for agent_id, champ_net in enumerate(agent.hof):
				utils.hard_update(self.rollout_team[0][agent_id], champ_net)


		else:
			Exception('Unknown source for Test Agent champ team')

	def terminate(self):
		try:
			for p in self.test_task_pipes: p[0].send('TERMINATE')
		except:
			None
