import torch.nn as nn
from einops import rearrange


class BaseGrid1dLinear:
    def __init__(self, linear_layer):
        """
        Initializes a BaseGrid1dLinear object.

        Args:
            linear_layer (nn.Linear): Linear layer for which the Weight grid is created.
        """

        assert isinstance(
            linear_layer, nn.Linear
        ), f"linear_layer expected an instance of torch.nn.Linear but got: {type(linear_layer)}"
        # linear_layer.weight.shape: (output, input)
        self.linear_layer = linear_layer

    @property
    def grid(self):
        ## grid.shape = (num_output_neurons, num_input_neurons)
        ## think of it as (num_output_neurons, e)
        return self.linear_layer.weight


class BaseGrid1dConv:
    def __init__(self, conv_layer):
        """
        Initializes a BaseGrid2dConv object.

        Args:
            conv_layer (nn.Conv2d): Conv layer for which the Weight grid is created.
        """

        assert isinstance(
            conv_layer, nn.Conv2d
        ), f"conv_layer expected an instance of torch.nn.Conv2d but got: {type(conv_layer)}"
        # conv_layer.weight.shape: (out_channels, in_channels, kernel_size[0], kernel_size[1])
        self.conv_layer = conv_layer

    @property
    def grid(self):
        ## grid.shape = (num_output_neurons, num_input_neurons)
        ## think of it as (num_output_neurons, e)
        rearragned_weight = rearrange(
            self.conv_layer.weight, "o i k1 k2 -> o (i k1 k2)"
        )
        return rearragned_weight
