import torch
import numpy as np


def rotate(x: torch.Tensor, r: int) -> torch.Tensor:
    # Method which implements the action of the group element `g` indexed by `r` on the input image `x`.
    # The method returns the image `g.x`

    # note that we rotate the last 2 dimensions of the input, since we want to later use this method to rotate minibatches containing multiple images
    return x.rot90(r, dims=(-2, -1))


def rotate_p4(y: torch.Tensor, r: int) -> torch.Tensor:
    # `y` is a function over p4, i.e. over the pixel positions and over the elements of the group C_4.
    # This method implements the action of a rotation `r` on `y`.
    # To be able to reuse this function later with a minibatch of inputs, assume that the last two dimensions (`dim=-2` and `dim=-1`) of `y` are the spatial dimensions
    # while `dim=-3` has size `4` and is the C_4 dimension.
    # All other dimensions are considered batch dimensions
    assert len(y.shape) >= 3
    assert y.shape[-3] == 4

    ### BEGIN SOLUTION
    ry = y.roll(r, -3)
    ry = rotate(ry, r)
    return ry

    ### END SOLUTION

class C4LiftingConv2d(torch.nn.Module):

    def __init__(self, in_channels: int, out_channels: int, kernel_size: int, padding: int = 0, bias: bool = True):

        super(C4LiftingConv2d, self).__init__()

        self.kernel_size = kernel_size
        self.stride = 1
        self.dilation = 1
        self.padding = padding
        self.out_channels = out_channels
        self.in_channels = in_channels

        # In this block you need to create a tensor which stores the learnable filters
        # Recall that this layer should have `out_channels x in_channels` different learnable filters, each of shape `kernel_size x kernel_size`
        # During the forward pass, you will build the bigger filter of shape `out_channels x 4 x in_channels x kernel_size x kernel_size` by rotating 4 times
        # the learnable filters in `self.weight`

        # initialize the weights with some random values from a normal distribution with std = 1 / sqrt(out_channels * in_channels)

        self.weight = None

        ### BEGIN SOLUTION
        weight = torch.nn.init.normal_(torch.empty(out_channels, in_channels, kernel_size, kernel_size), 0,
                                       1 / np.sqrt(out_channels * in_channels))
        self.weight = torch.nn.Parameter(weight, requires_grad=True)

        ### END SOLUTION

        # This time, you also need to build the bias
        # The bias is shared over the 4 rotations
        # In total, the bias has `out_channels` learnable parameters, one for each independent output
        # In the forward pass, you need to convert this bias into an "expanded" bias by repeating each entry `4` times

        self.bias = None
        if bias:
            ### BEGIN SOLUTION
            self.bias = torch.nn.Parameter(
                torch.nn.init.normal_(torch.empty(out_channels), 0, 1 / np.sqrt(out_channels * in_channels)),
                requires_grad=True)

        ### END SOLUTION

    def build_filter(self) -> torch.Tensor:
        # using the tensors of learnable parameters, build
        # - the `out_channels x 4 x in_channels x kernel_size x kernel_size` filter
        # - the `out_channels x 4` bias

        _filter = None
        _bias = None

        # Make sure that the filter and the bias tensors are on the same device of `self.weight` and `self.bias`

        # First build the filter
        # Recall that `_filter[:, i, :, :, :]` should contain the learnable filter rotated `i` times

        ### BEGIN SOLUTION
        _filter = torch.empty(self.out_channels, 4, self.in_channels, self.kernel_size, self.kernel_size).cuda()
        for i in range(4):
            _filter[:, i, :, :, :] = rotate(self.weight, i)

        ### END SOLUTION

        # Now build the bias
        # Recall that `_bias[:, i]` should contain a copy of the learnable bias for each `i=0,1,2,3`

        if self.bias is not None:
            ### BEGIN SOLUTION
            _bias = torch.cat([self.bias.unsqueeze(-1)] * 4, dim=-1)

        ### END SOLUTION
        else:
            _bias = None

        return _filter, _bias

    def forward(self, x: torch.Tensor) -> torch.Tensor:

        _filter, _bias = self.build_filter()

        assert _bias.shape == (self.out_channels, 4)
        assert _filter.shape == (self.out_channels, 4, self.in_channels, self.kernel_size, self.kernel_size)

        # to be able to use torch.conv2d, we need to reshape the filter and bias to stack together all filters
        _filter = _filter.reshape(self.out_channels * 4, self.in_channels, self.kernel_size, self.kernel_size)
        _bias = _bias.reshape(self.out_channels * 4)

        out = torch.conv2d(x, _filter,
                           stride=self.stride,
                           padding=self.padding,
                           dilation=self.dilation,
                           bias=_bias)

        # `out` has now shape `batch_size x out_channels*4 x W x H`
        # we need to reshape it to `batch_size x out_channels x 4 x W x H` to have the shape we expect

        return out.view(-1, self.out_channels, 4, out.shape[-2], out.shape[-1])


class C4GroupConv2d(torch.nn.Module):

    def __init__(self, in_channels: int, out_channels: int, kernel_size: int, padding: int = 0, bias: bool = True):

        super(C4GroupConv2d, self).__init__()

        self.kernel_size = kernel_size
        self.stride = 2
        self.dilation = 1
        self.padding = padding
        self.out_channels = out_channels
        self.in_channels = in_channels

        # In this block you need to create a tensor which stores the learnable filters
        # Recall that this layer should have `out_channels x in_channels` different learnable filters, each of shape `4 x kernel_size x kernel_size`
        # During the forward pass, you will build the bigger filter of shape `out_channels x 4 x in_channels x 4 x kernel_size x kernel_size` by rotating 4 times
        # the learnable filters in `self.weight`

        # initialize the weights with some random values from a normal distribution with std = 1 / np.sqrt(out_channels * in_channels)

        self.weight = None

        ### BEGIN SOLUTION
        weight = torch.nn.init.normal_(torch.empty(out_channels, in_channels, 4, kernel_size, kernel_size), 0,
                                       1 / np.sqrt(out_channels * in_channels))
        self.weight = torch.nn.Parameter(weight, requires_grad=True)

        ### END SOLUTION

        # The bias is shared over the 4 rotations
        # In total, the bias has `out_channels` learnable parameters, one for each independent output
        # In the forward pass, you need to convert this bias into an "expanded" bias by repeating each entry `4` times

        self.bias = None
        if bias:
            ### BEGIN SOLUTION
            self.bias = torch.nn.Parameter(
                torch.nn.init.normal_(torch.empty(out_channels), 0, 1 / np.sqrt(out_channels * in_channels)),
                requires_grad=True)

        ### END SOLUTION

    def build_filter(self) -> torch.Tensor:
        # using the tensors of learnable parameters, build
        # - the `out_channels x 4 x in_channels x 4 x kernel_size x kernel_size` filter
        # - the `out_channels x 4` bias

        _filter = None
        _bias = None

        # Make sure that the filter and the bias tensors are on the same device of `self.weight` and `self.bias`

        # First build the filter
        # Recall that `_filter[:, r, :, :, :, :]` should contain the learnable filter rotated `r` times
        # Also, recall that a rotation includes both a rotation of the pixels and a cyclic permutation of the 4 rotational input channels

        ### BEGIN SOLUTION
        _filter = torch.empty(self.out_channels, 4, self.in_channels, 4, self.kernel_size, self.kernel_size).cuda()
        for i in range(4):
            _filter[:, i, :, :, :, :] = rotate_p4(self.weight, i)

        ### END SOLUTION

        # Now build the bias
        # Recall that `_bias[:, i]` should contain a copy of the learnable bias for each `i`

        if self.bias is not None:
            ### BEGIN SOLUTION
            _bias = torch.cat([self.bias.unsqueeze(-1)] * 4, dim=-1)

        ### END SOLUTION
        else:
            _bias = None

        return _filter, _bias

    def forward(self, x: torch.Tensor) -> torch.Tensor:

        _filter, _bias = self.build_filter()

        assert _bias.shape == (self.out_channels, 4)
        assert _filter.shape == (self.out_channels, 4, self.in_channels, 4, self.kernel_size, self.kernel_size)

        # to be able to use torch.conv2d, we need to reshape the filter and bias to stack together all filters
        _filter = _filter.reshape(self.out_channels * 4, self.in_channels * 4, self.kernel_size, self.kernel_size)
        _bias = _bias.reshape(self.out_channels * 4)

        # this time, also the input has shape `batch_size x in_channels x 4 x W x H`
        # so we need to reshape it to `batch_size x in_channels*4 x W x H` to be able to use torch.conv2d
        x = x.view(x.shape[0], self.in_channels * 4, x.shape[-2], x.shape[-1])

        out = torch.conv2d(x, _filter,
                           stride=self.stride,
                           padding=self.padding,
                           dilation=self.dilation,
                           bias=_bias)

        # `out` has now shape `batch_size x out_channels*4 x W x H`
        # we need to reshape it to `batch_size x out_channels x 4 x W x H` to have the shape we expect

        return out.view(-1, self.out_channels, 4, out.shape[-2], out.shape[-1])


class Encoder_group_cnn(torch.nn.Module):
	def __init__(self):
		super(Encoder_group_cnn, self).__init__()

		### BEGIN SOLUTION
		self.liftingconv2d = C4LiftingConv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1, bias=True)
		self.groupconv2d8x16 = C4GroupConv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1, bias=True)
		self.groupconv2d16x32 = C4GroupConv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1, bias=True)
		# self.groupconv2d32x64 = GroupConv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1, bias=True)
		# self.groupconv2d64x128 = GroupConv2d(in_channels=64, out_channels=128, kernel_size=2, padding=1, bias=True)

		self.net = torch.nn.Sequential(self.liftingconv2d,
									   torch.nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)),
									   torch.nn.ReLU(),
									   self.groupconv2d8x16,
									   torch.nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)),
									   torch.nn.ReLU(),
									   self.groupconv2d16x32,
									   torch.nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)),
									   torch.nn.ReLU(),
									   # self.groupconv2d32x64,
									   # torch.nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)),
									   # torch.nn.ReLU(),
									   # self.groupconv2d64x128,
									   # torch.nn.MaxPool3d(kernel_size=(4, 3, 3), stride=(1, 1, 1), padding=(0, 0, 0)),
									   # torch.nn.ReLU(),
                                       )

	### END SOLUTION

	def forward(self, input: torch.Tensor):
		### BEGIN SOLUTION
		out = self.net(input)
		out = out.reshape(out.shape[0], -1)
		return out
	### END SOLUTION


if __name__ == '__main__':
    group_cnn = Encoder_group_cnn()
    x = torch.randn(16, 1, 32, 32)
    x_out = group_cnn(x)
    print(x_out.shape)


