import torch, os, random, sys
import torch.nn as nn
from torch.optim import Adam
import numpy as np
from core import utils as utils
from core.models import GumbelPolicy, QGlobal
from torch.utils.tensorboard import SummaryWriter
import torch.functional as F


class MultiDDQN(object):
	"""Classes implementing TD3 and DDPG off-policy learners

		 Parameters:
			   args (object): Parameter class


	 """
	def __init__(self, args, model_constructor):

		self.num_agents = args.num_agents
		self.actor_lr = args.actor_lr
		self.gamma = args.gamma;
		self.tau = args.tau;
		self.total_update = 0;
		self.model_constructor = model_constructor
		self.actualize = args.actualize
		self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
		self.tracker = utils.Tracker(args.aux_save, ['q_'+args.savetag, 'qloss_'+args.savetag, \
						'policy_loss_'+args.savetag, 'alz_score'+args.savetag,'alz_policy'+args.savetag], \
						'.csv', save_iteration=1000, conv_size=1000)


		#Initialize actors
		self.policies = [model_constructor.make_model('DDQN').to(device=self.device) for _ in range(self.num_agents)]
		self.policies_target = [model_constructor.make_model('DDQN').to(device=self.device) for _ in range(self.num_agents)]
		for p, target_p in zip(self.policies, self.policies_target):
			utils.hard_update(target_p, p)
		self.optims_policy = [Adam(self.policies[agent_id].parameters(), self.actor_lr) for agent_id in range(self.num_agents)]

		self.loss = nn.MSELoss()
		self.log_softmax = torch.nn.LogSoftmax(dim=1)
		self.softmax = torch.nn.Softmax(dim=1)
		self.auto_reward_target_critics = []
		for _ in range(args.evo_popn_size):
			self.auto_reward_target_critics.append([model_constructor.make_model('DDQN').to(device=self.device)\
			for _ in range(self.num_agents)])
		# Statistics Tracker
		self.writer = SummaryWriter(log_dir='tensorboard' + '/' + args.savetag)

	def autoreward_update(self, buffer, reward_recipe, policy, agent_id, popn_id, args, return_policies):
		"""Runs a step of Bellman upodate and policy gradient using a batch of experiences

			 Parameters:
				  state_batch (tensor): Current States
				  next_state_batch (tensor): Next States
				  action_batch (tensor): Actions
				  reward_batch (tensor): Rewards
				  done_batch (tensor): Done batch
				  num_epoch (int): Number of learning iteration to run with the same data

			 Returns:
				   None

			s & ns --> [batch_size, agent_id, *]
			a --> [batch_size, agent_id, *]
			r --> [batch_size, agent_id, *]

		 """
		#Pre-process data
		# target_policy = GumbelPolicy(self.state_dim, self.action_dim).to(device=self.device)
		policy.to(device=self.device)
		# utils.hard_update(target_policy, policy)

		#Optims
		optim = Adam(policy.parameters(), self.actor_lr)



		#Start Learning
		for _ in range(args.autoreward_iterations):
			state_batch, action_batch, next_state_batch , reward_batch, done_batch = buffer.sample(agent_id)
			#Compute reward
			reward_batch = reward_recipe.compute_reward(torch.cat([state_batch, action_batch], 1), popn_id).to(device=self.device)
			state_batch = state_batch.cuda()
			next_state_batch = next_state_batch.cuda()
			action_batch = action_batch.long().cuda()
			done_batch  = done_batch.cuda()

			with torch.no_grad():
				#na = policy.take_action(next_state_batch, return_only_action=True)

				na, _, ns_logits= self.auto_reward_target_critics[popn_id][agent_id].clean_action(next_state_batch, return_only_action=False)
				next_entropy = (-self.softmax(ns_logits) * self.log_softmax(ns_logits)).sum(1)
				self.writer.add_scalar('next_entropy',next_entropy.mean().item())

				# Compute Next Target
				next_target = ns_logits.gather(1, na.unsqueeze(1).long())# + next_entropy
				next_q_value = reward_batch + self.gamma * next_target * (1-done_batch)
				self.writer.add_scalar('next_qval', next_q_value.mean().item())

				#print('next_qval', next_q_value.shape)

			# Compute Duelling Q-Val
			_,_, logits= policy.clean_action(state_batch, return_only_action=False)
			q_val = logits.gather(1, action_batch)
			self.writer.add_scalar('qval', q_val.mean().item())

			entropy = (-self.softmax(logits) * self.log_softmax(logits)).sum(1)
			self.writer.add_scalar('entropy', entropy.mean().item())


			q_loss = (next_q_value - q_val)**2
			q_loss = q_loss.mean()
			self.writer.add_scalar('q_loss', q_loss.item())


			optim.zero_grad()
			q_loss.backward()
			optim.step()

			utils.soft_update(self.auto_reward_target_critics[popn_id][agent_id], policy, self.tau)

		#policy = None
		return_policies[popn_id][agent_id] = policy.cpu()

	def preserve_critic(self, agent_id, new_elite, old_elite):
		utils.hard_update(self.auto_reward_target_critics[new_elite][agent_id], self.auto_reward_target_critics[old_elite][agent_id])

	def reset_critic(self, popn_id, agent_id):
		new_net = self.model_constructor.make_model('DDQN').to(device=self.device)
		utils.hard_update(self.auto_reward_target_critics[popn_id][agent_id], new_net)

	def global_update(self, reward_scaling, buffer):
		pass
