import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class Critic(nn.Module):
	def __init__(self, state_dim, action_dim):
		super(Critic, self).__init__()

		self.l1 = nn.Linear(state_dim + action_dim, 256)
		self.l2 = nn.Linear(256, 256)
		self.l3 = nn.Linear(256, 1)


	def forward(self, state, action):
		q1 = F.relu(self.l1(torch.cat([state, action], 1)))
		q1 = F.relu(self.l2(q1))
		return self.l3(q1)


class Grad_DICE(object):
	def __init__(
		self,
		state_dim,
		action_dim,
		max_action,
		discount=0.99,
		tau=0.005
	):

		self.f = Critic(state_dim, action_dim).to(device)
		self.f_optimizer = torch.optim.Adam(self.f.parameters(), lr=1e-5)

		self.w = Critic(state_dim, action_dim).to(device)
		self.w_optimizer = torch.optim.Adam(self.w.parameters(), lr=1e-5)

		self.u = torch.ones(1, 1, requires_grad=True, device=device)
		self.u_optimizer = torch.optim.Adam([self.u], lr=0.01)

		self.lmbda = 1

		self.discount = discount
		self.tau = tau

		self.total_it = 0

		self.made_start = False

		self.max_action = max_action


	def train_R(self, replay_buffer, policy, batch_size=2048):
		state, action, next_state, reward, not_done, fake_reward = replay_buffer.sample(batch_size)
		start_state = replay_buffer.all_start()

		with torch.no_grad():
			start_action = policy(start_state)
			start_action = (start_action + torch.randn_like(start_action) * self.max_action * 0.1).clamp(-self.max_action, self.max_action)
			next_action = policy(next_state)
			next_action = (next_action + torch.randn_like(next_action) * self.max_action * 0.1).clamp(-self.max_action, self.max_action)

		##############################
		start_f = (1. - self.discount) * self.f(start_state, start_action).mean()
		f = self.f(state, action)
		f_next_term = not_done * self.discount * self.f(next_state, next_action)

		w = self.w(state, action)

		##############################

		loss = -(
			start_f 
			+ (w.detach() * f_next_term).mean() 
			- (w.detach() * f).mean() 
			- 0.5 * f.pow(2).mean()
			+ self.lmbda * 
			(
				self.u.detach() * (w.detach().mean() - 1)
				- 0.5 * self.u.pow(2).detach()
			)
		)

		self.f_optimizer.zero_grad()
		loss.backward()
		self.f_optimizer.step()

		##############################

		loss = -(
			start_f.detach() 
			+ (w * f_next_term).mean().detach()
			- (w * f).mean().detach() 
			- 0.5 * f.detach().pow(2).mean()
			self.lmbda * 
			(
				self.u * (w.detach().mean() - 1)
				- 0.5 * self.u.pow(2)
			)
		)

		self.u_optimizer.zero_grad()
		loss.backward()
		self.u_optimizer.step()

		#####################################

		loss = (
			start_f.detach() 
			+ (w * f_next_term.detach()).mean()
			- (w * f.detach()).mean() 
			- 0.5 * f.detach().pow(2).mean()
			+ self.lmbda * 
			(
				self.u.detach() * (w.mean() - 1)
				- 0.5 * self.u.pow(2).detach()
			)
		)

		self.w_optimizer.zero_grad()
		loss.backward()
		self.w_optimizer.step()


	def eval_policy(self, replay_buffer, policy, batch_size=10000):
		state, action, next_state, reward, not_done, fake_reward = replay_buffer.sample(batch_size)
		return float( (self.w(state,action) * reward ).mean() )