# do this one properly

import torch
import torch.nn as nn
import torch
import torch.nn as nn
from functools import reduce
import torchvision.transforms as transforms
from tqdm import trange

class ForwardHook:
    def __init__(self, module):
        """
        a nice forward hook
        """
        self.hook = module.register_forward_hook(self.hook_fn)
        self.input = None
        self.output = None

    def hook_fn(self, module, input, output):
        self.input = input
        self.output = output

    def close(self):
        self.hook.remove()


def get_model_layer_by_name(model: nn.Module, layer_name: str):
    """Retrieve a module nested in another by its access string.

    Works even when there is a Sequential in the module.

    Args:
        module (Union[TensorType, nn.Module]): module whose submodule you want to access
        name (str): the string representation of the submodule. Like `"module.something.this_thing"`

    Returns:
        object: module that you wanted to extract
    """
    if layer_name != "":
        names = layer_name.split(sep=".")
    else:
        return model
    return reduce(getattr, names, model)

class IntermediateOutputExtractor(nn.Module):
    def __init__(
        self, 
        model,
        forward_hook_layer,
    ):
        """An over-engineered forward hook for pytorch.

        Args:
            model (nn.Module): the model whose intermediate outputs you'd want to extract
            forward_hook_layer (nn.Module): the submodule (see: layer) within the model whose outputs you want
        """
        super().__init__()
        self.model =  model.eval()
        self.forward_hook = ForwardHook(forward_hook_layer)
        self.forward_hook_layer=forward_hook_layer

    def forward(self, x):
        """Runs a forward pass through the model and returns the forward hook output

        Args:
            x (torch.tensor): input tensor

        Returns:
            torch.tensor: forward hook output
        """
        y = self.model(x)
        assert self.forward_hook.output is not None, 'expected hook output to NOT be None, please check if your input tensor is passing through the forward_hook_layer or not'
        return self.forward_hook.output
    
    @classmethod
    def from_layer_name(cls, model, layer_name: str):
        """Neat way to directly initiate this class with a model and the name of the layer as a string.

        Args:
            model (Union[TensorType, nn.Module]): the model whose intermediate outputs you'd want to extract
            name (str): name of the the submodule (see: layer) within the model whose outputs you want

        Returns:
            IntermediateOutputExtractor: an instance of the class
        """
        forward_hook_layer = get_model_layer_by_name(model = model, layer_name = layer_name)
        
        return cls(model=model, forward_hook_layer=forward_hook_layer)

    def __call__(self, x):
        return self.forward(x)

    @torch.no_grad()
    def get_output_info(self, x):

        assert x.shape[0] == 1,f"Expected batch size to be 1, and not: {x.shape[0]}"
        y = self.forward(x)

        if isinstance(self.forward_hook_layer, nn.Conv2d):
            layer_type = "Conv2d"
        elif isinstance(self.forward_hook_layer, nn.Linear):
            layer_type = "Linear"
        else:
            layer_type = "unknown"

        return {
            "shape": y.shape,
            "layer_type": layer_type
        }

def get_intermediate_layer_activations(model, layer_name, input_images):
    """
    model: rn50 model
    layer: string layer name.
    input_images: input images as b, h, w, c
    """
    # layer_name = params.layer_name

    extractor = IntermediateOutputExtractor.from_layer_name(
            model=model,
            layer_name=layer_name
        )

    logits_list = []
    
    # Iterate over the input_images in batches
    batch_size=1
    
    with torch.no_grad():
        for i in trange(0, input_images.size(0), batch_size):
            batch_images = input_images[i:i + batch_size]
            logits = extractor.forward(batch_images)
            logits_list.append(logits)

    
    # Concatenate all logits to get the final output
    all_logits = torch.cat(logits_list, dim=0)
    
    return(all_logits)