import torch
import torch.distributions as td
from torch import nn
from torch.nn import functional as F

class Encoder(nn.Module):
	# s,a,s' -> z

	def __init__(self,state_dim,action_dim,hidden_dim,category_size,class_size):
		super().__init__()

		input_dim=state_dim+action_dim+state_dim
		feature_dim=category_size*class_size
		self.size=(category_size,class_size)

		self.l1=nn.Linear(input_dim,hidden_dim)
		self.l2=nn.Linear(hidden_dim,hidden_dim)
		self.l3=nn.Linear(hidden_dim,feature_dim)

	def forward(self,state,action,state_next):
		x=torch.cat([state,action,state_next],dim=-1)
		z=F.relu(self.l1(x))
		z=F.relu(self.l2(z))
		logits=self.l3(z)
		logits=torch.clamp(logits,min=-20,max=20)
		return logits

	def get_dist(self,state,action,state_next):
		logits=self.forward(state,action,state_next)
		logits=logits.reshape(-1,*self.size)
		dist=td.Independent(td.OneHotCategoricalStraightThrough(logits=logits),1)
		return dist

class Decoder(nn.Module):
	# z -> s',r

	def __init__(self,state_dim,hidden_dim,category_size,class_size):
		super().__init__()

		feature_dim=category_size*class_size
		self.l1=nn.Linear(feature_dim,hidden_dim)
		self.l2=nn.Linear(hidden_dim,hidden_dim)

		self.state_linear=nn.Linear(hidden_dim,state_dim)
		self.reward_linear=nn.Linear(hidden_dim,1)

	def forward(self,feature):
		feature=feature.reshape(feature.shape[0],-1)
		x=F.relu(self.l1(feature))
		x=F.relu(self.l2(x))

		state_next=self.state_linear(x)
		reward=self.reward_linear(x)
		return state_next,reward

class Feature(nn.Module):
	# s,a -> z

	def __init__(self,state_dim,action_dim,hidden_dim,category_size,class_size):
		super().__init__()

		input_dim=state_dim+action_dim
		feature_dim=category_size*class_size
		self.size=(category_size,class_size)

		self.l1=nn.Linear(input_dim,hidden_dim)
		self.l2=nn.Linear(hidden_dim,hidden_dim)
		self.l3=nn.Linear(hidden_dim,feature_dim)

	def forward(self,state,action):
		x=torch.cat([state,action],dim=-1)
		z=F.relu(self.l1(x))
		z=F.relu(self.l2(z))
		logits=self.l3(z)
		logits=torch.clamp(logits,min=-20,max=20)
		return logits

	def get_dist(self,state,action):
		logits=self.forward(state,action)
		logits=logits.reshape(-1,*self.size)
		dist=td.Independent(td.OneHotCategoricalStraightThrough(logits=logits),1)
		return dist

	def get_feature(self,state,action):
		logits=self.forward(state,action)
		logits=logits.reshape(-1,*self.size)
		features=F.softmax(logits,dim=-1).reshape(logits.shape[0],-1)
		return features
