from nesim.utils.grid_size import find_rectangle_dimensions
import torch.nn.functional as F
from einops import rearrange
import torch.nn as nn
import torch
from ..grid.two_dimensional import BaseGrid2dConv

class DownsampledConv2d(nn.Module):
    def __init__(
        self, 
        conv_layer: nn.Conv2d, 
        factor_h: int = 2, 
        factor_w: int = 2, 
        device: str = "cuda:0"
    ):
        super().__init__()
        """
        The total parameter count shrinkage is roughly equal to: factor_h * factor_w
        """
        
        ## weight.shape: output, input, kernel_h, kernel_w
        weight = conv_layer.weight.data.detach()
        kernel_h, kernel_w = weight.shape[2], weight.shape[3]
        size = find_rectangle_dimensions(area=weight.shape[0])
        grid = BaseGrid2dConv(conv_layer=conv_layer, height = size.height, width = size.width, device=device).grid
        self.in_channels = conv_layer.in_channels
        grid = rearrange(grid, "h w e -> e h w").unsqueeze(0)

        downsampled_grid = F.interpolate(
            grid, scale_factor=(1 / factor_h, 1 / factor_w), mode="bilinear"
        ).squeeze(0)
        self.downsampled_weight = nn.Parameter(
            rearrange(downsampled_grid, "(i kernel_h kernel_w) h w -> (h w) i kernel_h kernel_w", kernel_h=kernel_h, kernel_w=kernel_w).to(device)
        ).to(device)

        self.small_grid_size = find_rectangle_dimensions(area=self.downsampled_weight.shape[0])
        self.num_output_neurons = conv_layer.weight.shape[0]

        if conv_layer.bias is not None:
            self.bias = nn.Parameter(conv_layer.bias.detach().to(device)).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
        else:
            self.bias = None
        self.size = size
        self.factor_h = factor_h
        self.factor_w = factor_w

        self.conv_layer_params = {
            "stride": conv_layer.stride,
            "padding": conv_layer.padding,
            "dilation": conv_layer.dilation,
            "groups": conv_layer.groups,
        }

    def forward_compressed(self, x):
        y_before_bias = torch.nn.functional.conv2d(
            input=x, weight=self.downsampled_weight,
            stride=self.conv_layer_params["stride"], 
            padding=self.conv_layer_params["padding"], 
            dilation=self.conv_layer_params["dilation"], 
            groups=self.conv_layer_params["groups"]
        )

        output_height = y_before_bias.shape[2]
        output_width = y_before_bias.shape[3]

        y_before_bias = rearrange(
            y_before_bias,
            "batch (small_grid_h small_grid_w) h w -> batch (h w) small_grid_h small_grid_w",
            small_grid_h = self.small_grid_size.height,
            small_grid_w = self.small_grid_size.width
        )

        ## batch, 1, small_h, small_w -> batch, 1, h , w
        y_before_bias_upsampled = F.interpolate(
            y_before_bias, size=(self.size.height, self.size.width), mode="nearest"
        )

        y_before_bias_upsampled = rearrange(
            y_before_bias_upsampled,
            "batch (h w) original_grid_h original_grid_w -> batch (original_grid_h original_grid_w) h w",
            h = output_height,
            w = output_width
        )

        assert y_before_bias_upsampled.shape[1] == self.num_output_neurons

        if self.bias is not None:
            y_after_bias = y_before_bias_upsampled + self.bias
            return y_after_bias
        else:
            return y_before_bias_upsampled

    def forward(self, x):
        return self.forward_compressed(x)

    def __repr__(self):
        return f"DownsampledConv2d(in_channels={self.downsampled_weight.shape[1]}, out_channels={self.downsampled_weight.shape[0]})"

    @property
    def weight(self):
        return self.downsampled_weight