import torch.nn as nn
import torch
from .hook import ForwardHook
from .getting_modules import get_model_layer_by_name

class IntermediateOutputExtractor(nn.Module):
    def __init__(
        self, 
        model,
        forward_hook_layer,
    ):
        """An over-engineered forward hook for pytorch.

        Args:
            model (nn.Module): the model whose intermediate outputs you'd want to extract
            forward_hook_layer (nn.Module): the submodule (see: layer) within the model whose outputs you want
        """
        super().__init__()
        self.model =  model.eval()
        self.forward_hook = ForwardHook(forward_hook_layer)
        self.forward_hook_layer=forward_hook_layer

    def forward(self, x):
        """Runs a forward pass through the model and returns the forward hook output

        Args:
            x (torch.tensor): input tensor

        Returns:
            torch.tensor: forward hook output
        """
        y = self.model(x)
        assert self.forward_hook.output is not None, 'expected hook output to NOT be None, please check if your input tensor is passing through the forward_hook_layer or not'
        return self.forward_hook.output
    
    @classmethod
    def from_layer_name(cls, model, layer_name: str):
        """Neat way to directly initiate this class with a model and the name of the layer as a string.

        Args:
            model (Union[TensorType, nn.Module]): the model whose intermediate outputs you'd want to extract
            name (str): name of the the submodule (see: layer) within the model whose outputs you want

        Returns:
            IntermediateOutputExtractor: an instance of the class
        """
        forward_hook_layer = get_model_layer_by_name(model = model, layer_name = layer_name)
        
        return cls(model=model, forward_hook_layer=forward_hook_layer)

    def __call__(self, x):
        return self.forward(x)

    @torch.no_grad()
    def get_output_info(self, x):

        assert x.shape[0] == 1,f"Expected batch size to be 1, and not: {x.shape[0]}"
        y = self.forward(x)

        if isinstance(self.forward_hook_layer, nn.Conv2d):
            layer_type = "Conv2d"
        elif isinstance(self.forward_hook_layer, nn.Linear):
            layer_type = "Linear"
        else:
            layer_type = "unknown"

        return {
            "shape": y.shape,
            "layer_type": layer_type
        }