import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils import process_reward, norm_state, response_to_lable

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
action_offsite = torch.tensor([0, 0]).to(device)
action_scale = torch.tensor([1, 1]).to(device)
n_reward_class=5

class L_Encoder(nn.Module):
	def __init__(self, l_state_dim, hidden_dim):
		super(L_Encoder, self).__init__()
		self.encoder = nn.Sequential(nn.Linear(l_state_dim, 256),
									nn.ReLU(inplace=True),
									nn.Linear(256, hidden_dim)
									)

	def forward(self, l_state):
		return self.encoder(l_state)

class H_Encoder(nn.Module):
	def __init__(self, h_state_dim, hidden_dim):
		super(H_Encoder, self).__init__()
		self.encoder = nn.Sequential(nn.Linear(h_state_dim, 256),
									nn.ReLU(inplace=True),
									nn.Linear(256, 256),
									nn.ReLU(inplace=True))  
		self.fc_mu  = nn.Linear(256, hidden_dim) 
		self.fc_log_std = nn.Linear(256, hidden_dim)

		self.predict1=nn.Linear(hidden_dim, hidden_dim)
		self.predict2=nn.Linear(hidden_dim, n_reward_class)
		
	
	def reparameterise(self, mu, std):
		"""
		mu : [batch_size,z_dim]
		std : [batch_size,z_dim]        
		"""        
		# get epsilon from standard normal
		eps = torch.randn_like(std)
		return mu + std*eps
	
	def forward(self, h_state):
		x=self.encoder(h_state)
		mu=self.fc_mu(x)
		log_std = self.fc_log_std(x).clamp(-4, 15)
		std = torch.exp(log_std)
		z = self.reparameterise(mu, std)
		return z, mu, std 

	def predict(self, h_feature):
		h=F.relu(self.predict1(h_feature))
		h=self.predict2(h)
		return h

class Actor(nn.Module):
	def __init__(self, l_state_dim, h_state_dim, hidden_dim, action_dim, phi=0.05):
		super(Actor, self).__init__()
		self.l_encoder = L_Encoder(l_state_dim, hidden_dim)
		self.h_encoder = H_Encoder(h_state_dim, hidden_dim)
		self.l1 = nn.Linear(hidden_dim*2 + action_dim, 256)
		self.l2 = nn.Linear(256, 256)
		self.l3 = nn.Linear(256, action_dim)
		
		self.phi = phi

	def forward(self, state, h_state, action, test=False):
		#print(state.shape, h_state.shape)
		z, mu, std=self.h_encoder(h_state)
		l_feature=self.l_encoder(state)
		if not test:
			a = F.relu(self.l1(torch.cat([l_feature, z, action], 1)))
		else:
			a = F.relu(self.l1(torch.cat([l_feature, mu, action], 1)))
		a = F.relu(self.l2(a))
		a = action_scale * self.phi * torch.tanh(self.l3(a))+action_offsite
		action+=a
		action=torch.max(torch.min(action, action_offsite+action_scale), action_offsite-action_scale)
		return action


class Critic(nn.Module):
	def __init__(self, state_dim, action_dim):
		super(Critic, self).__init__()
		self.l1 = nn.Linear(state_dim + action_dim, 256)
		self.l2 = nn.Linear(256, 256)
		self.l3 = nn.Linear(256, 1)

		self.l4 = nn.Linear(state_dim + action_dim, 256)
		self.l5 = nn.Linear(256, 256)
		self.l6 = nn.Linear(256, 1)


	def forward(self, state, action):
		action=torch.clamp(((action-action_offsite)/action_scale),-1,1)
		q1 = F.relu(self.l1(torch.cat([state, action], 1)))
		q1 = F.relu(self.l2(q1))
		q1 = self.l3(q1)

		q2 = F.relu(self.l4(torch.cat([state, action], 1)))
		q2 = F.relu(self.l5(q2))
		q2 = self.l6(q2)
		return q1, q2


	def q1(self, state, action):
		action=torch.clamp(((action-action_offsite)/action_scale),-1,1)
		q1 = F.relu(self.l1(torch.cat([state, action], 1)))
		q1 = F.relu(self.l2(q1))
		q1 = self.l3(q1)
		return q1


# Vanilla Variational Auto-Encoder 
class VAE(nn.Module):
	def __init__(self, state_dim, action_dim, latent_dim):
		super(VAE, self).__init__()
		self.e1 = nn.Linear(state_dim + action_dim, 750)
		self.e2 = nn.Linear(750, 750)

		self.mean = nn.Linear(750, latent_dim)
		self.log_std = nn.Linear(750, latent_dim)

		self.d1 = nn.Linear(state_dim + latent_dim, 750)
		self.d2 = nn.Linear(750, 750)
		self.d3 = nn.Linear(750, action_dim)

		self.latent_dim = latent_dim
		self.device = device


	def forward(self, state, action):
		z = F.relu(self.e1(torch.cat([state, action], 1)))
		z = F.relu(self.e2(z))

		mean = self.mean(z)
		# Clamped for numerical stability 
		log_std = self.log_std(z).clamp(-4, 15)
		std = torch.exp(log_std)
		z = mean + std * torch.randn_like(std)
		
		u = self.decode(state, z)

		return u, mean, std


	def decode(self, state, z=None):
		# When sampling from the VAE, the latent vector is clipped to [-0.5, 0.5]
		if z is None:
			z = torch.randn((state.shape[0], self.latent_dim)).to(self.device).clamp(-0.5,0.5)

		a = F.relu(self.d1(torch.cat([state, z], 1)))
		a = F.relu(self.d2(a))
		return action_scale * torch.tanh(self.d3(a))+action_offsite
		
class ResAct(object):
	def __init__(self, args, lmbda=0.75, phi=0.05):
		discount, tau, phi = args.discount, args.tau, args.phi
		latent_dim = args.action_dim * 2

		self.actor = Actor(args.state_dim,args.h_state_dim,args.hidden_dim, args.action_dim, phi).to(device)
		self.actor_target = copy.deepcopy(self.actor)
		self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=0.1*args.lr)

		self.critic = Critic(args.state_dim+args.h_state_dim, args.action_dim).to(device)
		self.critic_target = copy.deepcopy(self.critic)
		self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=args.lr)

		self.vae = VAE(args.state_dim+args.h_state_dim, args.action_dim, latent_dim).to(device)
		self.vae_optimizer = torch.optim.Adam(self.vae.parameters()) 

		self.args=args
		self.action_dim = args.action_dim
		self.discount = discount
		self.tau = tau
		self.lmbda = lmbda
		self.device = device

		# self.expectile=args.expectile
	def select_action(self, state):		
		with torch.no_grad():
			state = torch.FloatTensor(state.reshape(1, -1)).repeat(20, 1).to(self.device)
			state=norm_state(state)
			l_state=state[:,:self.args.state_dim]
			h_state=state[:,self.args.state_dim:]
			
			action = self.actor.forward(l_state, h_state, self.vae.decode(state),test=True)
			q1 = self.critic.q1(state, action)
			ind = q1.argmax(0)
		return action[ind]

	def train(self, replay_buffer, batch_size=256,iterations=1):

		for it in range(iterations):
			# Sample replay buffer / batch
			state, h_state, action, next_action, next_state, next_h_state, response, h_response, not_done = replay_buffer.sample(batch_size)
			reward = process_reward(h_response, response)
		
			state=torch.cat([state, h_state],axis=1)
			next_state=torch.cat([next_state, next_h_state],axis=1)

			state=norm_state(state)
			next_state=norm_state(next_state)
			# Variational Auto-Encoder Training
			recon, mean, std = self.vae(state, action)
			recon_loss = F.mse_loss(recon, action)
			KL_loss	= -0.5 * (1 + torch.log(std.pow(2)) - mean.pow(2) - std.pow(2)).mean()
			vae_loss = recon_loss + 0.5 * KL_loss

			self.vae_optimizer.zero_grad()
			vae_loss.backward()
			self.vae_optimizer.step()


			# Critic Training
			with torch.no_grad():
				# Duplicate next state 10 times
				next_state = torch.repeat_interleave(next_state, 10, 0)
				next_l_state=next_state[:,:self.args.state_dim]
				next_h_state=next_state[:,self.args.state_dim:]

				# Compute value of perturbed actions sampled from the VAE
				target_Q1, target_Q2 = self.critic_target(next_state, self.actor_target.forward(next_l_state,next_h_state, self.vae.decode(next_state),test=False))

				# Soft Clipped Double Q-learning 
				target_Q = self.lmbda * torch.min(target_Q1, target_Q2) + (1. - self.lmbda) * torch.max(target_Q1, target_Q2)
				# Take max over each action sampled from the VAE
				target_Q = target_Q.reshape(batch_size, -1).max(1)[0].reshape(-1, 1)

				target_Q = reward + not_done * self.discount * target_Q

			current_Q1, current_Q2 = self.critic(state, action)
			critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q)

			self.critic_optimizer.zero_grad()
			critic_loss.backward()
			self.critic_optimizer.step()


			# Pertubation Model / Action Training
			sampled_actions = self.vae.decode(state)
			l_state=state[:,:self.args.state_dim]
			h_state=state[:,self.args.state_dim:]
			perturbed_actions = self.actor.forward(l_state, h_state, sampled_actions,test=False)

			z, mu, sigma=self.actor.h_encoder(h_state) 
			h=self.actor.h_encoder.predict(z) #[batch, 6]
			kl_loss = -0.5 * (1 + torch.log(sigma.pow(2)) - mu.pow(2) - sigma.pow(2)).mean()
			
			pseudo_label=response_to_lable(h_response, response) 
			ce_loss=F.cross_entropy(h, pseudo_label)
			# Update through DPG
			actor_loss = -self.critic.q1(state, perturbed_actions).mean()
			actor_loss+=(self.args.kl_rate*kl_loss+self.args.ce_rate*ce_loss)
			self.actor_optimizer.zero_grad()
			actor_loss.backward()
			self.actor_optimizer.step()

			# Update Target Networks 
			for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
				target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

			for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
				target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
		return {'vae_loss':vae_loss, 'critic_loss':critic_loss, 'actor_loss':actor_loss}
	
	def save(self, filename):
		torch.save(self.critic.state_dict(), filename + "_critic")
		torch.save(self.critic_optimizer.state_dict(), filename + "_critic_optimizer")
		
		torch.save(self.actor.state_dict(), filename + "_actor")
		torch.save(self.actor_optimizer.state_dict(), filename + "_actor_optimizer")

		torch.save(self.vae.state_dict(), filename + "_vae")

	def load(self, filename):
		self.critic.load_state_dict(torch.load(filename + "_critic"))
		self.critic_optimizer.load_state_dict(torch.load(filename + "_critic_optimizer"))
		self.critic_target = copy.deepcopy(self.critic)

		self.actor.load_state_dict(torch.load(filename + "_actor"))
		self.actor_optimizer.load_state_dict(torch.load(filename + "_actor_optimizer"))
		self.actor_target = copy.deepcopy(self.actor)

		self.vae.load_state_dict(torch.load(filename + "_vae"))
