import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def get_pref(reward_dim):
        preference = np.random.rand(reward_dim)
        preference = preference.astype(np.float32)
        preference /= preference.sum()
        '''
        PRE = [[0.9, 0.1], [0.5,0.5], [0.1,0.9]]
        preference = np.array(random.choice(PRE))
        preference = preference.astype(np.float32)
        '''
        return preference

# Re-tuned version of Deep Deterministic Policy Gradients (DDPG)
# Paper: https://arxiv.org/abs/1509.02971


class Actor(nn.Module):
	def __init__(self, state_dim, action_dim, reward_dim, max_action):
		super(Actor, self).__init__()

		self.l1 = nn.Linear(state_dim+reward_dim, 256)
		self.l2 = nn.Linear(256, 256)
		self.l3 = nn.Linear(256, action_dim)
		
		self.max_action = max_action

	
	def forward(self, state, preference):
            
		a = F.relu( self.l1(torch.cat([state, preference],1)))
		a = F.relu(self.l2(a))
		return self.max_action * torch.tanh(self.l3(a))


class Critic(nn.Module):
	def __init__(self, state_dim, action_dim, reward_dim):
		super(Critic, self).__init__()

		self.l1 = nn.Linear(state_dim + action_dim + reward_dim, 256)
		self.l2 = nn.Linear(256, 256)
		self.l3 = nn.Linear(256, reward_dim)


	def forward(self, state, action, preference):
		q = F.relu(self.l1(torch.cat([state, action, preference], 1)))
		q = F.relu(self.l2(q))
		return self.l3(q)


class DDPG(object):
	def __init__(self, state_dim, action_dim, reward_dim,max_action, discount=0.99, tau=0.005):
		self.actor = Actor(state_dim, action_dim, reward_dim, max_action).to(device)
		self.actor_target = copy.deepcopy(self.actor)
		self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=3e-4)

		self.critic = Critic(state_dim, action_dim, reward_dim).to(device)
		self.critic_target = copy.deepcopy(self.critic)
		self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=3e-4)

		self.discount = discount
		self.tau = tau

		self.reward_dim = reward_dim


	def select_action(self, state, preference):
		state = torch.FloatTensor(state.reshape(1, -1)).to(device)
		preference = torch.FloatTensor(preference.reshape(1, -1)).to(device)
		return self.actor(state, preference).cpu().data.numpy().flatten()


	def train(self, replay_buffer, batch_size=256):
		# Sample replay buffer 
		state, action, next_state, reward, not_done = replay_buffer.sample(batch_size)

		# Compute the target Q value
		preference = torch.FloatTensor(get_pref(self.reward_dim)).to(device)
		D_pref = preference.repeat(batch_size,1)

                #q1 = torch.tensordot(q1, preference, dims = 1)

                
		target_Q = self.critic_target(next_state, self.actor_target(next_state, D_pref), D_pref)
		target_Q = reward + (not_done * self.discount * target_Q).detach()
                

		# Get current Q estimate
		current_Q = self.critic(state, action, D_pref)

		# Dot with preference
		target_Q = torch.tensordot(target_Q, preference, dims = 1)
		current_Q = torch.tensordot(current_Q, preference, dims = 1)

		critic_loss = F.mse_loss(current_Q, target_Q)

		# Optimize the critic
		self.critic_optimizer.zero_grad()
		critic_loss.backward()
		self.critic_optimizer.step()

		# Compute actor loss
		#actor_loss = -self.critic(state, self.actor(state, preference)).mean()
		actor_loss = -self.critic(state, self.actor(state, D_pref), D_pref)
		actor_loss = torch.tensordot(actor_loss, preference, dims = 1).mean()
		
		# Optimize the actor 
		self.actor_optimizer.zero_grad()
		actor_loss.backward()
		self.actor_optimizer.step()

		# Update the frozen target models
		for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
			target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

		for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
			target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)


	def save(self, filename, num):
		torch.save(self.critic.state_dict(), filename + str(num)+ "_critic")
		torch.save(self.critic_optimizer.state_dict(), filename + str(num) + "_critic_optimizer")
		
		torch.save(self.actor.state_dict(), filename + str(num)+ "_actor")
		torch.save(self.actor_optimizer.state_dict(), filename + str(num)+ "_actor_optimizer")


	def load(self, filename):
		self.critic.load_state_dict(torch.load(filename + "_critic"))
		self.critic_optimizer.load_state_dict(torch.load(filename + "_critic_optimizer"))
		self.critic_target = copy.deepcopy(self.critic)

		self.actor.load_state_dict(torch.load(filename + "_actor"))
		self.actor_optimizer.load_state_dict(torch.load(filename + "_actor_optimizer"))
		self.actor_target = copy.deepcopy(self.actor)
		
