import torch.nn as nn
import torch
from pydantic import BaseModel
import torch.nn.functional as F
from .getting_modules import get_module_by_name
from .hook import ForwardHook


class FilterLocation(BaseModel, extra="forbid"):
    """
    `input_channel_idx` (int): index of input channel

    `output_channel_idx` (int): index of output channel
    """

    input_channel_idx: int
    output_channel_idx: int


class SingleConvFilterOutputExtractor:
    def __init__(
        self,
        model: nn.Module,
        input_tensor,
        conv_layer_name: str,
    ):
        conv_layer = get_module_by_name(module=model, name=conv_layer_name)
        assert isinstance(conv_layer, nn.Conv2d)
        hook = ForwardHook(module=conv_layer)
        self.conv_layer_name = conv_layer_name
        self.conv_layer = conv_layer
        self.model = model

        """
        forward pass on model and save layer input
        """
        with torch.no_grad():
            # normal forward pass
            y = self.model(input_tensor)

        # get layer input
        layer_input = hook.input[0]
        ## make sure it has the correct number of input channels\
        assert (
            layer_input.ndim == 4
        ), f"Expected a 4d tensor but got: {layer_input.ndim} dims"
        self.layer_input = layer_input.detach()

        hook.close()

    @torch.no_grad()
    def extract(self, filter_location: FilterLocation):

        assert isinstance(filter_location, FilterLocation)

        ## make sure that the input channel index is valid
        assert (
            filter_location.input_channel_idx <= self.layer_input.shape[1]
        ), f"Invalid input channel idx: {filter_location.input_channel_idx}"

        conv_weights_subset = (
            self.conv_layer.weight[
                filter_location.output_channel_idx,
                filter_location.input_channel_idx,
                :,
                :,
            ]
            .unsqueeze(0)
            .unsqueeze(0)
        )

        filter_output = F.conv2d(
            input=self.layer_input[
                :, filter_location.input_channel_idx, :, :
            ].unsqueeze(1),
            weight=conv_weights_subset,
            bias=self.conv_layer.bias,
            stride=self.conv_layer.stride,
            padding=self.conv_layer.padding,
            dilation=self.conv_layer.dilation,
            groups=self.conv_layer.groups,
        )
        return filter_output
