import torch
from torch import nn
from torch.distributions import Normal
from torch.nn import functional as F

class Encoder(nn.Module):
	# s,a,s' -> z

	def __init__(self,state_dim,action_dim,hidden_dim,feature_dim):
		super().__init__()

		input_dim=state_dim+action_dim+state_dim
		self.l1=nn.Linear(input_dim,hidden_dim)
		self.l2=nn.Linear(hidden_dim,hidden_dim)

		self.mean_linear=nn.Linear(hidden_dim,feature_dim)
		self.log_std_linear=nn.Linear(hidden_dim,feature_dim)

	def forward(self,state,action,state_next):
		x=torch.cat([state,action,state_next],dim=-1)
		z=F.relu(self.l1(x))
		z=F.relu(self.l2(z))

		mean=self.mean_linear(z)
		log_std=self.log_std_linear(z)
		log_std=torch.clamp(log_std,min=-20,max=2)

		return mean,log_std

	def sample(self,state,action,state_next):
		mean,log_std=self.forward(state,action,state_next)
		normal=Normal(mean,log_std.exp())
		z=normal.rsample()  # https://stackoverflow.com/a/70818755/23344262
		return z

class Decoder(nn.Module):
	# z -> s',r

	def __init__(self,state_dim,hidden_dim,feature_dim):
		super().__init__()

		self.l1=nn.Linear(feature_dim,hidden_dim)
		self.l2=nn.Linear(hidden_dim,hidden_dim)
		self.state_linear=nn.Linear(hidden_dim,state_dim)
		self.reward_linear=nn.Linear(hidden_dim,1)

	def forward(self,feature):
		x=F.relu(self.l1(feature))
		x=F.relu(self.l2(x))

		state_next=self.state_linear(x)
		reward=self.reward_linear(x)

		return state_next,reward

class Feature(nn.Module):
	# s,a -> z

	def __init__(self,state_dim,action_dim,hidden_dim,feature_dim):
		super().__init__()

		input_dim=state_dim+action_dim
		self.l1=nn.Linear(input_dim,hidden_dim)
		self.l2=nn.Linear(hidden_dim,hidden_dim)

		self.mean_linear=nn.Linear(hidden_dim,feature_dim)
		self.log_std_linear=nn.Linear(hidden_dim,feature_dim)

	def forward(self,state,action):
		x=torch.cat([state,action],dim=-1)
		z=F.relu(self.l1(x))
		z=F.relu(self.l2(z))

		mean=self.mean_linear(z)
		log_std=self.log_std_linear(z)
		log_std=torch.clamp(log_std,min=-20,max=2)

		return mean,log_std
