import collections
from typing import Dict, List, Tuple

import torch
import torch.nn as nn
from torch.jit import ScriptModule


class IngredientModelWrapper(nn.Module):
    """
    Always work in evaluation mode, although all sub-script modules are fixed
    Return:
        cls_token: [bs, 1, dim]
        feat: [bs, L, dim]
        feat_origin: [bs, L, dim]
        ingredients: [bs, L]
        attn: [bs, L, L]
        attn_cls: [bs, L]
    """
    def __init__(
        self,
        backbone_jit: ScriptModule,
        codebook_jit: ScriptModule = None,
        increment_codebooks: List[Tuple] = None
    ):
        super().__init__()
        self.backbone_jit = backbone_jit
        if codebook_jit is None:
            self.incremental = True
            num_increments = len(increment_codebooks)
            self.num_ingredients = num_increments * increment_codebooks[0][0].codes.num_embeddings
            self.emb_dim = increment_codebooks[0][0].codes.embedding_dim

            self.adapter = increment_codebooks[0][1]
            self.codebook = increment_codebooks[0][0]
            for i in range(1, len(increment_codebooks)):
                cb = increment_codebooks[i][0]
                self.codebook.size += cb.codes.num_embeddings
                self.codebook.codes.num_embeddings += cb.codes.num_embeddings
                self.codebook.codes.weight.data = torch.cat((self.codebook.codes.weight, cb.codes.weight), dim=0)
            self.codebook.requires_grad = False
            return

        self.incremental = False
        self.codebook_jit = codebook_jit
        self.codebook_tensor: torch.Tensor

        self.register_buffer("codebook_tensor", codebook_jit.codebook.codes.weight)
        self.num_ingredients: int = self.codebook_tensor.shape[0]
        self.emb_dim: int = self.codebook_tensor.shape[1]

    def train(self, mode: bool = True):
        self.training = mode
        for module in self.children():
            module.train(False)
        return self

    def eval(self):
        return self.train(False)

    def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
        ret: Dict[str, torch.Tensor] = collections.OrderedDict()
        out_backbone = self.backbone_jit(x)
        mid_feat = out_backbone["mid_feat"]
        extracted_attn: torch.Tensor = out_backbone["extracted"]
        # extracted_attn: torch.Tensor = out_backbone.get("extracted")
        if self.incremental:
            feat, ingredients = self.codebook(self.adapter.adapt(mid_feat))
        else:
            feat, ingredients = self.codebook_jit(mid_feat)
        # [1, bs, dim] -> [bs, 1, dim]
        ret["cls_token"] = torch.transpose(feat[:1], 0, 1)
        # [L, bs, dim] -> [bs, L, dim]
        ret["feat"] = torch.transpose(feat[1:], 0, 1)
        ret["feat_origin"] = torch.transpose(mid_feat[1:], 0, 1)
        # [L, bs] -> [bs, L]
        ret["ingredients"] = torch.transpose(ingredients, 0, 1)
        # get attention
        bs, L = ret["ingredients"].shape
        attn = torch.zeros(bs, L + 1, L + 1, device=x.device)
        # [bs * heads, L + 1, L + 1] -> [bs, heads, L + 1, L + 1]
        if extracted_attn is not None:
            attn_heads = extracted_attn.unflatten(0, (bs, -1))
            torch.mean(attn_heads, dim=1, out=attn)
        # [bs, L, L]
        ret["attn"] = attn[..., 1:, 1:]
        # [bs, L]
        ret["attn_cls"] = attn[..., 0, 1:]
        for k, v in ret.items():
            ret[k] = v.contiguous()
        return ret

