"""
save_decoupled_jit.py

This script save a codebook model to separate traced jit models as `encoder`, `codebook`, and `decoder`.

Warning:
    For saving memory, output jit model will be traced as a static function and, therefor,
    no longer support accessing to its parameters (as constants) nor submodules and will never be trainable.
    However, if input tensor `requires_grad == True`, the gradients can be correctly computed.
"""

import argparse
import os
import collections

import torch
import torch.nn as nn
import torch.jit
from torch.utils.hooks import RemovableHandle

from cv_lib.utils import get_cfg

from models import get_model, ModelWrapper
from dark_kg.utils import load_pretrain_model
from codebook import Codebook
from codebook.adapter import get_adapter, Adapter


class Decoupling:
    def __init__(
        self,
        encode_layer: str,
        codebook: Codebook,
        adapter: Adapter,
        extract_layer: str = None
    ):
        """
        Args:
            encode_layer: layer name where the codebook will apply to its output
            extract_layer: extract feature from given layer name
        """
        super().__init__()
        self.encode_layer = encode_layer
        self.extract_layer = extract_layer
        self.codebook = codebook
        self.adapter = adapter

        self.hook: RemovableHandle = None
        self.extract_layer_hook: RemovableHandle = None
        self.mid_feat: torch.Tensor = None
        self.discrete_feat: torch.Tensor = None
        self.extracted: torch.Tensor = None

    def register_backbone_hooks(self, model: nn.Module) -> RemovableHandle:
        self.clear()
        raw_model = model
        if isinstance(model, nn.parallel.DistributedDataParallel):
            raw_model = model.module
        for name, module in raw_model.named_modules():
            if name == self.encode_layer:
                # define hook
                def forward_hook(module, input, output):
                    self.mid_feat = output

                self.hook = module.register_forward_hook(forward_hook)
            if name == self.extract_layer:
                def forward_hook(module, intput, output):
                    self.extracted = output
                self.extract_layer_hook = module.register_forward_hook(forward_hook)

    def register_cls_header_hooks(self, model: nn.Module) -> RemovableHandle:
        self.clear()
        raw_model = model
        if isinstance(model, nn.parallel.DistributedDataParallel):
            raw_model = model.module
        for name, module in raw_model.named_modules():
            if name == self.encode_layer:
                # define hook
                def forward_hook(module, input, output):
                    return self.discrete_feat

                self.hook = module.register_forward_hook(forward_hook)
                return

    def clear(self):
        self.hook = None
        self.mid_feat = None
        self.discrete_feat = None
        self.match = None
        if self.hook is not None:
            self.hook.remove()
        if self.extract_layer_hook is not None:
            self.extract_layer_hook.remove()


class JitWrapper:
    def __init__(self, model: ModelWrapper, decoupling: Decoupling, model_input: torch.Tensor):
        self.model = model
        self.decoupling = decoupling
        self.model_input: torch.Tensor = model_input

    def backbone_forward(self, dummy_input: torch.Tensor):
        self.decoupling.register_backbone_hooks(self.model)
        self.model(dummy_input)
        ret = collections.OrderedDict()
        ret["mid_feat"] = self.decoupling.mid_feat
        extracted = self.decoupling.extracted
        if extracted is not None:
            ret["extracted"] = extracted
        return ret

    def cls_header_forward(self, dummy_input: torch.Tensor):
        self.decoupling.register_cls_header_hooks(self.model)
        self.decoupling.discrete_feat = dummy_input
        return self.model(self.model_input)

    def backbone_codebook_forward(self, dummy_input: torch.Tensor):
        backbone_out = self.backbone_forward(dummy_input)
        mid_feat = self.decoupling.adapter.adapt(backbone_out["mid_feat"])
        output, match = self.decoupling.codebook(mid_feat)
        output, match = self.decoupling.adapter.reconstruct(output, match)
        return output


class CodebookJitWrapper(nn.Module):
    def __init__(self, codebook: Codebook, adapter: Adapter):
        super().__init__()
        self.codebook = codebook
        self.adapter = adapter

    def forward(self, dummy_input: torch.Tensor):
        seq = self.adapter.adapt(dummy_input)
        output, match = self.codebook(seq)
        output, match = self.adapter.reconstruct(output, match)
        return output, match


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--cfg_fp", type=str)
    parser.add_argument("--ckpt_fp", type=str)
    parser.add_argument("--codebook_fp", type=str)
    parser.add_argument("--save_path", type=str)
    parser.add_argument("--num_classes", type=int, default=100)
    parser.add_argument("--img_size", type=int, default=224)
    parser.add_argument("--img_channels", type=int, default=3)
    parser.add_argument("--extract_layer", type=str, default=None)
    args = parser.parse_args()

    os.makedirs(args.save_path, exist_ok=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    cfg = get_cfg(args.cfg_fp)
    model_cfg = get_cfg(cfg["model"])
    codebook_cfg = cfg["codebook"]

    model: ModelWrapper = get_model(model_cfg["model"], args.num_classes, with_wrapper=True)
    codebook = Codebook(**codebook_cfg["codebook_cfg"])
    adapter = get_adapter(codebook_cfg["adapter_name"])
    decoupling = Decoupling(
        codebook_cfg["encoder_layer"],
        codebook,
        adapter,
        extract_layer=args.extract_layer
    )

    # load state dict
    ckpt = torch.load(args.ckpt_fp, map_location="cpu")
    load_pretrain_model(ckpt, model)
    decoupling.codebook.initial_codes(args.codebook_fp)

    model_input = torch.randn(1, args.img_channels, args.img_size, args.img_size).to(device)
    jit_wrapper = JitWrapper(model, decoupling, model_input)
    codebook_jit_wrapper = CodebookJitWrapper(codebook, adapter)

    model.eval().requires_grad_(False).to(device)
    codebook.eval().requires_grad_(False).to(device)
    codebook_jit_wrapper.eval().requires_grad_(False).to(device)

    # get mid seq
    mid_feat = jit_wrapper.backbone_forward(model_input)["mid_feat"]

    # tracing
    backbone_jit: torch.jit.ScriptModule = torch.jit.trace(
        jit_wrapper.backbone_forward,
        (model_input,),
        strict=False
    )
    codebook_jit: torch.jit.ScriptModule = torch.jit.trace(
        codebook_jit_wrapper,
        (mid_feat,),
        strict=False
    )
    cls_header_jit: torch.jit.ScriptModule = torch.jit.trace(
        jit_wrapper.cls_header_forward,
        (mid_feat,),
        strict=False
    )
    backbone_codebook_jit: torch.jit.ScriptModule = torch.jit.trace(
        jit_wrapper.backbone_codebook_forward,
        (model_input,),
        strict=False
    )

    torch.jit.save(backbone_jit, os.path.join(args.save_path, "backbone-jit.pth"))
    torch.jit.save(cls_header_jit, os.path.join(args.save_path, "cls_header-jit.pth"))
    torch.jit.save(codebook_jit, os.path.join(args.save_path, "codebook-jit.pth"))
    torch.jit.save(backbone_codebook_jit, os.path.join(args.save_path, "backbone_codebook-jit.pth"))

