import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class Critic(nn.Module):
	def __init__(self, state_dim, action_dim):
		super(Critic, self).__init__()

		self.l1 = nn.Linear(state_dim + action_dim, 256)
		self.l2 = nn.Linear(256, 256)
		self.l3 = nn.Linear(256, 1)


	def forward(self, state, action):
		q1 = F.relu(self.l1(torch.cat([state, action], 1)))
		q1 = F.relu(self.l2(q1))
		return self.l3(q1)


class Dual_DICE(object):
	def __init__(
		self,
		state_dim,
		action_dim,
		max_action,
		discount=0.99,
		tau=0.005
	):

		self.critic = Critic(state_dim, action_dim).to(device)
		self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=3e-4)

		self.w = Critic(state_dim, action_dim).to(device)
		self.w_optimizer = torch.optim.Adam(self.w.parameters(), lr=3e-4)

		self.discount = discount
		self.tau = tau

		self.total_it = 0

		self.made_start = False

		self.max_action = max_action



	def train_R(self, replay_buffer, policy, batch_size=2048):
		state, action, next_state, reward, not_done, fake_reward = replay_buffer.sample(batch_size)
		start_state = replay_buffer.all_start()

		with torch.no_grad():
			start_action = policy(start_state)
			start_action = (start_action + torch.randn_like(start_action) * self.max_action * 0.1).clamp(-self.max_action, self.max_action)

		Q = self.critic(start_state, start_action)
		start_Q = (1. - self.discount) * Q.mean()

		w = self.w(state, action)

		with torch.no_grad():
			next_action = policy(next_state)
			next_action = (next_action + torch.randn_like(next_action) * self.max_action * 0.1).clamp(-self.max_action, self.max_action)

		critic_stuff = self.critic(state,action) - self.discount * not_done * self.critic(next_state, next_action)

		loss = (critic_stuff * w.detach() - w.pow(2).detach()/2.).mean() - start_Q

		self.critic_optimizer.zero_grad()
		loss.backward()
		self.critic_optimizer.step()

		loss = -((critic_stuff.detach() * w - w.pow(2)/2.).mean() - start_Q.detach())

		self.w_optimizer.zero_grad()
		loss.backward()
		self.w_optimizer.step()


	def eval_policy(self, replay_buffer, policy, batch_size=10000):
		# Sample replay buffer 
		state, action, next_state, reward, not_done, fake_reward = replay_buffer.sample(batch_size)
		return float( (self.w(state,action) * reward ).mean() )