import torch.nn as nn
import einops

"""
The base grid is an abstraction that is common to both neighbourhood cossim loss
and the laplacian pyramid loss. It takes a linear or a conv layer and arranges it's weights
in a 3 dimensional grid of shape: (height, width, embedding_size)

embedding_size depends on the nature of the layer i.e input size, output size and the kernel size for conv layers.
height * width should be the same as the output size of the model
"""


class BaseGrid2dLinear:
    def __init__(self, linear_layer, height: int, width: int, device: str):
        """
        Initializes a BaseGrid2dLinear object.

        Args:
            linear_layer (nn.Linear): Linear layer for which the Weight grid is created.
            height (int): Height of the Weight grid.
            width (int): Width of the Weight grid.
            device (str): Device on which the Weight grid should be placed.
        """
        self.width = width
        self.height = height
        self.device = device

        assert isinstance(
            linear_layer, nn.Linear
        ), f"linear_layer expected an instance of torch.nn.Linear but got: {type(linear_layer)}"
        # linear_layer.weight.shape: (output, input)
        assert (
            linear_layer.weight.shape[0] == self.width * self.height
        ), f"Expected grid height * width to be the same as linear_layer.weight.shape[0]: {linear_layer.weight.shape[0]}"

        self.linear_layer = linear_layer

    @property
    def grid(self):
        return self.linear_layer.weight.reshape(
            self.height, self.width, self.linear_layer.weight.shape[1]
        ).to(self.device)


class BaseGrid2dConv:
    def __init__(self, conv_layer, height: int, width: int, device: str):
        """
        Initializes a ConvLayerWeightGrid2D object.

        Args:
            conv_layer (nn.Conv2d): Convolutional layer for which the weight grid is created.
            height (int): Height of the weight grid.
            width (int): Width of the weight grid.
            device (str): Device on which the weight grid should be placed.
        """
        self.width = width
        self.height = height
        self.device = device

        assert isinstance(
            conv_layer, nn.Conv2d
        ), f"conv_layer expected an instance of torch.nn.Conv2d but got: {type(conv_layer)}"
        # conv_layer.weight.shape: output_channels, input_channels, kernel_size, kernel_size
        assert (
            conv_layer.weight.shape[0] == self.width * self.height
        ), f"Expected grid height * width to be the same as conv_weights.shape[0]: {conv_layer.weight.shape[0]}"

        self.conv_layer = conv_layer

    @property
    def grid(self):
        all_embeddings_based_on_weights = einops.rearrange(
            self.conv_layer.weight, "o i h w -> o (i h w)"
        )

        return all_embeddings_based_on_weights.reshape(
            self.height, self.width, all_embeddings_based_on_weights.shape[1]
        ).to(self.device)
