import numpy as np
import math
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch
from mnist.shared_mnists import CustomSharedHyper, individual_head
import os
from functools import reduce
import operator
import argparse
import time
from mnist.model_mnist import HyperNetwork_Head, FunctionalFullNetwork

def param_count(model):
    a= sum(p.numel() for p in model.parameters() if p.requires_grad)
    print("Total number of parameters: {:,}".format(a))
    return a

class RotateTransform:
    """
    rotate the image by angle degrees, in transforms.
    """

    def __init__(self, angle):
        self.angle = angle

    def __call__(self, x):
        return transforms.functional.rotate(x, self.angle)

class GroupBase(torch.nn.Module):
    def __init__(self, dimension, identity):
        super().__init__()
        self.dimension = dimension
        self.register_buffer('identity', torch.Tensor(identity))

    def elements(self):
        raise NotImplementedError()

    def product(self, h, h_prime):
        raise NotImplementedError()

    def inverse(self, h):
        raise NotImplementedError()

    def left_action_on_R2(self, h, x):
        raise NotImplementedError()

    def matrix_representation(self, h):
        raise NotImplementedError()

    def determinant(self, h):
        raise NotImplementedError()

    def normalize_group_parameterization(self, h):
        raise NotImplementedError()

class CyclicGroup(GroupBase):

    def __init__(self, order):
        super().__init__(
            dimension=1,
            identity=[0.]
        )

        assert order > 1
        self.order = torch.tensor(order)

    def elements(self):
        """ Obtain a tensor containing all group elements in this group.

        @returns elements: Tensor containing group elements of shape [self.order]
        """
        return torch.linspace(
            start=0,
            end=2 * np.pi * float(self.order - 1) / float(self.order),
            steps=self.order,
            device=self.identity.device
        )

    def product(self, h, h_prime):
        """ Defines group product on two group elements of the cyclic group C4.

        @param h: Group element 1
        @param h_prime: Group element 2

        @returns product: Tensor containing h \cdot h_prime with \cdot the group action.
        """
        # As we directly parameterize the group by its rotation angles, this
        # will be a simple addition. Don't forget the closure property though!

        ## YOUR CODE STARTS HERE ##
        product = torch.remainder(h + h_prime, 2 * np.pi)
        ## AND ENDS HERE ##

        return product

    def inverse(self, h):
        """ Defines group inverse for an element of the cyclic group C4.

        @param h: Group element

        @returns inverse: Tensor containing h^{-1}.
        """
        # Implement the inverse operation. Keep the closure property in mind!

        ## YOUR CODE STARTS HERE ##
        inverse = torch.remainder(-h, 2 * np.pi)
        ## AND ENDS HERE ##

        return inverse

    def left_action_on_R2(self, h, x):
        """ Group action of an element from the subgroup H on a vector in R2.

        @param h: A group element from subgroup H.
        @param x: Vectors in R2.

        @returns transformed_x: Tensor containing \rho(h)x.
        """
        # Transform the vector x with h, recall that we are working with a left-regular representation,
        # meaning we transform vectors in R^2 through left-matrix multiplication.
        transformed_x = torch.tensordot(self.matrix_representation(h), x, dims=1)
        return transformed_x

    def matrix_representation(self, h):
        """ Obtain a matrix representation in R^2 for an element h.

        @param h: A group element.

        @returns representation: Tensor containing matrix representation of h, shape [2, 2].
        """
        ## YOUR CODE STARTS HERE ##
        cos_t = torch.cos(h)
        sin_t = torch.sin(h)

        representation = torch.tensor([
            [cos_t, -sin_t],
            [sin_t, cos_t]
        ], device=self.identity.device)
        ## AND ENDS HERE ##

        return representation

    def normalize_group_elements(self, h):
        """ Normalize values of group elements to range between -1 and 1.
        The group elements range from 0 to 2pi * (self.order - 1) / self.order,
        so we normalize accordingly.

        @param h: A group element.
        @return normalized_h: Tensor containing normalized value corresponding to element h.
        """
        largest_elem = 2 * np.pi * (self.order - 1) / self.order
        normalized_h = (2*h / largest_elem) - 1.
        return normalized_h

def asserting_group():
    # Some tests to verify our implementation.
    c4 = CyclicGroup(order=4)
    e, g1, g2, g3 = c4.elements()
    print("asserting")
    assert c4.product(e, g1) == g1 and c4.product(g1, g2) == g3
    assert c4.product(g1, c4.inverse(g1)) == e

    assert torch.allclose(c4.matrix_representation(e), torch.eye(2))
    assert torch.allclose(c4.matrix_representation(g2), torch.tensor([[-1, 0], [0, -1]]).float(), atol=1e-6)

    assert torch.allclose(c4.left_action_on_R2(g1, torch.tensor([0., 1.])), torch.tensor([-1., 0.]), atol=1e-7)
    print("finished asserting")

def bilinear_interpolation(signal, grid):
    """ Obtain signal values for a set of gridpoints through bilinear interpolation.

    @param signal: Tensor containing pixel values [C, H, W] or [N, C, H, W]
    @param grid: Tensor containing coordinate values [2, H, W] or [2, N, H, W]
    """
    # If signal or grid is a 3D array, add a dimension to support grid_sample.
    if len(signal.shape) == 3:
        signal = signal.unsqueeze(0)
    if len(grid.shape) == 3:
        grid = grid.unsqueeze(1)

    # Grid_sample expects [N, H, W, 2] instead of [2, N, H, W]
    grid = grid.permute(1, 2, 3, 0)

    # Grid sample expects YX instead of XY.
    grid = torch.roll(grid, shifts=1, dims=-1)

    return torch.nn.functional.grid_sample(
        signal,
        grid,
        padding_mode='zeros',
        align_corners=True,
        mode="bilinear"
    )

def trilinear_interpolation(signal, grid):
    """

    @param signal: Tensor containing pixel values [C, D, H, W] or [N, C, D, H, W]
    @param grid: Tensor containing coordinate values [3, D, H, W] or [3, N, D, H, W]
    """
    # If signal or grid is a 4D array, add a dimension to support grid_sample.
    if len(signal.shape) == 4:
        signal = signal.unsqueeze(0)
    if len(grid.shape) == 4:
        grid = grid.unsqueeze(1)

    # Grid_sample expects [N, D, H, W, 3] instead of [3, N, D, H, W]
    grid = grid.permute(1, 2, 3, 4, 0)

    # Grid sample expects YX instead of XY.
    grid = torch.roll(grid, shifts=1, dims=-1)

    return torch.nn.functional.grid_sample(
        signal,
        grid,
        padding_mode='zeros',
        align_corners=True,
        mode="bilinear"  # actually trilinear in this case...
    )


class LiftingKernelBase(torch.nn.Module):

    def __init__(self, group, kernel_size, in_channels, out_channels):
        """ Implements a base class for the lifting kernel. Stores the R^2 grid
        over which the lifting kernel is defined and it's transformed copies
        under the action of a group H.

        """
        super().__init__()
        self.group = group
        self.kernel_size = kernel_size
        self.in_channels = in_channels
        self.out_channels = out_channels

        # Create spatial kernel grid. These are the coordinates on which our
        # kernel weights are defined.
        self.register_buffer("grid_R2", torch.stack(torch.meshgrid(
            torch.linspace(-1., 1., self.kernel_size),
            torch.linspace(-1., 1., self.kernel_size),
            indexing='ij'
        )).to(self.group.identity.device))

        # Transform the grid by the elements in this group.
        self.register_buffer("transformed_grid_R2", self.create_transformed_grid_R2())

    def create_transformed_grid_R2(self):
        """Transform the created grid by the group action of each group element.
        This yields a grid (over H) of spatial grids (over R2). In other words,
        a list of grids, each index of which is the original spatial grid transformed by
        a corresponding group element in H.

        """

        group_elements = self.group.elements()

        transformed_grids = []
        for element in self.group.inverse(group_elements):
            transformed_grids.append(
                self.group.left_action_on_R2(element, self.grid_R2)
            )
        transformed_grid = torch.stack(transformed_grids, dim=1)
        ## AND ENDS HERE ##

        return transformed_grid


    def sample(self, sampled_group_elements):
        """ Sample convolution kernels for a given number of group elements

        arguments should include:
        :param sampled_group_elements: the group elements over which to sample
            the convolution kernels

        should return:
        :return kernels: filter bank extending over all input channels,
            containing kernels transformed for all output group elements.
        """
        raise NotImplementedError()

class InterpolativeLiftingKernel(LiftingKernelBase):

    def __init__(self, group, kernel_size, in_channels, out_channels):
        super().__init__(group, kernel_size, in_channels, out_channels)

        # Create and initialise a set of weights, we will interpolate these
        # to create our transformed spatial kernels.
        self.weight = torch.nn.Parameter(torch.zeros((
            self.out_channels,
            self.in_channels,
            self.kernel_size,
            self.kernel_size
        ), device=self.group.identity.device))

        # Initialize weights using kaiming uniform intialisation.
        torch.nn.init.kaiming_uniform_(self.weight.data, a=math.sqrt(5))

    def sample(self):
        """ Sample convolution kernels for a given number of group elements

        should return:
        :return kernels: filter bank extending over all input channels,
            containing kernels transformed for all output group elements.
        """
        # First, we fold the output channel dim into the input channel dim;
        # this allows us to transform the entire filter bank in one go using the
        # torch grid_sample function.

        ## YOUR CODE STARTS HERE ##
        weight = self.weight.view(
            self.out_channels * self.in_channels,
            self.kernel_size,
            self.kernel_size
        )
        ## AND ENDS HERE ##

        # Sample the transformed kernels.
        transformed_weight = []
        for spatial_grid_idx in range(self.group.elements().numel()):
            transformed_weight.append(
                bilinear_interpolation(weight, self.transformed_grid_R2[:, spatial_grid_idx, :, :])
            )
        transformed_weight = torch.stack(transformed_weight)

        # Separate input and output channels.
        transformed_weight = transformed_weight.view(
            self.group.elements().numel(),
            self.out_channels,
            self.in_channels,
            self.kernel_size,
            self.kernel_size
        )

        # Put out channel dimension before group dimension. We do this
        # to be able to use pytorched Conv2D. Details below!
        transformed_weight = transformed_weight.transpose(0, 1)

        return transformed_weight

class LiftingConvolution(torch.nn.Module):

    def __init__(self, group, in_channels, out_channels, kernel_size, padding):
        super().__init__()

        self.kernel = InterpolativeLiftingKernel(
            group=group,
            kernel_size=kernel_size,
            in_channels=in_channels,
            out_channels=out_channels
        )

        self.padding = padding

    def forward(self, x):
        """ Perform lifting convolution

        @param x: Input sample [batch_dim, in_channels, spatial_dim_1,
            spatial_dim_2]
        @return: Function on a homogeneous space of the group
            [batch_dim, out_channels, num_group_elements, spatial_dim_1,
            spatial_dim_2]
        """

        conv_kernels = self.kernel.sample()
        ## AND ENDS HERE ##

        x = torch.nn.functional.conv2d(
            input=x,
            weight=conv_kernels.reshape(
                self.kernel.out_channels * self.kernel.group.elements().numel(),
                self.kernel.in_channels,
                self.kernel.kernel_size,
                self.kernel.kernel_size
            ),
            padding=self.padding
        )


        x = x.view(
            -1,
            self.kernel.out_channels,
            self.kernel.group.elements().numel(),
            x.shape[-1],
            x.shape[-2]
        )

        return x

class GroupKernelBase(torch.nn.Module):

    def __init__(self, group, kernel_size, in_channels, out_channels):
        """ Implements base class for the group convolution kernel. Stores grid
        defined over the group R^2 \rtimes H and it's transformed copies under
        all elements of the group H.

        """
        super().__init__()
        self.group = group

        self.kernel_size = kernel_size
        self.in_channels = in_channels
        self.out_channels = out_channels

        # Create a spatial kernel grid
        self.register_buffer("grid_R2", torch.stack(torch.meshgrid(
            torch.linspace(-1., 1., self.kernel_size),
            torch.linspace(-1., 1., self.kernel_size),
            indexing='ij'
        )).to(self.group.identity.device))

        # The kernel grid now also extends over the group H, as our input
        # feature maps contain an additional group dimension
        self.register_buffer("grid_H", self.group.elements())
        self.register_buffer("transformed_grid_R2xH", self.create_transformed_grid_R2xH())

    def create_transformed_grid_R2xH(self):
        """Transform the created grid over R^2 \rtimes H by the group action of
        each group element in H.

        This yields a set of grids over the group. In other words, a list of
        grids, each index of which is the original grid over G transformed by
        a corresponding group element in H.
        """
        # Sample the group H.

        ## YOUR CODE STARTS HERE ##
        group_elements = self.group.elements()
        ## AND ENDS HERE ##

        # Transform the grid defined over R2 with the sampled group elements.
        # We again would like to end up with a grid of shape [2, |H|, kernel_size, kernel_size].

        ## YOUR CODE STARTS HERE ##
        transformed_grid_R2 = []
        for g_inverse in self.group.inverse(group_elements):
            transformed_grid_R2.append(
                self.group.left_action_on_R2(g_inverse, self.grid_R2)
            )
        transformed_grid_R2 = torch.stack(transformed_grid_R2, dim=1)
        ## AND ENDS HERE ##

        # Transform the grid defined over H with the sampled group elements. We want a grid of
        # shape [|H|, |H|]. Make sure to stack the transformed like above (over the 1st dim).

        ## YOUR CODE STARTS HERE ##
        transformed_grid_H = []
        for g_inverse in self.group.inverse(group_elements):
            transformed_grid_H.append(
                self.group.product(
                    g_inverse, self.grid_H
                )
            )
        transformed_grid_H = torch.stack(transformed_grid_H, dim=1)
        ## AND ENDS HERE ##

        # Rescale values to between -1 and 1, we do this to please the torch
        # grid_sample function.
        transformed_grid_H = self.group.normalize_group_elements(transformed_grid_H)

        # Create a combined grid as the product of the grids over R2 and H
        # repeat R2 along the group dimension, and repeat H along the spatial dimension
        # to create a [3, |H|, |H|, kernel_size, kernel_size] grid
        transformed_grid = torch.cat(
            (
                transformed_grid_R2.view(
                    2,
                    group_elements.numel(),
                    1,
                    self.kernel_size,
                    self.kernel_size,
                ).repeat(1, 1, group_elements.numel(), 1, 1),
                transformed_grid_H.view(
                    1,
                    group_elements.numel(),
                    group_elements.numel(),
                    1,
                    1,
                ).repeat(1, 1, 1, self.kernel_size, self.kernel_size)
            ),
            dim=0
        )
        return transformed_grid


    def sample(self, sampled_group_elements):
        """ Sample convolution kernels for a given number of group elements

        arguments should include:
        :param sampled_group_elements: the group elements over which to sample
            the convolution kernels

        should return:
        :return kernels: filter bank extending over all input channels,
            containing kernels transformed for all output group elements.
        """
        raise NotImplementedError()


class InterpolativeGroupKernel(GroupKernelBase):

    def __init__(self, group, kernel_size, in_channels, out_channels):
        super().__init__(group, kernel_size, in_channels, out_channels)

        self.weight = torch.nn.Parameter(torch.zeros((
            self.out_channels,
            self.in_channels,
            self.group.elements().numel(), # this is different from the lifting convolution
            self.kernel_size,
            self.kernel_size
        ), device=self.group.identity.device))


        # initialize weights using kaiming uniform intialisation.
        torch.nn.init.kaiming_uniform_(self.weight.data, a=math.sqrt(5))

    def sample(self):
        """ Sample convolution kernels for a given number of group elements

        should return:
        :return kernels: filter bank extending over all input channels,
            containing kernels transformed for all output group elements.
        """
        # First, we fold the output channel dim into the input channel dim;
        # this allows us to transform the entire filter bank in one go using the
        # interpolation function.

        ## YOUR CODE STARTS HERE ##
        weight = self.weight.view(
            self.out_channels * self.in_channels,
            self.group.elements().numel(),
            self.kernel_size,
            self.kernel_size
        )

        transformed_weight = []
        for grid_idx in range(self.group.elements().numel()):
            transformed_weight.append(
                trilinear_interpolation(weight, self.transformed_grid_R2xH[:, grid_idx, :, :, :])
            )
        transformed_weight = torch.stack(transformed_weight)

        # Separate input and output channels.
        transformed_weight = transformed_weight.view(
            self.group.elements().numel(),
            self.out_channels,
            self.in_channels,
            self.group.elements().numel(),
            self.kernel_size,
            self.kernel_size
        )

        transformed_weight = transformed_weight.transpose(0, 1)

        return transformed_weight

class GroupConvolution(torch.nn.Module):
    def __init__(self, group, in_channels, out_channels, kernel_size, padding):
        super().__init__()
        self.kernel = InterpolativeGroupKernel(
            group=group,
            kernel_size=kernel_size,
            in_channels=in_channels,
            out_channels=out_channels
        )
        self.padding = padding


    def forward(self, x):
        """ Perform lifting convolution

        @param x: Input sample [batch_dim, in_channels, group_dim, spatial_dim_1,
            spatial_dim_2]
        @return: Function on a homogeneous space of the group
            [batch_dim, out_channels, num_group_elements, spatial_dim_1,
            spatial_dim_2]
        """
        x = x.reshape(
            -1,
            x.shape[1] * x.shape[2],
            x.shape[3],
            x.shape[4]
        )

        conv_kernels = self.kernel.sample()

        x = torch.nn.functional.conv2d(
            input=x,
            weight=conv_kernels.reshape(
                self.kernel.out_channels * self.kernel.group.elements().numel(),
                self.kernel.in_channels * self.kernel.group.elements().numel(),
                self.kernel.kernel_size,
                self.kernel.kernel_size
            ),
            padding=self.padding
        )

        x = x.view(
            -1,
            self.kernel.out_channels,
            self.kernel.group.elements().numel(),
            x.shape[-1],
            x.shape[-2],
        )
        return x

class GroupEquivariantCNN(torch.nn.Module):
    def __init__(self, group, in_channels, out_channels, kernel_size, num_hidden, hidden_channels):
        super().__init__()
        self.total_layers = num_hidden + 1
        if type(group)==int:
            group = CyclicGroup(group)

        self.lifting_conv = LiftingConvolution(
            group=group,
            in_channels=in_channels,
            out_channels=hidden_channels,
            kernel_size=kernel_size,
            padding=0
        )

        # Create a set of group convolutions.
        self.gconvs = torch.nn.ModuleList()

        for i in range(num_hidden):
            self.gconvs.append(
                GroupConvolution(
                    group=group,
                    in_channels=hidden_channels,
                    out_channels=hidden_channels,
                    kernel_size=kernel_size,
                    padding=0
                )
            )
        self.projection_layer = torch.nn.AdaptiveAvgPool3d(1)
        self.final_linear = torch.nn.Linear(hidden_channels, out_channels)

    def forward(self, x):
        masked = False
        hidden_counter = 0
        lifting_mask = None
        x = self.lifting_conv(x)
        x = torch.nn.functional.layer_norm(x, x.shape[-4:])
        x = torch.nn.functional.relu(x)
        # Apply group convolutions.
        for gconv in self.gconvs:
            x = gconv(x)
            x = torch.nn.functional.layer_norm(x, x.shape[-4:])
            x = torch.nn.functional.relu(x)
            hidden_counter += 1
        # to ensure equivariance, apply max pooling over group and spatial dims.
        x = self.projection_layer(x)
        x = x.squeeze()
        x = self.final_linear(x)
        return x

############################  Hyper

class InterpolativeLiftingKernel_hyper(LiftingKernelBase):

    def __init__(self, group, kernel_size, in_channels, out_channels):
        super().__init__(group, kernel_size, in_channels, out_channels)

        # Create and initialise a set of weights, we will interpolate these
        # to create our transformed spatial kernels.
        self.need_shape = (self.out_channels, self.in_channels, self.kernel_size, self.kernel_size)

    def sample(self, generated_weight):
        """ Sample convolution kernels for a given number of group elements

        should return:
        :return kernels: filter bank extending over all input channels,
            containing kernels transformed for all output group elements.
        """
        # First, we fold the output channel dim into the input channel dim;
        # this allows us to transform the entire filter bank in one go using the
        # torch grid_sample function.

        weight = generated_weight.view(
            self.out_channels * self.in_channels,
            self.kernel_size,
            self.kernel_size
        )
        ## AND ENDS HERE ##

        # Sample the transformed kernels.
        transformed_weight = []
        for spatial_grid_idx in range(self.group.elements().numel()):
            transformed_weight.append(
                bilinear_interpolation(weight, self.transformed_grid_R2[:, spatial_grid_idx, :, :])
            )
        transformed_weight = torch.stack(transformed_weight)

        # Separate input and output channels.
        transformed_weight = transformed_weight.view(
            self.group.elements().numel(),
            self.out_channels,
            self.in_channels,
            self.kernel_size,
            self.kernel_size
        )
        # Put out channel dimension before group dimension. We do this
        # to be able to use pytorched Conv2D. Details below!
        transformed_weight = transformed_weight.transpose(0, 1)
        return transformed_weight

class LiftingConvolution_hyper(torch.nn.Module):

    def __init__(self, group, in_channels, out_channels, kernel_size, padding):
        super().__init__()

        self.kernel = InterpolativeLiftingKernel_hyper(
            group=group,
            kernel_size=kernel_size,
            in_channels=in_channels,
            out_channels=out_channels
        )

        self.padding = padding

    def need_shape(self):
        return self.kernel.need_shape

    def forward(self, x, weight):
        """ Perform lifting convolution

        @param x: Input sample [batch_dim, in_channels, spatial_dim_1,
            spatial_dim_2]
        @return: Function on a homogeneous space of the group
            [batch_dim, out_channels, num_group_elements, spatial_dim_1,
            spatial_dim_2]
        """

        conv_kernels = self.kernel.sample(weight)
        ## AND ENDS HERE ##

        x = torch.nn.functional.conv2d(
            input=x,
            weight=conv_kernels.reshape(
                self.kernel.out_channels * self.kernel.group.elements().numel(),
                self.kernel.in_channels,
                self.kernel.kernel_size,
                self.kernel.kernel_size
            ),
            padding=self.padding
        )

        x = x.view(
            -1,
            self.kernel.out_channels,
            self.kernel.group.elements().numel(),
            x.shape[-1],
            x.shape[-2]
        )

        return x

class GroupKernelBase_hyper(torch.nn.Module):

    def __init__(self, group, kernel_size, in_channels, out_channels):
        """ Implements base class for the group convolution kernel. Stores grid
        defined over the group R^2 \rtimes H and it's transformed copies under
        all elements of the group H.

        """
        super().__init__()
        self.group = group
        self.kernel_size = kernel_size
        self.in_channels = in_channels
        self.out_channels = out_channels

        # Create a spatial kernel grid
        self.register_buffer("grid_R2", torch.stack(torch.meshgrid(
            torch.linspace(-1., 1., self.kernel_size),
            torch.linspace(-1., 1., self.kernel_size),
            indexing='ij'
        )).to(self.group.identity.device))
        # The kernel grid now also extends over the group H, as our input
        # feature maps contain an additional group dimension
        self.register_buffer("grid_H", self.group.elements())
        self.register_buffer("transformed_grid_R2xH", self.create_transformed_grid_R2xH())

    def create_transformed_grid_R2xH(self):
        """Transform the created grid over R^2 \rtimes H by the group action of
        each group element in H.

        This yields a set of grids over the group. In other words, a list of
        grids, each index of which is the original grid over G transformed by
        a corresponding group element in H.
        """
        # Sample the group H.

        ## YOUR CODE STARTS HERE ##
        group_elements = self.group.elements()
        ## AND ENDS HERE ##

        # Transform the grid defined over R2 with the sampled group elements.
        # We again would like to end up with a grid of shape [2, |H|, kernel_size, kernel_size].

        ## YOUR CODE STARTS HERE ##
        transformed_grid_R2 = []
        for g_inverse in self.group.inverse(group_elements):
            transformed_grid_R2.append(
                self.group.left_action_on_R2(g_inverse, self.grid_R2)
            )
        transformed_grid_R2 = torch.stack(transformed_grid_R2, dim=1)
        ## AND ENDS HERE ##

        # Transform the grid defined over H with the sampled group elements. We want a grid of
        # shape [|H|, |H|]. Make sure to stack the transformed like above (over the 1st dim).

        ## YOUR CODE STARTS HERE ##
        transformed_grid_H = []
        for g_inverse in self.group.inverse(group_elements):
            transformed_grid_H.append(
                self.group.product(
                    g_inverse, self.grid_H
                )
            )
        transformed_grid_H = torch.stack(transformed_grid_H, dim=1)
        ## AND ENDS HERE ##

        # Rescale values to between -1 and 1, we do this to please the torch
        # grid_sample function.
        transformed_grid_H = self.group.normalize_group_elements(transformed_grid_H)

        # Create a combined grid as the product of the grids over R2 and H
        # repeat R2 along the group dimension, and repeat H along the spatial dimension
        # to create a [3, |H|, |H|, kernel_size, kernel_size] grid
        transformed_grid = torch.cat(
            (
                transformed_grid_R2.view(
                    2,
                    group_elements.numel(),
                    1,
                    self.kernel_size,
                    self.kernel_size,
                ).repeat(1, 1, group_elements.numel(), 1, 1),
                transformed_grid_H.view(
                    1,
                    group_elements.numel(),
                    group_elements.numel(),
                    1,
                    1,
                ).repeat(1, 1, 1, self.kernel_size, self.kernel_size)
            ),
            dim=0
        )
        return transformed_grid

    def sample(self, sampled_group_elements):
        """ Sample convolution kernels for a given number of group elements

        arguments should include:
        :param sampled_group_elements: the group elements over which to sample
            the convolution kernels

        should return:
        :return kernels: filter bank extending over all input channels,
            containing kernels transformed for all output group elements.
        """
        raise NotImplementedError()

class InterpolativeGroupKernel_hyper(GroupKernelBase_hyper):

    def __init__(self, group, kernel_size, in_channels, out_channels):
        super().__init__(group, kernel_size, in_channels, out_channels)
        # create and initialise a set of weights, we will interpolate these
        # to create our transformed spatial kernels. Note that our weight
        # now also extends over the group H.
        ## YOUR CODE STARTS HERE ##
        self.need_shape = (
            self.out_channels,
            self.in_channels,
            self.group.elements().numel(), # this is different from the lifting convolution
            self.kernel_size,
            self.kernel_size
        )

    def sample(self, generated_weight):
        """ Sample convolution kernels for a given number of group elements

        should return:
        :return kernels: filter bank extending over all input channels,
            containing kernels transformed for all output group elements.
        """


        weight = generated_weight.view(
            self.out_channels * self.in_channels,
            self.group.elements().numel(),
            self.kernel_size,
            self.kernel_size
        )
        ## AND ENDS HERE ##

        transformed_weight = []
        # We loop over all group elements and retrieve weight values for
        # the corresponding transformed grids over R2xH.
        for grid_idx in range(self.group.elements().numel()):
            transformed_weight.append(
                trilinear_interpolation(weight, self.transformed_grid_R2xH[:, grid_idx, :, :, :])
            )
        transformed_weight = torch.stack(transformed_weight)

        # Separate input and output channels.
        transformed_weight = transformed_weight.view(
            self.group.elements().numel(),
            self.out_channels,
            self.in_channels,
            self.group.elements().numel(),
            self.kernel_size,
            self.kernel_size
        )

        # Put out channel dimension before group dimension. We do this
        # to be able to use pytorched Conv2D. Details below!
        transformed_weight = transformed_weight.transpose(0, 1)

        return transformed_weight

class GroupConvolution_hyper(torch.nn.Module):
    def __init__(self, group, in_channels, out_channels, kernel_size, padding):
        super().__init__()
        self.kernel = InterpolativeGroupKernel_hyper(
            group=group,
            kernel_size=kernel_size,
            in_channels=in_channels,
            out_channels=out_channels
        )
        self.padding = padding
        self.need_shape = self.kernel.need_shape


    def forward(self, x, weight):
        """ Perform lifting convolution

        @param x: Input sample [batch_dim, in_channels, group_dim, spatial_dim_1,
            spatial_dim_2]
        @return: Function on a homogeneous space of the group
            [batch_dim, out_channels, num_group_elements, spatial_dim_1,
            spatial_dim_2]
        """
        x = x.reshape(
            -1,
            x.shape[1] * x.shape[2],
            x.shape[3],
            x.shape[4]
        )

        conv_kernels = self.kernel.sample(weight)

        x = torch.nn.functional.conv2d(
            input=x,
            weight=conv_kernels.reshape(
                self.kernel.out_channels * self.kernel.group.elements().numel(),
                self.kernel.in_channels * self.kernel.group.elements().numel(),
                self.kernel.kernel_size,
                self.kernel.kernel_size
            ),
            padding=self.padding
        )

        x = x.view(
            -1,
            self.kernel.out_channels,
            self.kernel.group.elements().numel(),
            x.shape[-1],
            x.shape[-2],
        )
        return x

class HyperNet_GCNN(torch.nn.Module):
    def __init__(self, shared_choice=1, weight_shapes=None, rank = 4, in_channels=1):
        super().__init__()
        self.cnn = CustomSharedHyper(shared_choice=shared_choice, in_channels=in_channels)
        cnn_result_dim = self.cnn.cnn_dim
        heads = []
        for shapes in weight_shapes:
            heads.append(LinearHeads(in_dim=cnn_result_dim, out_shape=shapes, rank=rank))
        self.heads = torch.nn.ModuleList(heads)

    def forward(self, x):
        x = self.cnn(x)
        x = x.view(x.shape[0], -1)
        results_params = []
        for head in self.heads:
            results_params.append(head(x))
        return results_params

class LinearHeads(torch.nn.Module):
    def __init__(self, in_dim, out_shape, rank=4):
        super().__init__()
        self.out_shape = out_shape
        if type(out_shape) != int:
            out_dim = reduce(operator.mul, out_shape, 1)
        else:
            out_dim = out_shape
        self.linear = individual_head(in_dim, out_dim, lora=True, intermediate_dim=rank)

    def forward(self, x):
        x = self.linear(x)
        shape_out = self.out_shape
        x = x.view(shape_out)
        return x


class GroupEquivariantCNN_hyper(torch.nn.Module):
    def __init__(self, group, in_channels, out_channels, kernel_size,
                 num_hidden, hidden_channels, rank=4, shared_choice=1):
        super().__init__()
        self.total_layers = num_hidden + 1
        if type(hidden_channels) == int:
            hidden_channels = [hidden_channels] * (num_hidden+1)
        elif len(hidden_channels) != num_hidden+1:
            # copy the last hidden_channels to the rest of the layers
            hidden_channels = hidden_channels + [hidden_channels[-1]]*(num_hidden+1-len(hidden_channels))
        if type(kernel_size) == int:
            kernel_size = [kernel_size] * (num_hidden+1)
        elif len(kernel_size)!= num_hidden+1:
            # copy the last kernel_size to the rest of the layers
            kernel_size = kernel_size + [kernel_size[-1]]*(num_hidden+1-len(kernel_size))

        if type(group) == int:
            group = CyclicGroup(group)
        self.lifting_conv = LiftingConvolution_hyper(
            group=group,
            in_channels=in_channels,
            out_channels=hidden_channels[0],
            kernel_size=kernel_size[0],
            padding=0
        )

        self.lifting_needed = self.lifting_conv.need_shape()
        # Create a set of group convolutions.
        self.gconvs = torch.nn.ModuleList()
        self.conv_needed = []
        for i in range(num_hidden):
            self.gconvs.append(
                GroupConvolution_hyper(
                    group=group,
                    in_channels=hidden_channels[i],
                    out_channels=hidden_channels[i+1],
                    kernel_size=kernel_size[i+1],
                    padding=0
                )
            )
            self.conv_needed.append(self.gconvs[-1].need_shape)
        self.projection_layer = torch.nn.AdaptiveAvgPool3d(1)
        # self.final_linear = torch.nn.Linear(hidden_channels, out_channels)
        self.linear_needed = (out_channels, hidden_channels[-1])
        concated_shape = [self.lifting_needed] + self.conv_needed + [self.linear_needed, out_channels]
        self.hypernet = HyperNet_GCNN(shared_choice=shared_choice,
                                      weight_shapes=concated_shape, rank=rank, in_channels=in_channels)

    def test_hyper(self):
        example_input = torch.randn(3, 1, 28, 28)
        mean_of_example = torch.mean(example_input, dim = 0,keepdim=True)
        print("testing the output of hypernet")
        weights = self.hypernet(mean_of_example)
        for weight in weights:
            print(weight.shape)
            if weight.requires_grad:
                print("requires grad")
            else:
                print("does not require grad")


    def forward(self, x):
        hidden_counter = 0
        mean_of_input = torch.mean(x, dim = 0,keepdim=True)
        weights = self.hypernet(mean_of_input)
        x = self.lifting_conv(x, weights[0])
        x = torch.nn.functional.layer_norm(x, x.shape[-4:])
        x = torch.nn.functional.relu(x)

        for gconv in self.gconvs:
            x = gconv(x, weights[hidden_counter+1])
            x = torch.nn.functional.layer_norm(x, x.shape[-4:])
            x = torch.nn.functional.relu(x)
            hidden_counter += 1
        # to ensure equivariance, apply max pooling over group and spatial dims.
        x = self.projection_layer(x)
        x = x.squeeze()
        x = torch.nn.functional.linear(x, weights[-2], weights[-1])
        return x

class Lift_layer(torch.nn.Module):
    def __init__(self, group, in_channels, out_channels, kernel_size, padding=0):
        super().__init__()
        if type(group) == int:
            group = CyclicGroup(group)
        self.lifting_conv = LiftingConvolution(
            group=group,
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            padding=padding
        )

    def forward(self, x):
        x = self.lifting_conv(x)
        return x

class GroupConvolution_layer(torch.nn.Module):
    def __init__(self, group, in_channels, out_channels, kernel_size, padding =0):
        super().__init__()
        if type(group) == int:
            group = CyclicGroup(group)

        self.gconv = GroupConvolution(
            group=group,
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            padding=padding
        )

    def forward(self, x):
        x = self.gconv(x)
        return x

class GroupEquivariantCNN_channels(torch.nn.Module):
    def __init__(self, group, in_channels, out_channels, kernel_size,
                 num_hidden, hidden_channels, final_pool=1):
        super().__init__()
        self.total_layers = num_hidden + 1

        if type(group) == int:
            group = CyclicGroup(group)
        if type(hidden_channels)!= list:
            hidden_channels = [hidden_channels] * (num_hidden+1)
        elif len(hidden_channels)!= num_hidden+1:
            # copy the last hidden_channels to the rest of the layers
            hidden_channels = hidden_channels + [hidden_channels[-1]]*(num_hidden+1-len(hidden_channels))
        if type(kernel_size)!= list:
            self.lifting_conv = LiftingConvolution(
                group=group,
                in_channels=in_channels,
                out_channels=hidden_channels[0],
                kernel_size=kernel_size,
                padding=0
            )
        else:
            self.lifting_conv = LiftingConvolution(
                group=group,
                in_channels=in_channels,
                out_channels=hidden_channels[0],
                kernel_size=kernel_size[0],
                padding=0
            )

        # Create a set of group convolutions.
        self.gconvs = torch.nn.ModuleList()
        if type(kernel_size)!= list:
            for i in range(num_hidden):
                self.gconvs.append(
                    GroupConvolution(
                        group=group,
                        in_channels=hidden_channels[i],
                        out_channels=hidden_channels[i+1],
                        kernel_size=kernel_size,
                        padding=0
                    )
                )
        else:
            for i in range(num_hidden):
                self.gconvs.append(
                    GroupConvolution(
                        group=group,
                        in_channels=hidden_channels[i],
                        out_channels=hidden_channels[i+1],
                        kernel_size=kernel_size[i+1],
                        padding=0
                    )
                )
        self.projection_layer = torch.nn.AdaptiveAvgPool3d(1)
        self.final_linear = torch.nn.Linear(hidden_channels[-1], out_channels)

    def forward(self, x):
        masked = False
        hidden_counter = 0
        lifting_mask = None
        x = self.lifting_conv(x)
        x = torch.nn.functional.layer_norm(x, x.shape[-4:])
        x = torch.nn.functional.relu(x)
        # Apply group convolutions.
        for gconv in self.gconvs:
            x = gconv(x)
            x = torch.nn.functional.layer_norm(x, x.shape[-4:])
            x = torch.nn.functional.relu(x)
            hidden_counter += 1
        # to ensure equivariance, apply max pooling over group and spatial dims.
        x = self.projection_layer(x)
        x = x.squeeze()
        x = self.final_linear(x)
        return x

class RegularCNN(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, num_hidden, hidden_channels):
        super().__init__()
        self.total_layers = num_hidden + 1
        self.convs = torch.nn.ModuleList()
        self.convs.append(
            torch.nn.Conv2d(
                in_channels=in_channels,
                out_channels=hidden_channels,
                kernel_size=kernel_size,
                padding=0
            ))
        for i in range(num_hidden):
            self.convs.append(
                torch.nn.Conv2d(
                    in_channels=hidden_channels,
                    out_channels=hidden_channels,
                    kernel_size=kernel_size,
                    padding=0
                )
            )
        self.projection_layer = torch.nn.AdaptiveAvgPool2d(1)
        self.final_linear = torch.nn.Linear(hidden_channels, out_channels)

    def forward(self, x):
        masked = False
        hidden_counter = 0
        conv_mask = None
        for conv in self.convs:
            x = conv(x)
            x = torch.nn.functional.layer_norm(x, x.shape[-4:])
            x = torch.nn.functional.relu(x)
            hidden_counter += 1
        # to ensure equivariance, apply max pooling over group and spatial dims.
        x = self.projection_layer(x)
        x = x.squeeze()
        x = self.final_linear(x)
        return x

def test_model(model, data_loader):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    accuracy = []
    for _, (data, target) in enumerate(data_loader):
        data, target = data.to(device), target.to(device)
        output = model(data)
        pred = output.argmax(dim=1, keepdim=True)
        correct = pred.eq(target.view_as(pred)).sum().item()
        accuracy.append(correct/len(data)*100)
    print("Accuracy: ", sum(accuracy)/len(accuracy))

def train(train, model, mode, epoch, lr = 0.001, count = False, count_time = False):
    save_location = 'constraint_pths/'
    if count_time is False:
        if count is False:
            if not os.path.exists(save_location):
                os.makedirs(save_location)
            save_location = save_location + model + mode + '.pth'
            print("saving location: ", save_location)

    if train:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        if model == "gcnn":
            model = GroupEquivariantCNN(group=CyclicGroup(4), in_channels=1, out_channels=10,
                                    kernel_size=2, num_hidden=3, hidden_channels=16).to(device)
        elif model == "cnn":
            model = RegularCNN(in_channels=1, out_channels=10,
                                 kernel_size=2, num_hidden=3, hidden_channels=16).to(device)
        elif model == "our" or model == 'ours':
            shell = RegularCNN(in_channels=1, out_channels=10,
                               kernel_size=2, num_hidden=3, hidden_channels=16).to(device)
            hyper = HyperNetwork_Head(shell, lora=True, inter=4,
                                      shared_choice=1).to(device)
            model = FunctionalFullNetwork(hyper, shell, 0, head=True).to(device)
        elif model == "hg" or model == "hypergcnn":
            print("Hyper-GCNN model")
            model = GroupEquivariantCNN_hyper(group=4, in_channels=1, out_channels=10,
                                              kernel_size=2, num_hidden=3, hidden_channels=16, shared_choice=1).to(device)

        else:
            raise ValueError("Model not found.")

        if count:
            # print("Model choice:", save_location)
            print("params count:", sum(p.numel() for p in model.parameters()))
            if count_time:
                print()
            else:
                return 0

        if mode == 'regular':
            train_loader = torch.utils.data.DataLoader(
                datasets.MNIST('data', train=True, download=True,
                               transform=transforms.Compose([
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.1307,), (0.3081,))
                               ])),
                batch_size=32, shuffle=True)

            test_loader = torch.utils.data.DataLoader(
                datasets.MNIST('data', train=False, transform=transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.1307,), (0.3081,))
                ])),
                batch_size=32, shuffle=True)

        elif mode == 'rot':
            train_loader = torch.utils.data.DataLoader(
                datasets.MNIST('data', train=True, download=True,
                               transform=transforms.Compose([
                                   transforms.ToTensor(),
                                   transforms.RandomRotation(180),
                                   transforms.Normalize((0.1307,), (0.3081,))
                               ])),
                batch_size=32, shuffle=True)

            test_loader = torch.utils.data.DataLoader(
                datasets.MNIST('data', train=False, transform=transforms.Compose([
                    transforms.ToTensor(), transforms.RandomRotation(180),
                    transforms.Normalize((0.1307,), (0.3081,))
                ])),
                batch_size=32, shuffle=True)
        else:
            raise ValueError("Model not found.")

        # first train the group equivariant model
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        criterion = torch.nn.CrossEntropyLoss()
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        begin_time = time.time()
        print("Training the model...")
        for counter in range(epoch):
            train_loss = []
            # for _, (images, labels) in enumerate(train_loader):
            for _, (images, labels) in enumerate(train_loader):
                images = images.to(device)
                labels = labels.to(device)
                optimizer.zero_grad()
                outputs = model(images)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                train_loss.append(loss.item())
            if count_time == False:
                avg_loss = np.mean(train_loss)
                # print(counter, "th training finished, time {datetime.timedelta(seconds=(time.time()-start_time))}")
                # Test the network
                correct = 0
                total = 0
                with torch.no_grad():
                    for images, labels in test_loader:
                        images = images.to(device)
                        outputs = model(images)
                        _, predicted = torch.max(outputs, -1)
                        predicted = predicted.detach().cpu()
                        total += labels.size(0)
                        correct += (predicted == labels.data).sum()
                accuracy = 100 * correct / total
                print(f"Epoch {counter}, train loss: {avg_loss}, test accuracy: {accuracy}")
                if counter % 10 == 0:
                    torch.save(model.state_dict(), save_location)
            if count_time:
                print(counter, end="")
                if counter == 10:
                    end_time = time.time()
                    print((begin_time-end_time)/counter, "seconds")
                    return
    return save_location

def test(model, save_location):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if model == "gcnn":
        model = GroupEquivariantCNN(group=CyclicGroup(4), in_channels=1, out_channels=10,
                                kernel_size=2, num_hidden=3, hidden_channels=16).to(device)
    elif model == "cnn":
        model = RegularCNN(in_channels=1, out_channels=10,
                             kernel_size=2, num_hidden=3, hidden_channels=16).to(device)
    elif model == "our":
        shell = RegularCNN( in_channels=1, out_channels=10,
                                kernel_size=2, num_hidden=3, hidden_channels=16).to(device)
        hyper = HyperNetwork_Head(shell, lora=True, inter=4,
                                  shared_choice=1).to(device)
        model = FunctionalFullNetwork(hyper, shell, 0, head=True).to(device)
    elif model == "hg" or model == "hypergcnn":
        print("Hyper-GCNN model")
        model = GroupEquivariantCNN_hyper(group=4, in_channels=1, out_channels=10,
                                          kernel_size=2, num_hidden=3, hidden_channels=16, shared_choice=1).to(device)
    else:
        raise ValueError("Model not found.")

    model.load_state_dict(torch.load(save_location))
    test_loader = torch.utils.data.DataLoader(
                datasets.MNIST('data', train=False, transform=transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.1307,), (0.3081,))
                ])),
                batch_size=32, shuffle=True)

    rot_loader = torch.utils.data.DataLoader(
        datasets.MNIST('data', train=False, transform=transforms.Compose([
            transforms.ToTensor(), transforms.RandomRotation(180), transforms.Normalize((0.1307,), (0.3081,))
        ])), batch_size=32, shuffle=True)

    ninety_loader = torch.utils.data.DataLoader(
        datasets.MNIST('data', train=False, transform=transforms.Compose([
            transforms.ToTensor(), RotateTransform(90), transforms.Normalize((0.1307,), (0.3081,))
        ])), batch_size=32, shuffle=True)
    print("original dataset")
    test_model(model, test_loader)
    print("rotated dataset")
    test_model(model, rot_loader)
    print("rotated 90 dataset")
    test_model(model, ninety_loader)

if __name__ == '__main__':
    ## aviod weird warnings
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = True

    parser = argparse.ArgumentParser()
    parser.add_argument('--model', type=str, default='gcnn',
                        help='model to train, group_equivariant_cnn or regular_cnn')
    parser.add_argument('--mode', type=str, default='regular',
                        help='train mode, regular or rot')
    parser.add_argument('--epoch', type=int, default=100,
                        help='number of epochs to train')
    parser.add_argument('--load', type=str, default=None,
                        help='path to load the saved model')
    parser.add_argument('--lr', type=float, default=0.001,
                        help='learning rate')
    parser.add_argument("--train", action="store_true", help="Set to initiate training")
    parser.add_argument("--count", action="store_true", help="Set to count the number of parameters in the model")
    args = parser.parse_args()

    if args.count:
        # only count parameters, no training.
        print("first, regular CNN")
        train(train=True, model='cnn', mode='rot', epoch=1,count=True)
        print("second, G-CNN")
        train(train=True, model='gcnn', mode='rot', epoch=1, count=True)
        print("third, our model")
        train(train=True, model='our', mode='rot', epoch=1, count=True)
        print("fourth, hyper-G-CNN")
        train(train=True, model='hg', mode='rot', epoch=1, count=True)
    else:
        print("On a {}, using {} data, for {} epochs, with learning rate {}.".format(args.model, args.mode, args.epoch, args.lr))
        if args.train:
            a = train(train=True, model=args.model, mode=args.mode, epoch=args.epoch, lr=args.lr)
        else:
            a = train(train=False, model=args.model, mode=args.mode, epoch=args.epoch, lr=args.lr)
        print("For model {}, test accuracy is:".format(args.model))
        print("save location", a)
        test(model=args.model, save_location=a)

