import torch
from einops import rearrange


def sort_tensor_along_hw(tensor):
    """
    Sorts a tensor along the h and w axes while keeping the e axis intact.

    Args:
        tensor (torch.Tensor): The input tensor with dimensions (h, w, e).

    Returns:
        torch.Tensor: The sorted tensor along h and w axes, while keeping e intact.
    """
    h, w, e = tensor.shape
    sorted_tensor = rearrange(tensor, "h w e -> (h w) e")
    sorted_tensor = sorted_tensor.sort(dim=0)[0]
    sorted_tensor = rearrange(sorted_tensor, "(h w) e -> h w e", h=h, w=w)

    return sorted_tensor


import torch.nn as nn
from .grid_size import find_rectangle_dimensions
from ..grid.two_dimensional import BaseGrid2dConv, BaseGrid2dLinear
from typing import Union


def sort_layer(layer: Union[nn.Linear, nn.Conv2d]):
    assert isinstance(
        layer, Union[nn.Linear, nn.Conv2d]
    ), f"Expected nn.Linear or nn.Conv2d, but got {type(layer)}"

    weights = layer.weight.data
    ## weights.shape = num_output_neurons, num_input_neurons

    grid_size = find_rectangle_dimensions(area=weights.shape[0])

    if isinstance(layer, nn.Linear):
        num_input_neurons = layer.in_features
        weights_reshaped = BaseGrid2dLinear(
            layer, grid_size.height, grid_size.width, device=weights.device
        ).grid

        weights_arranged = sort_tensor_along_hw(tensor=weights_reshaped)
        weights_arranged = weights_arranged.reshape(
            grid_size.height * grid_size.width, num_input_neurons
        )
    elif isinstance(layer, nn.Conv2d):
        num_input_neurons = layer.in_channels
        weights_reshaped = BaseGrid2dConv(
            layer, grid_size.height, grid_size.width, device=weights.device
        ).grid

        weights_arranged = sort_tensor_along_hw(tensor=weights_reshaped)

        kernel_size = layer.kernel_size[0]

        weights_arranged = weights_arranged.reshape(
            grid_size.height * grid_size.width,
            num_input_neurons,
            kernel_size,
            kernel_size,
        )
    layer.weight.data = weights_arranged
    return layer
