import torch, os, random, sys
import torch.nn as nn
from torch.optim import Adam
import numpy as np
from core import utils as utils
from core.models import GumbelPolicy, QGlobal
from torch.utils.tensorboard import SummaryWriter
import torch.functional as F


class MaxminDQN(object):
	"""Classes implementing TD3 and DDPG off-policy learners

		 Parameters:
			   args (object): Parameter class


	 """
	def __init__(self, args, model_constructor):

		self.num_agents = args.num_agents
		self.actor_lr = args.actor_lr
		self.gamma = args.gamma
		self.tau = args.tau
		self.total_update = 0
		self.model_constructor = model_constructor
		self.actualize = args.actualize
		self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
		self.tracker = utils.Tracker(args.aux_save, ['q_'+args.savetag, 'qloss_'+args.savetag, \
						'policy_loss_'+args.savetag, 'alz_score'+args.savetag,'alz_policy'+args.savetag], \
						'.csv', save_iteration=1000, conv_size=1000)


		#Initialize actors
		self.policies = [model_constructor.make_model('MaxminDQN').to(device=self.device) for _ in range(self.num_agents)]
		self.policies_target = [model_constructor.make_model('MaxminDQN').to(device=self.device) for _ in range(self.num_agents)]
		for p, target_p in zip(self.policies, self.policies_target):
			utils.hard_update(target_p, p)
		self.optims_policy = [Adam(self.policies[agent_id].parameters(), self.actor_lr) for agent_id in range(self.num_agents)]

		self.loss = nn.MSELoss()
		self.log_softmax = torch.nn.LogSoftmax(dim=1)
		self.softmax = torch.nn.Softmax(dim=1)
		self.auto_reward_target_critics = []
		for _ in range(args.evo_popn_size):
			self.auto_reward_target_critics.append([model_constructor.make_model('MaxminDQN').to(device=self.device)\
			for _ in range(self.num_agents)])
		# Statistics Tracker
		self.writer = SummaryWriter(log_dir='tensorboard' + '/' + args.savetag)

	def autoreward_update(self, buffer, reward_recipe, policy, agent_id, popn_id, args, return_policies):
		"""Runs a step of Bellman upodate and policy gradient using a batch of experiences

			 Parameters:
				  state_batch (tensor): Current States
				  next_state_batch (tensor): Next States
				  action_batch (tensor): Actions
				  reward_batch (tensor): Rewards
				  done_batch (tensor): Done batch
				  num_epoch (int): Number of learning iteration to run with the same data

			 Returns:
				   None

			s & ns --> [batch_size, agent_id, *]
			a --> [batch_size, agent_id, *]
			r --> [batch_size, agent_id, *]

		 """
		#Pre-process data
		# target_policy = GumbelPolicy(self.state_dim, self.action_dim).to(device=self.device)
		policy.to(device=self.device)
		# utils.hard_update(target_policy, policy)

		#Optims
		optim = Adam(policy.parameters(), self.actor_lr)



		#Start Learning
		for _ in range(args.autoreward_iterations):
			state_batch, action_batch, next_state_batch , reward_batch, done_batch = buffer.sample(agent_id)
			#Compute reward
			reward_batch = reward_recipe.compute_reward(torch.cat([state_batch, action_batch], 1), popn_id).to(device=self.device)
			state_batch = state_batch.cuda()
			next_state_batch = next_state_batch.cuda()
			action_batch = action_batch.long().cuda()
			done_batch  = done_batch.cuda()
			with torch.no_grad():
				q_target = self.get_q_target(next_state_batch, reward_batch, done_batch, popn_id, agent_id)

			q_values = policy(state_batch)
			q_loss = 0
			all_nn_params = []
			for i in range(len(q_values)):
				q_val = q_values[i].gather(1, action_batch)
				loss = (q_target - q_val)**2
				loss = loss.mean()
				q_loss += loss
			self.writer.add_scalar('q_loss', q_loss.item())
			for idx in range(len(q_values)):
				all_nn_params.append(torch.cat([param.view(-1)\
				for param in policy.heads[idx].parameters()]))
			head_id = np.random.randint(len(q_values))
			inequality = self.regularizer(all_nn_params, head_id)
			q_loss = q_loss - inequality
			optim.zero_grad()
			q_loss.backward()
			optim.step()

			utils.soft_update(self.auto_reward_target_critics[popn_id][agent_id], policy, self.tau)

		#policy = None
		return_policies[popn_id][agent_id] = policy.cpu()

	def get_q_target(self, next_states, rewards, dones, popn_id, agent_id):
		head_output = self.auto_reward_target_critics[popn_id][agent_id](next_states)
		q_target = self.compute_q_target(head_output, rewards, dones)
		return q_target

	def regularizer(self, all_params, head_id):
		weight_vector = all_params[head_id]
		mean_vector = torch.mean(
			torch.stack(all_params), dim=0).detach()
		return 1e-8*torch.dist(weight_vector, mean_vector, 2)


	def compute_q_target(self, head_output, rewards, dones):
		q_min = head_output[0]
		for i in range(1, len(head_output)):
			q_min = torch.min(q_min, head_output[i])
		q_next = q_min.max(1)[0].unsqueeze(1)
		q_target = rewards + self.gamma * q_next * (1 - dones)
		return q_target

	def preserve_critic(self, agent_id, new_elite, old_elite):
		utils.hard_update(self.auto_reward_target_critics[new_elite][agent_id], self.auto_reward_target_critics[old_elite][agent_id])

	def reset_critic(self, popn_id, agent_id):
		new_net = self.model_constructor.make_model('MaxminDQN').to(device=self.device)
		utils.hard_update(self.auto_reward_target_critics[popn_id][agent_id], new_net)

	def global_update(self, reward_scaling, buffer):
		pass
