import torch
import torch.nn as nn

class InputEncoder(nn.Module):
	def __init__(self, args):
		super(InputEncoder, self).__init__()
		self.grid_size = args.grid_size
		self.io_feature_size = args.io_feature_size
		self.io_embedding_size = args.io_embedding_size
		self.input_encoder = nn.Sequential(
			nn.Conv2d(in_channels=self.io_feature_size, out_channels=64, kernel_size=3, padding=1),
			nn.ReLU())

		self.block_1 = nn.Sequential(
			nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1),
			nn.ReLU(),
			nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1),
			nn.ReLU(),
			nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1),
			nn.ReLU()
			)

		self.block_2 = nn.Sequential(
			nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1),
			nn.ReLU(),
			nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1),
			nn.ReLU(),
			nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1),
			nn.ReLU()
			)

		self.fc = nn.Linear(64 * self.grid_size * self.grid_size, self.io_embedding_size)

	def forward(self, input_grids):
		batch_dims = input_grids.size()[:-3]
		input_grids = input_grids.contiguous().view(-1, self.io_feature_size, self.grid_size, self.grid_size)
		enc = self.input_encoder(input_grids)
		enc = enc + self.block_1(enc)
		enc = enc + self.block_2(enc)
		enc = self.fc(enc.view(*(batch_dims + (-1, ))))
		return enc

class IOEncoder(nn.Module):
	def __init__(self, args):
		super(IOEncoder, self).__init__()
		self.grid_size = args.grid_size
		self.io_feature_size = args.io_feature_size
		self.io_embedding_size = args.io_embedding_size
		self.input_encoder = nn.Sequential(
			nn.Conv2d(in_channels=self.io_feature_size, out_channels=32, kernel_size=3, padding=1),
			nn.ReLU())
		self.output_encoder = nn.Sequential(
			nn.Conv2d(in_channels=self.io_feature_size, out_channels=32, kernel_size=3, padding=1),
			nn.ReLU())

		self.block_1 = nn.Sequential(
			nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1),
			nn.ReLU(),
			nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1),
			nn.ReLU(),
			nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1),
			nn.ReLU()
			)

		self.block_2 = nn.Sequential(
			nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1),
			nn.ReLU(),
			nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1),
			nn.ReLU(),
			nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1),
			nn.ReLU()
			)

		self.fc = nn.Linear(64 * self.grid_size * self.grid_size, self.io_embedding_size)

	def forward(self, input_grids, output_grids):
		batch_dims = input_grids.size()[:-3]
		input_grids = input_grids.contiguous().view(-1, self.io_feature_size, self.grid_size, self.grid_size)
		output_grids = output_grids.contiguous().view(-1, self.io_feature_size, self.grid_size, self.grid_size)
		input_enc = self.input_encoder(input_grids)
		output_enc = self.output_encoder(output_grids)
		enc = torch.cat([input_enc, output_enc], 1)
		enc = enc + self.block_1(enc)
		enc = enc + self.block_2(enc)
		enc = self.fc(enc.view(*(batch_dims + (-1, ))))
		return enc