
import torch
import torch.nn as nn
import torch.nn.functional as F

def init(module, weight_init, bias_init, gain=1):
    weight_init(module.weight.data, gain=gain)
    bias_init(module.bias.data)
    return module

init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.
					   constant_(x, 0))

init_relu_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.
							constant_(x, 0), nn.init.calculate_gain('relu'))


def apply_init_(modules):
	"""
	Initialize NN modules
	"""
	for m in modules:
		if isinstance(m, nn.Conv2d):
			nn.init.xavier_uniform_(m.weight)
			if m.bias is not None:
				nn.init.constant_(m.bias, 0)
		elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
			nn.init.constant_(m.weight, 1)
			if m.bias is not None:
				nn.init.constant_(m.bias, 0)


class Flatten(nn.Module):
	"""
	Flatten a tensor
	"""

	def forward(self, x):
		return x.reshape(x.size(0), -1)


class Conv2d_tf(nn.Conv2d):
	"""
	Conv2d with the padding behavior from TF
	"""

	def __init__(self, *args, **kwargs):
		super(Conv2d_tf, self).__init__(*args, **kwargs)
		self.padding = kwargs.get("padding", "SAME")

	def _compute_padding(self, input, dim):
		input_size = input.size(dim + 2)
		filter_size = self.weight.size(dim + 2)
		effective_filter_size = (filter_size - 1) * self.dilation[dim] + 1
		out_size = (input_size + self.stride[dim] - 1) // self.stride[dim]
		total_padding = max(
			0, (out_size - 1) * self.stride[dim] + effective_filter_size - input_size
		)
		additional_padding = int(total_padding % 2 != 0)

		return additional_padding, total_padding

	def forward(self, input):
		if self.padding == "VALID":
			return F.conv2d(
				input,
				self.weight,
				self.bias,
				self.stride,
				padding=0,
				dilation=self.dilation,
				groups=self.groups,
			)
		rows_odd, padding_rows = self._compute_padding(input, dim=0)
		cols_odd, padding_cols = self._compute_padding(input, dim=1)
		if rows_odd or cols_odd:
			input = F.pad(input, [0, cols_odd, 0, rows_odd])

		return F.conv2d(
			input,
			self.weight,
			self.bias,
			self.stride,
			padding=(padding_rows // 2, padding_cols // 2),
			dilation=self.dilation,
			groups=self.groups,
		)


class NNBase(nn.Module):
	"""
	Actor-Critic network (base class)
	"""

	def __init__(self, hidden_size):
		super(NNBase, self).__init__()

		self._hidden_size = hidden_size

	@property
	def output_size(self):
		return self._hidden_size


class BasicBlock(nn.Module):
	"""
	Residual Network Block
	"""

	def __init__(self, n_channels, stride=1):
		super(BasicBlock, self).__init__()

		self.conv1 = Conv2d_tf(n_channels, n_channels, kernel_size=3, stride=1, padding=(1, 1))
		self.relu = nn.ReLU(inplace=True)
		self.conv2 = Conv2d_tf(n_channels, n_channels, kernel_size=3, stride=1, padding=(1, 1))
		self.stride = stride

		apply_init_(self.modules())

		self.train()

	def forward(self, x):
		identity = x

		out = self.relu(x)
		out = self.conv1(out)
		out = self.relu(out)
		out = self.conv2(out)

		out += identity
		return out


class ResNetBase(NNBase):
	"""
	Residual Network
	"""

	def __init__(self, num_inputs, hidden_size=256, channels=[16, 32, 32]):
		super(ResNetBase, self).__init__(hidden_size)

		self.layer1 = self._make_layer(num_inputs, channels[0])
		self.layer2 = self._make_layer(channels[0], channels[1])
		self.layer3 = self._make_layer(channels[1], channels[2])

		self.flatten = Flatten()
		self.relu = nn.ReLU()

		self.fc = init_relu_(nn.Linear(2048, hidden_size))
		# self.critic_linear = init_(nn.Linear(hidden_size, 1))

		apply_init_(self.modules())

		self.train()

	def _make_layer(self, in_channels, out_channels, stride=1):
		layers = []

		layers.append(Conv2d_tf(in_channels, out_channels, kernel_size=3, stride=1))
		layers.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

		layers.append(BasicBlock(out_channels))
		layers.append(BasicBlock(out_channels))

		return nn.Sequential(*layers)

	def forward(self, inputs):
		x = inputs

		x = self.layer1(x)
		x = self.layer2(x)
		x = self.layer3(x)

		x = self.relu(self.flatten(x))
		x = self.relu(self.fc(x))

		return x
