import torch
import torch.nn as nn
from util import log
from c4_group_cnn import C4LiftingConv2d, C4GroupConv2d
from resnet import EncoderBlock, Flatten


class C4_Encoder_group_cnn(torch.nn.Module):
	def __init__(self, in_channels=1):
		super(C4_Encoder_group_cnn, self).__init__()

		self.liftingconv2d = C4LiftingConv2d(in_channels=in_channels, out_channels=32, kernel_size=3, padding=1, bias=True)
		self.groupconv2d8x16 = C4GroupConv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1, bias=True)
		self.groupconv2d16x32 = C4GroupConv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1, bias=True)

		self.net = torch.nn.Sequential(self.liftingconv2d,
									   torch.nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)),
									   torch.nn.ReLU(),
									   self.groupconv2d8x16,
									   torch.nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)),
									   torch.nn.ReLU(),
									   self.groupconv2d16x32,
									   torch.nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)),
									   )

	def forward(self, input: torch.Tensor):
		out = self.net(input)
		out = out.reshape(out.shape[0], -1)
		return out


class Encoder_ResNet(nn.Module):
    def __init__(self, original_channel=3, base_channel=32, input_encoded_size=128):
        super(Encoder_ResNet, self).__init__()
        self.encoder = nn.Sequential(EncoderBlock(in_channels=original_channel, out_channels=base_channel),
                                     EncoderBlock(in_channels=base_channel, out_channels=base_channel),
                                     EncoderBlock(in_channels=base_channel, out_channels=base_channel),
                                     EncoderBlock(in_channels=base_channel, out_channels=base_channel),
                                     Flatten(),
                                     nn.Linear(288, input_encoded_size),
                                     )

    def forward(self, x):
        return self.encoder(x)


class Encoder_mlp(nn.Module):
	def __init__(self, args):
		super(Encoder_mlp, self).__init__()
		log.info('Building MLP encoder...')
		# Fully-connected layers
		log.info('FC layers...')
		self.fc1 = nn.Linear(3*32*32, 512)
		self.fc2 = nn.Linear(512, 256)
		self.fc3 = nn.Linear(256, 128)
		# Nonlinearities
		self.relu = nn.ReLU()
		# Initialize parameters
		for name, param in self.named_parameters():
			# Initialize all biases to 0
			if 'bias' in name:
				nn.init.constant_(param, 0.0)
			# Initialize all pre-ReLU weights using Kaiming normal distribution
			elif 'weight' in name:
				nn.init.kaiming_normal_(param, nonlinearity='relu')
	def forward(self, x):
		# Flatten image
		x_flat = torch.flatten(x, 1)
		# Fully-connected layers
		fc1_out = self.relu(self.fc1(x_flat))
		fc2_out = self.relu(self.fc2(fc1_out))
		fc3_out = self.relu(self.fc3(fc2_out))
		# Output
		z = fc3_out
		return z


