import torch
import torch.nn as nn
import torch.nn.functional as F

class MLP(nn.Module):
	def __init__(self, layer_sizes, bias=True):
		super().__init__()

		# input output dimensions for all linear layers
		if len(layer_sizes) < 2:
			raise ValueError(f'need at least two values in layers for linear(in, out), currently {len(layer_sizes)}')
		self.layers = nn.ModuleList()
		for input_dim, output_dim in zip(layer_sizes[:-1], layer_sizes[1:]):
			self.layers.append(nn.Linear(input_dim, output_dim, bias=bias))

	def forward(self, x):
		for layer_num, layer in enumerate(self.layers):
			# no non-linearity on last layer
			if layer_num == len(self.layers) - 1:
				x = layer(x)
			else:
				x = F.relu(layer(x))
		return x

class TemporalModel(nn.Module):
	def __init__(self, obj_dim, out_dim=3):
		super().__init__()
		self.linear_1 = nn.Linear(obj_dim, 128)
		self.linear_2 = nn.Linear(128, 256)
		self.linear_3 = nn.Linear(256, out_dim)

	def forward(self, obj):
		output = F.relu(self.linear_1(obj))
		output = F.relu(self.linear_2(output))
		output = self.linear_3(output)
		return output

class LogicModel(nn.Module):
	def __init__(self, obj_dim, out_dim=3):
		super().__init__()
		self.linear_1 = nn.Linear(obj_dim, 512)
		self.linear_2 = nn.Linear(512, 512)
		self.linear_3 = nn.Linear(512, out_dim)

	def forward(self, obj):
		output = F.relu(self.linear_1(obj))
		output = F.relu(self.linear_2(output))
		output = self.linear_3(output)
		return output

class LogicModelSoft(nn.Module):
	def __init__(self, obj_dim, out_dim=3):
		super().__init__()
		self.linear_1 = nn.Linear(obj_dim, 1024)
		self.linear_2 = nn.Linear(1024, 512)
		self.linear_3 = nn.Linear(512, out_dim)
		self.dropout = nn.Dropout(p=0.3, inplace=True)

	def forward(self, obj):
		output = F.relu(self.linear_1(obj))
		self.dropout(output)
		output = F.relu(self.linear_2(output))
		self.dropout(output)
		output = self.linear_3(output)
		return output

class FeatureModel(nn.Module):
	def __init__(self, obj_dim, out_dim=3):
		super().__init__()
		self.linear_1 = nn.Linear(obj_dim, 2048)
		self.linear_2 = nn.Linear(2048, 1024)
		self.linear_3 = nn.Linear(1024, 512)
		self.linear_4 = nn.Linear(512, out_dim)

	def forward(self, obj):
		output = F.relu(self.linear_1(obj))
		output = F.relu(self.linear_2(output))
		output = F.relu(self.linear_3(output))
		output = self.linear_4(output)
		return output

class FeatureModelCNN(nn.Module):
	def __init__(self, out_dim=3):
		super().__init__()
		self.conv_dim = 2304
		self.cnn1 = nn.Conv1d(14, 32, 7, stride=2)
		self.cnn2 = nn.Conv1d(32, 32, 5, stride=2)
		self.linear_1 = nn.Linear(self.conv_dim, 2048)
		self.linear_2 = nn.Linear(2048, 1024)
		self.linear_3 = nn.Linear(1024, 512)
		self.linear_4 = nn.Linear(512, out_dim)

	def forward(self, obj):
		# print(obj.shape)
		output = F.relu(self.cnn1(obj))
		# print(output.shape)
		output = F.relu(self.cnn2(output))
		# print(output.shape)
		output = output.view(-1, self.conv_dim)
		output = F.relu(self.linear_1(output))
		output = F.relu(self.linear_2(output))
		output = F.relu(self.linear_3(output))
		output = self.linear_4(output)
		return output