import torch
from torch import Tensor
import torch.nn as nn
from bort.models.layers import ConvLayer

class BaseVisualizer:
    def __init__(self):
        self.output_list = []
        self.is_enable = False
        self.register_dict = {}
    
    def enable(self, is_enable: bool = False):
        self.output_list = []
        self.is_enable = is_enable
    
    def wrap(self, model: nn.Module):
        for k, v in model.named_modules():
            if isinstance(v, (nn.Conv2d, ConvLayer)):
                v.register_forward_hook(self.forward_hook)
                self.register_dict[id(v)] = k
        return model
    
    def process(self, module: nn.Module, input: Tensor, output: Tensor):
        pass # TODO:

    def forward_hook(self, module: nn.Module, input: Tensor, output: Tensor):
        if self.is_enable:
            self.process(module, input, output)