from typing import Dict

import torch
import torch.nn as nn
from torch.utils.hooks import RemovableHandle

from .codebook import Codebook
from .adapter import Adapter


class MidEncoder:
    def __init__(
        self,
        model: nn.Module,
        encode_layer: str,
        codebook: Codebook,
        adapter: Adapter
    ):
        """
        Args:
            encode_layer: layer name where the codebook will apply to its output
        """
        super().__init__()
        self.encode_layer = encode_layer
        self.codebook = codebook
        self.adapter = adapter

        self.hook = self.register_forward_hooks(model)
        self.mid_dict: Dict[str, torch.Tensor] = {
            "origin_seq": None,
            "encoded_seq": None,
            "match": None
        }

    def register_forward_hooks(self, model: nn.Module) -> RemovableHandle:
        raw_model = model
        if isinstance(model, nn.parallel.DistributedDataParallel):
            raw_model = model.module
        for name, module in raw_model.named_modules():
            if name == self.encode_layer:
                # define hook
                def forward_hook(module, input, output):
                    self.mid_dict["origin_seq"] = output
                    output = self.adapter.adapt(output)
                    output, match = self.codebook(output)
                    output, match = self.adapter.reconstruct(output, match)
                    self.mid_dict["encoded_seq"] = output
                    self.mid_dict["match"] = match
                    return output

                handle = module.register_forward_hook(forward_hook)
                return handle

    def clear(self):
        self.hook.remove()
        for k in self.mid_dict:
            self.mid_dict[k] = None

