# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# adapted from:
# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
# DeiT: https://github.com/facebookresearch/deit
# --------------------------------------------------------
from math import prod
import os
from functools import partial

import timm.models.vision_transformer
import torch
import torch.nn as nn
from vc_models.models.vit import model_utils
from timm.models.vision_transformer import resize_pos_embed, Mlp, Attention
from timm.models.crossvit import CrossAttention
import logging
from habitat_vc.models.freeze_batchnorm import convert_frozen_batchnorm
from torch.nn import MultiheadAttention
from torch.distributions import Categorical
import triton
import triton.language as tl
from . import _VISUALIZE

logger = logging.getLogger(__name__)


class VisionTransformer(timm.models.vision_transformer.VisionTransformer):
    """Vision Transformer with support for global average pooling"""

    def __init__(
        self, global_pool=False, use_cls=True, mask_ratio=None, del_head=True, freeze=False, token_grad=False, **kwargs
    ):
        super(VisionTransformer, self).__init__(**kwargs)
        if global_pool:
            self.classifier_feature = "global_pool"
        elif use_cls:
            self.classifier_feature = "use_cls_token"
        else:
            self.classifier_feature = "reshape_embedding"

        if del_head:
            del self.head  # don't use prediction head

        if self.classifier_feature == "global_pool":
            norm_layer = kwargs["norm_layer"]
            embed_dim = kwargs["embed_dim"]
            self.fc_norm = norm_layer(embed_dim)

            del self.norm  # remove the original norm

        if self.classifier_feature == "reshape_embedding":
            self.final_spatial = int(self.patch_embed.num_patches**0.5)
            self.embed_dim = (
                self.patch_embed.grid_size[0],
                self.patch_embed.grid_size[1],
                kwargs["embed_dim"],
            )

        self.mask_ratio = mask_ratio
        self.freeze = freeze
        self.token_grad = token_grad

        
    def train(self, mode = True):
        super().train(mode)
        
        if mode:
            if self.freeze:
                for p in self.parameters():
                    p.requires_grad = False
                self.norm_pre = convert_frozen_batchnorm(self.norm_pre)
                self.blocks = convert_frozen_batchnorm(self.blocks)
                self.norm = convert_frozen_batchnorm(self.norm)
                self.fc_norm = convert_frozen_batchnorm(self.fc_norm)

    def random_masking(self, x, mask_ratio):
        """
        Perform per-sample random masking by per-sample shuffling.
        Per-sample shuffling is done by argsort random noise.
        x: [N, L, D], sequence
        """
        N, L, D = x.shape  # batch, length, dim
        len_keep = int(L * (1 - mask_ratio))

        noise = torch.rand(N, L, device=x.device)  # noise in [0, 1]

        # sort noise for each sample
        ids_shuffle = torch.argsort(
            noise, dim=1
        )  # ascend: small is keep, large is remove
        ids_restore = torch.argsort(ids_shuffle, dim=1)

        # keep the first subset
        ids_keep = ids_shuffle[:, :len_keep]
        x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))

        # generate the binary mask: 0 is keep, 1 is remove
        mask = torch.ones([N, L], device=x.device)
        mask[:, :len_keep] = 0
        # unshuffle to get the binary mask
        mask = torch.gather(mask, dim=1, index=ids_restore)

        return x_masked, mask, ids_restore

    def handle_outcome(self, x):
        if self.classifier_feature == "global_pool":
            x = x[:, 1:, :].mean(dim=1)  # global pool without cls token
            outcome = self.fc_norm(x)
        elif self.classifier_feature == "use_cls_token":
            x = self.norm(x)
            outcome = x[:, 0]  # use cls token
        elif self.classifier_feature == "reshape_embedding":
            x = self.norm(x)
            outcome = reshape_embedding(
                x[:, 1:]
            )  # remove cls token and reshape embedding
        else:
            raise NotImplementedError

        return outcome

    def forward_features(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)

        # add pos embed w/o cls token
        x = x + self.pos_embed[:, 1:, :]

        # masking: length -> length * mask_ratio
        if self.mask_ratio is not None:
            x, _, _ = self.random_masking(x, mask_ratio=self.mask_ratio)

        # append cls token
        cls_token = self.cls_token + self.pos_embed[:, :1, :]
        x = torch.cat((cls_token.expand(B, -1, -1), x), dim=1)

        if self.token_grad:
            x.requires_grad = True
            x.retain_grad()
            xs = [x]
            for block in self.blocks:
                x = block(x)
                x.retain_grad()
                xs.append(x)
            self.forward_dict = {
                'xs': xs,
            }
        else:
            x = self.blocks(x)
        return self.handle_outcome(x)

    def forward(self, x):
        return self.forward_features(x)


class ClipVisionTransformer(VisionTransformer):
    def forward_features(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)
        x = torch.cat(
            [
                self.cls_token.squeeze()
                + torch.zeros(B, 1, x.shape[-1], device=x.device),
                x,
            ],
            dim=1,
        )  # shape = [*, grid ** 2 + 1, width]
        x = x + self.pos_embed.squeeze().to(x.dtype)
        x = self.norm_pre(x)

        x = self.blocks(x)
        return self.handle_outcome(x)


def reshape_embedding(x):
    N, L, D = x.shape
    H = W = int(L**0.5)
    x = x.reshape(N, H, W, D)
    x = torch.einsum("nhwd->ndhw", x)
    return x


def vit_small_patch16(**kwargs):
    """ViT small as defined in the DeiT paper."""
    model = VisionTransformer(
        patch_size=16,
        embed_dim=384,
        depth=12,
        num_heads=6,
        mlp_ratio=4,
        qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
        **kwargs
    )
    return model


def vit_base_patch16(**kwargs):
    model = VisionTransformer(
        patch_size=16,
        embed_dim=768,
        depth=12,
        num_heads=12,
        mlp_ratio=4,
        qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
        **kwargs
    )
    return model


def clip_vit_base_patch16(**kwargs):
    model = ClipVisionTransformer(
        patch_size=16,
        embed_dim=768,
        depth=12,
        num_heads=12,
        mlp_ratio=4,
        qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
        # CLIP-specific:
        pre_norm=True,
        num_classes=512,
        **kwargs
    )
    return model


def vit_large_patch16(**kwargs):
    model = VisionTransformer(
        patch_size=16,
        embed_dim=1024,
        depth=24,
        num_heads=16,
        mlp_ratio=4,
        qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
        **kwargs
    )
    return model


def single_cnn(img_size=224, use_cls=False, global_pool=False, **kwargs):
    assert not (use_cls and global_pool), "use_cls and global_pool cannot be both True"
    patch_size=16
    embed_dim=1024
    model = nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size, bias=False)
    model.final_spatial = int(img_size / patch_size)
    return model


def vit_huge_patch14(**kwargs):
    model = VisionTransformer(
        patch_size=14,
        embed_dim=1280,
        depth=32,
        num_heads=16,
        mlp_ratio=4,
        qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
        **kwargs
    )
    return model


def load_mae_encoder(model, checkpoint_path=None):
    if checkpoint_path is None:
        return model
    else:
        model_utils.download_model_if_needed(checkpoint_path)

    if not os.path.isabs(checkpoint_path):
        model_base_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)),'..','..','..')
        checkpoint_path = os.path.join(model_base_dir,checkpoint_path)
        
    state_dict = torch.load(checkpoint_path, map_location="cpu")["model"]
    if state_dict["pos_embed"].shape != model.pos_embed.shape:
        state_dict["pos_embed"] = resize_pos_embed(
            state_dict["pos_embed"],
            model.pos_embed,
            getattr(model, "num_tokens", 1),
            model.patch_embed.grid_size,
        )

    # filter out keys with name decoder or mask_token
    state_dict = {
        k: v
        for k, v in state_dict.items()
        if "decoder" not in k and "mask_token" not in k
    }

    if model.classifier_feature == "global_pool":
        # remove layer that start with norm
        state_dict = {k: v for k, v in state_dict.items() if not k.startswith("norm")}
        # add fc_norm in the state dict from the model
        state_dict["fc_norm.weight"] = model.fc_norm.weight
        state_dict["fc_norm.bias"] = model.fc_norm.bias

    model.load_state_dict(state_dict)
    return model


def load_contrastive_vit(model, checkpoint_path=None, state_dict_key="state_dict"):
    if checkpoint_path is None:
        return model

    old_state_dict = torch.load(checkpoint_path, map_location="cpu")[state_dict_key]
    state_dict = {}
    for k in list(old_state_dict.keys()):
        # retain only base_encoder up to before the embedding layer
        if k.startswith("module.base_encoder") and not k.startswith(
            "module.base_encoder.head"
        ):
            # remove prefix
            state_dict[k[len("module.base_encoder.") :]] = old_state_dict[k]
        # delete renamed or unused k
        del old_state_dict[k]

    if model.classifier_feature == "global_pool":
        # remove layer that start with norm
        state_dict = {k: v for k, v in state_dict.items() if not k.startswith("norm")}
        # add fc_norm in the state dict from the model
        state_dict["fc_norm.weight"] = model.fc_norm.weight
        state_dict["fc_norm.bias"] = model.fc_norm.bias

    if state_dict["pos_embed"].shape != model.pos_embed.shape:
        state_dict["pos_embed"] = resize_pos_embed(
            state_dict["pos_embed"],
            model.pos_embed,
            getattr(model, "num_tokens", 1),
            model.patch_embed.grid_size,
        )

    model.load_state_dict(state_dict)
    return model

def save_to_img(tensor, path, shape):
    tensor = tensor.reshape(*shape)
    tensor = (tensor * 255).clamp(0, 255).to(torch.uint8)
    tensor = tensor.cpu().numpy()
    import imageio.v2 as imageio
    imageio.imwrite(path, tensor)

class ReferenceLayer(nn.Module):
    def __init__(self, num_heads, embed_dim, certainty='entropy_sim'):
        super().__init__()
        self.num_heads = num_heads
        self.embed_dim = embed_dim
        # self.key_norm = nn.LayerNorm(embed_dim)
        # self.value_norm = nn.LayerNorm(embed_dim)
        # self.query_norm = nn.LayerNorm(embed_dim)
        # self.ffn_norm = nn.LayerNorm(embed_dim*2)
        self.ffn = Mlp(
            in_features=embed_dim*2,
            hidden_features=embed_dim * 4,
            out_features=embed_dim,
            act_layer=nn.GELU,
            drop=0.0,
        )
        self.certainty = certainty
        max_entropy = torch.log(torch.tensor(embed_dim, dtype=torch.float32)) + 1e-6
        self.register_buffer('max_entropy', max_entropy)
        self.register_buffer('eps', torch.tensor(1e-6, dtype=torch.float32))
        if certainty == 'mlp':
            self.certainty_mlp = nn.Sequential(
                nn.Linear(3, 32),
                nn.ReLU(),
                nn.Linear(32, 1),
                nn.Sigmoid()
            )
        elif certainty == 'ablation1':
            # random generated
            pass
        elif certainty == 'ablation2':
            # fixed_center
            pass
        elif certainty == 'ablation3':
            self.certainty_mlp = nn.Sequential(
                nn.Linear(embed_dim*2, 32),
                nn.ReLU(),
                nn.Linear(32, 1),
                nn.Sigmoid()
            )
        elif certainty in ['ablation4', 'ablation5', 'ablation6']:
            self.certainty_mlp = nn.Sequential(
                nn.Linear(3, 32),
                nn.ReLU(),
                nn.Linear(32, 1),
                nn.Sigmoid()
            )

    # @torch.compile
    def forward(self, current_x, memory_x, memory_target, pos_embed):
        B, L, D = current_x.shape
        # current_x: B x L x D
        # memory_x: B x L x D
        # memory_target: B x 1 x D
        # query = self.query_norm(current_x)
        # key = self.key_norm(memory_x)
        # value = self.value_norm(memory_target)
        query = current_x - pos_embed
        key = memory_x - pos_embed
        value = memory_target - pos_embed
        invalid_mask = value.abs().max(dim=-1)[0] > 10
        value[invalid_mask] = torch.zeros((1,), device=value.device)
        key[invalid_mask] = torch.zeros((1,), device=key.device)
        
        attn_weights = torch.einsum('bld,bmd->blm', query, key)
        # attn_weights = torch.cosine_similarity(query.unsqueeze(2), key.unsqueeze(1), dim=-1)
        attn_weights = attn_weights.softmax(dim=-2)
        # attn_weights = attn_weights / attn_weights.sum(dim=-1, keepdim=True)
        attn_output = torch.einsum('blm,bmd->bld', attn_weights, value)
        # import ipdb; ipdb.set_trace()
        
        output = attn_output - query
        ffn_feat = torch.cat([current_x, output], dim=-1)
        output = current_x + self.ffn(ffn_feat)
        
        attn_weights = attn_weights.detach()
        if 'entropy' in self.certainty or self.certainty == 'mlp' or self.certainty in ['ablation6', 'ablation5']:
            norm_entropy = -torch.sum(attn_weights * torch.log(attn_weights + self.eps), dim=-1) / self.max_entropy
        if 'top1' in self.certainty or self.certainty == 'mlp' or self.certainty in ['ablation4', 'ablation5']:
            top1_conf, _ = attn_weights.max(dim=-1)
        if 'sim' in self.certainty or self.certainty == 'mlp' or self.certainty in ['ablation4', 'ablation6']:
            similarity = torch.nn.functional.cosine_similarity(query, output, dim=-1)

        if self.certainty == 'mlp':
            feat = torch.stack([norm_entropy, similarity, top1_conf], dim=-1)  # [B, N, 3]
            certainty = self.certainty_mlp(feat).squeeze(-1)  # [B, N]
        elif self.certainty == 'entropy':
            certainty = 1 - norm_entropy
        elif self.certainty == 'top1':
            certainty = top1_conf
        elif self.certainty == 'entropy_sim':
            certainty = (1 - norm_entropy) * similarity
        elif self.certainty == 'top1_sim':
            certainty = top1_conf * similarity
        elif self.certainty == 'ablation1':
            certainty = torch.rand_like(output[..., 0])
        elif self.certainty == 'ablation2':
            assert L == 1601
            certainty = torch.ones(B, 40, 40)
            certainty[:, :10] = 0
            certainty[:, :, :10] = 0
            certainty[:, -10:] = 0
            certainty[:, :, -10:] = 0
            certainty = torch.cat([torch.ones(B, 1), certainty.reshape(B, -1)], dim=-1)
            certainty = certainty.to(output.device)
        elif self.certainty == 'ablation3':
            feat = torch.cat([current_x, output], dim=-1)  # [B, N, 2d]
            certainty = self.certainty_mlp(feat).squeeze(-1)  # [B, N]
        elif self.certainty == 'ablation4':
            feat = torch.stack([torch.zeros_like(similarity)+0.5, similarity, top1_conf], dim=-1)  # [B, N, 2]
            certainty = self.certainty_mlp(feat).squeeze(-1)  # [B, N]
        elif self.certainty == 'ablation5':
            feat = torch.stack([norm_entropy, torch.zeros_like(norm_entropy)+0.5, top1_conf], dim=-1)  # [B, N, 2]
            certainty = self.certainty_mlp(feat).squeeze(-1)  # [B, N]
        elif self.certainty == 'ablation6':
            feat = torch.stack([norm_entropy, similarity, torch.zeros_like(norm_entropy)+0.5], dim=-1)  # [B, N, 2]
            certainty = self.certainty_mlp(feat).squeeze(-1)  # [B, N]
        else:
            raise NotImplementedError(f"Unknown certainty type: {self.certainty}")
        
        if torch.isnan(output).any():
            import ipdb; ipdb.set_trace()
        return output, certainty
        

class SelectiveVisionTransformer(VisionTransformer):
    def __init__(
        self, reduction_layers=(3, 6), keep_ratios=(0.7, 0.7), hidden_state_dim=2048+32,
        reference_last_frame_thr=1.0, reference_last_frame_layer_idx=3, reference_certainty='entropy_sim',
        freeze_backbone=False, freeze_batchnorm=False, freeze_reference_net=False, **kwargs):
        super().__init__(**kwargs)
        self.reduction_layers = reduction_layers
        self.hidden_state_dim = hidden_state_dim
        self.keep_ratios = keep_ratios
        self.embedding_dim = self.num_features
        self.freeze_backbone = freeze_backbone
        self.freeze_batchnorm = freeze_batchnorm
        self.forward_dict = dict()

        if len(reduction_layers) > 0:
            self.fc_hidden = nn.Sequential(
                nn.Linear(self.hidden_state_dim, self.num_features),
                nn.ReLU(),
                nn.Linear(self.num_features, self.num_features),
            )
        if reference_last_frame_thr < 1.0:
            self.reference_last_frame = True
            self.reference_net = ReferenceLayer(
                num_heads=kwargs.get('num_heads', 3),
                embed_dim=self.num_features,
                certainty=reference_certainty
            )
            self.memory = None
            self.reference_last_frame_thr = reference_last_frame_thr
            self.reference_last_frame_layer_idx = reference_last_frame_layer_idx
            self.freeze_reference_net = freeze_reference_net
        else:
            self.reference_last_frame = False
    
    # @torch.jit.script
    # def generate_mask(scores: torch.Tensor, keep_ratio: float, add_prefix_ones: int = 0, straight_forward: bool = False) -> torch.Tensor:
    #     # use kthvalue to get the threshold, convert topk to smallest n-k
    #     keep_num = int(keep_ratio * scores.size(1))
    #     k_rank = scores.size(1) - keep_num
    #     topk_val = torch.kthvalue(scores, k_rank, dim=1, keepdim=True).values
    #     mask = scores > topk_val
    #     if straight_forward:
    #         mask = mask.float() - scores.detach() + scores
    #     if add_prefix_ones > 0:
    #         mask = torch.cat([torch.ones(scores.size(0), add_prefix_ones, device=mask.device, dtype=mask.dtype), mask], dim=1)
    #     return mask
    
    @torch.jit.script
    def generate_mask(scores: torch.Tensor, keep_ratio: float, add_prefix_ones: int = 0, straight_forward: bool = False) -> torch.Tensor:
        keep_num = int(keep_ratio * scores.size(1))

        if keep_ratio > 0:
            topk_vals, topk_indices = torch.topk(scores, keep_num, dim=1)
            if straight_forward:
                mask = torch.zeros_like(scores, dtype=scores.dtype)
            else:
                mask = torch.zeros_like(scores, dtype=torch.bool)
            mask.scatter_(1, topk_indices, 1)
        else:
            mask = scores > 0.5
        
        if straight_forward:
            mask = mask.float() - scores.detach() + scores
        
        if add_prefix_ones > 0:
            prefix = torch.ones(scores.size(0), add_prefix_ones, device=mask.device, dtype=mask.dtype)
            mask = torch.cat([prefix, mask], dim=1)
        return mask
    
    def train(self, mode = True):
        super().train(mode)
        
        if mode:
            if self.freeze_backbone:
                for p in self.parameters():
                    p.requires_grad = False
                if self.freeze_batchnorm:
                    self.norm_pre = convert_frozen_batchnorm(self.norm_pre)
                    self.blocks = convert_frozen_batchnorm(self.blocks)
                    self.norm = convert_frozen_batchnorm(self.norm)
                    self.fc_norm = convert_frozen_batchnorm(self.fc_norm)
                if hasattr(self, 'fc_hidden'):
                    for p in self.fc_hidden.parameters():
                        p.requires_grad = True
            if hasattr(self, 'reference_net') and not self.freeze_reference_net:
                for p in self.reference_net.parameters():
                    p.requires_grad = True
    
    def forward_vit(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)

        # add pos embed w/o cls token
        x = x + self.pos_embed[:, 1:, :]

        # masking: length -> length * mask_ratio
        if self.mask_ratio is not None:
            x, _, _ = self.random_masking(x, mask_ratio=self.mask_ratio)

        # append cls token
        cls_token = self.cls_token + self.pos_embed[:, :1, :]
        x = torch.cat((cls_token.expand(B, -1, -1), x), dim=1)

        x = self.blocks(x)
        return self.handle_outcome(x)

    def forward_features_optimized(self, x, hidden_states):
        B = x.shape[0]
        x = self.patch_embed(x)
        # add pos embed w/o cls token
        x = x + self.pos_embed[:, 1:, :]

        # masking: length -> length * mask_ratio
        if self.mask_ratio is not None:
            x, _, _ = self.random_masking(x, mask_ratio=self.mask_ratio)

        # append cls token
        cls_token = self.cls_token + self.pos_embed[:, :1, :]
        x = torch.cat((cls_token.expand(B, -1, -1), x), dim=1)
        prune_key = self.fc_hidden(hidden_states).reshape(B, 1, self.num_features)

        if self.training:
            xs = []
            for i, block in enumerate(self.blocks):
                if i in self.reduction_layers:
                    xs.append(x)
                x = block(x)
            xs.append(x)
            
            scores = torch.ones_like(xs[0][:, :, :1])
            masks = list() # prune masks
            current_ratio = 1.0
            for i in range(len(self.reduction_layers)):
                x = xs[i]
                x1 = x[:, 1:]
                x2 = x[:, :1] + prune_key
                scores = torch.nn.functional.cosine_similarity(x1, x2, dim=-1)
                keep_ratio = self.keep_ratios[i] * current_ratio
                current_ratio = keep_ratio
                if i > 0:
                    scores = scores * masks[-1]
                mask = self.generate_mask(scores, keep_ratio, add_prefix_ones=1, straight_forward=True)
                masks.append((1-mask) if i==0 else (1-mask-masks[-1]))
            masks.append(mask)
            
            
            masks = torch.stack(masks, dim=0)
            x = torch.stack(xs, dim=0)
            x = (x * masks.unsqueeze(-1)).sum(dim=0)
            return self.handle_outcome(x)
        else:
            raise NotImplementedError("Pruning is not supported in eval mode")
            current_idxes = torch.arange(x.shape[1]*B, device=x.device)
            out_x = torch.zeros_like(x).reshape(-1, x.shape[-1])
            for i, block in enumerate(self.blocks):
                if self.reference_last_frame and i == self.reference_last_frame_layer_idx and memory is not None:
                    key, value = memory
                    querys = x
                    output, certainty = self.reference_net(querys, key, value)
                    thr = self.reference_last_frame_thr
                    keep_mask = (certainty < thr).unsqueeze(-1)
                    keep_mask[:, 0] = 1
                    assert B == 1, "Batch size > 1 is not supported"
                    pruned_idxes = current_idxes[~keep_mask]
                    out_x[pruned_idxes] = output[~keep_mask].reshape(-1, x.shape[-1])
                    current_idxes = current_idxes[keep_mask.flatten()]
                    x = output[keep_mask].reshape(B, -1, x.shape[-1])
                    
                if i in self.reduction_layers:
                    x1 = x
                    x2 = x[:, :1] + prune_key
                    scores = torch.nn.functional.cosine_similarity(x1, x2, dim=-1)
                    scores[:, 0] = 100
                    prun_idx = self.reduction_layers.index(i)
                    keep_ratio = self.keep_ratios[prun_idx]
                    mask = self.generate_mask(scores, keep_ratio).bool()
                    
                    small_target_idx = current_idxes[~mask.flatten()]
                    out_x[small_target_idx] = x[~mask].reshape(-1, x.shape[-1])
                    current_idxes = current_idxes[mask.flatten()]
                    x = x[mask].reshape(B, -1, x.shape[-1])
                x = block(x)
            out_x[current_idxes] = x.reshape(-1, x.shape[-1])
            out_x = out_x.reshape(B, -1, out_x.shape[-1])
            return out_x
    
    def forward_features_optimized2(self, x, hidden_states=None, memory=None):
        B = x.shape[0]
        x = self.patch_embed(x)
        # add pos embed w/o cls token
        x = x + self.pos_embed[:, 1:, :]

        # masking: length -> length * mask_ratio
        if self.mask_ratio is not None:
            x, _, _ = self.random_masking(x, mask_ratio=self.mask_ratio)

        # append cls token
        cls_token = self.cls_token + self.pos_embed[:, :1, :]
        x = torch.cat((cls_token.expand(B, -1, -1), x), dim=1)
        if hasattr(self, 'fc_hidden'):
            prune_key = self.fc_hidden(hidden_states).reshape(B, 1, self.num_features)

        forward_dict = dict()
        if self.training:
            xs = []
            for i, block in enumerate(self.blocks):
                if self.reference_last_frame and i == self.reference_last_frame_layer_idx:
                    mem_key = x
                    if memory is not None:
                        assert B == 1, "Batch size > 1 is not supported"
                        key, value = memory
                        query = x
                        output, certainty = self.reference_net(query, key, value, self.pos_embed)
                        thr = self.reference_last_frame_thr
                        temporal_mask = certainty < thr
                        temporal_mask[:, 0] = 1
                        x = output[temporal_mask].reshape(B, -1, x.shape[-1])
                    
                if i in self.reduction_layers:
                    xs.append(x)
                    
                x = block(x)
            xs.append(x)
            
            if len(self.reduction_layers) > 0:
                scores = torch.ones_like(xs[0][:, :, :1])
                masks = list()
                current_ratio = 1.0
                for i in range(len(self.reduction_layers)):
                    x = xs[i]
                    x1 = x[:, 1:]
                    x2 = x[:, :1] + prune_key
                    scores = torch.nn.functional.cosine_similarity(x1, x2, dim=-1)
                    keep_ratio = self.keep_ratios[i] * current_ratio
                    current_ratio = keep_ratio
                    if i > 0:
                        scores = scores * masks[-1]
                    forward_dict[f'keep_probs_{i}'] = scores
                    mask = self.generate_mask(scores, keep_ratio, add_prefix_ones=1, straight_forward=True)
                    masks.append((1-mask) if i==0 else (1-mask-masks[-1]))
                masks.append(mask)
                
                
                masks = torch.stack(masks, dim=0)
                x = torch.stack(xs, dim=0)
                x = (x * masks.unsqueeze(-1)).sum(dim=0)
            
            for k, v in forward_dict.items():
                if k not in self.forward_dict:
                    self.forward_dict[k] = []
                self.forward_dict[k].append(v)
            self.forward_dict['merge_function'] = self.merge_forward_dict_train
            
            if self.reference_last_frame:
                if memory is not None:
                    # fuse back to temporal
                    final_output = output
                    final_output[temporal_mask] = x.reshape(-1, x.shape[-1]) # inplace operation is invalid when backward
                else:
                    final_output = x
                memory = (mem_key.detach(), final_output.detach())
                final_output = self.handle_outcome(final_output)
                return final_output, memory
            else:
                return self.handle_outcome(x)
        else:
            current_idxes = torch.arange(x.shape[1]*B, device=x.device)
            out_x = torch.zeros_like(x).reshape(-1, x.shape[-1])
            if _VISUALIZE:
                keep_confidence = torch.ones_like(out_x[..., 0]).flatten()

            for i, block in enumerate(self.blocks):
                if self.reference_last_frame and i == self.reference_last_frame_layer_idx:
                    querys = x
                    if memory is not None:
                        key, value = memory
                        output, certainty = self.reference_net(querys, key, value, self.pos_embed)
                        thr = self.reference_last_frame_thr
                        keep_mask = (certainty < thr)
                        if _VISUALIZE:
                            keep_confidence[current_idxes] = 1 - certainty
                        keep_mask[:, 0] = 1
                        assert B == 1, "Batch size > 1 is not supported"
                        pruned_idxes = current_idxes[~keep_mask.flatten()]
                        out_x[pruned_idxes] = output[~keep_mask].reshape(-1, x.shape[-1])
                        current_idxes = current_idxes[keep_mask.flatten()]
                        x = output[keep_mask].reshape(B, -1, x.shape[-1])
                        ratio1 = current_idxes.shape[0] / out_x.shape[0]
                    
                if i in self.reduction_layers:
                    x1 = x
                    x2 = x[:, :1] + prune_key
                    scores = torch.nn.functional.cosine_similarity(x1, x2, dim=-1)
                    scores[:, 0] = 1
                    if _VISUALIZE:
                        keep_confidence[current_idxes] = scores
                    prun_idx = self.reduction_layers.index(i)
                    keep_ratio = self.keep_ratios[prun_idx]
                    mask = self.generate_mask(scores, keep_ratio).bool()
                    
                    small_target_idx = current_idxes[~mask.flatten()]
                    out_x[small_target_idx] = x[~mask].reshape(-1, x.shape[-1])
                    current_idxes = current_idxes[mask.flatten()]
                    x = x[mask].reshape(B, -1, x.shape[-1])
                    ratio2 = current_idxes.shape[0] / out_x.shape[0]
                x = block(x)
            if _VISUALIZE:
                keep_confidence = keep_confidence.reshape(B, -1)
                keep_confidence = keep_confidence[:, 1:]
                self.forward_dict = {
                    'mask': keep_confidence
                }

            out_x[current_idxes] = x.reshape(-1, x.shape[-1])
            out_x = out_x.reshape(B, -1, out_x.shape[-1])
            
            if self.reference_last_frame:
                memory = (querys, out_x)
                return self.handle_outcome(out_x), memory
            else:
                return self.handle_outcome(out_x)
    
    def merge_forward_dict_train(self, forward_dict):
        result = dict()

        if len(self.keep_ratios) > 0 and self.keep_ratios[0] == 0.0:
            for reduction_id, l in enumerate(self.reduction_layers):
                keep_probs = torch.cat(forward_dict[f'keep_probs_{reduction_id}'], dim=1)  # L x B x L
                keep_ratio = keep_probs.mean()
                result[f'keep_ratio_{l}'] = keep_ratio.item()
                result[f'loss/keep_{l}'] = (keep_ratio**2) * 0.1

        # clear the forward_dict for next batch
        for k in list(forward_dict.keys()):
            del forward_dict[k]
        return result
    
    def forward_features(self, x, hidden_states=None, masks=None):
        if self.reference_last_frame:
            assert hidden_states is not None
            output = list()
            if self.memory is None:
                self.memory = [None] * x.shape[0]
            for b in range(x.shape[0]):
                if not masks[b]:
                    self.memory[b] = None
                out, mem = self.forward_features_optimized2(x[b:b+1], hidden_states[b:b+1], self.memory[b])
                self.memory[b] = mem
                output.append(out)
            x = torch.cat(output, dim=0)
        else:
            x = self.forward_features_optimized2(x, hidden_states)
        # return self.forward_vit(x)
        return x
        B = x.shape[0]
        x = self.patch_embed(x)
        # add pos embed w/o cls token
        x = x + self.pos_embed[:, 1:, :]

        # masking: length -> length * mask_ratio
        if self.mask_ratio is not None:
            x, _, _ = self.random_masking(x, mask_ratio=self.mask_ratio)

        # append cls token
        cls_token = self.cls_token + self.pos_embed[:, :1, :]
        x = torch.cat((cls_token.expand(B, -1, -1), x), dim=1)
        prune_key = self.fc_hidden(hidden_states).reshape(B, 1, self.num_features)

        if self.training:
            mask_history = []
            for i, block in enumerate(self.blocks):
                if i in self.reduction_layers:
                    x1 = x[:, 1:]
                    x2 = x[:, :1].detach() + prune_key
                    scores = torch.nn.functional.cosine_similarity(x1, x2, dim=-1)
                    prun_idx = self.reduction_layers.index(i)
                    keep_ratio = self.keep_ratios[prun_idx]
                    if prun_idx > 0:
                        scores = scores * mask
                    mask = self.generate_mask(scores, keep_ratio, add_prefix_ones=1, straight_forward=True)
                    mask_history.append(mask)
                x_new = block(x)
                if i >= self.reduction_layers[0]:
                    mask = mask_history[-1]
                    x = x_new * mask.unsqueeze(-1) + x * (1 - mask.unsqueeze(-1))
                else:
                    x = x_new
            return self.handle_outcome(x)
        else:
            current_idxes = torch.arange(x.shape[1]*B, device=x.device)
            out_x = torch.zeros_like(x).reshape(-1, x.shape[-1])
            for i, block in enumerate(self.blocks):
                if i in self.reduction_layers:
                    x1 = x
                    x2 = x[:, :1] + prune_key
                    scores = torch.nn.functional.cosine_similarity(x1, x2, dim=-1)
                    scores[:, 0] = torch.float("inf")
                    prun_idx = self.reduction_layers.index(i)
                    keep_ratio = self.keep_ratios[prun_idx] if prun_idx==0 else (self.keep_ratios[prun_idx]/self.keep_ratios[prun_idx-1])
                    mask = self.generate_mask(scores, keep_ratio).bool()
                    small_target_idx = current_idxes[~mask.flatten()]
                    out_x[small_target_idx] = x[~mask].reshape(-1, x.shape[-1])
                    # scatter
                    # out_x.scatter_(0, small_target_idx.unsqueeze(-1).expand(-1, x.shape[-1]), x[~mask].reshape(-1, x.shape[-1]))
                    current_idxes = current_idxes[mask.flatten()]
                    x = x[mask].reshape(B, -1, x.shape[-1])
                x = block(x)
            out_x[current_idxes] = x.reshape(-1, x.shape[-1])
            out_x = out_x.reshape(B, -1, out_x.shape[-1])
            return self.handle_outcome(out_x)

    def forward(self, x, hidden_states=None, masks=None):
        return self.forward_features(x, hidden_states, masks)
        

def deit_tiny_patch16_224(**kwargs):
    model = VisionTransformer(
        patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs
    )
    return model

def selective_deit_tiny_patch16_224(**kwargs):
    kwargs.pop('requires_state_keys')
    model = SelectiveVisionTransformer(
        patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs
    )
    return model

def deit_small_patch16_224(**kwargs):
    """ DeiT-small model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
    ImageNet-1k weights from https://github.com/facebookresearch/deit.
    """
    # model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
    model = VisionTransformer(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
    return model

def load_deit_vit(model, checkpoint_path=None):
    if checkpoint_path is None:
        return model
    else:
        model_utils.download_model_if_needed(checkpoint_path, prefix_url="https://dl.fbaipublicfiles.com/deit/")

    if not os.path.isabs(checkpoint_path):
        model_base_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)),'..','..','..')
        checkpoint_path = os.path.join(model_base_dir,checkpoint_path)
        
    state_dict = torch.load(checkpoint_path, map_location="cpu")["model"]
    new_state_dict = dict()
    for k in state_dict:
        if k.startswith('model.'):
            nk = k[6:]
        else:
            nk = k
        new_state_dict[nk] = state_dict[k]
    state_dict = new_state_dict

    if state_dict["pos_embed"].shape != model.pos_embed.shape:
        state_dict["pos_embed"] = resize_pos_embed(
            state_dict["pos_embed"],
            model.pos_embed,
            getattr(model, "num_tokens", 1),
            model.patch_embed.grid_size,
        )

    # filter out keys with name decoder or mask_token
    state_dict = {
        k: v
        for k, v in state_dict.items()
        if "decoder" not in k and "mask_token" not in k
    }

    if model.classifier_feature == "global_pool":
        # remove layer that start with norm
        state_dict = {k: v for k, v in state_dict.items() if not k.startswith("norm")}
        # add fc_norm in the state dict from the model
        state_dict["fc_norm.weight"] = model.fc_norm.weight
        state_dict["fc_norm.bias"] = model.fc_norm.bias

    if not hasattr(model, "head"):
        # remove head if not exist
        state_dict = {k: v for k, v in state_dict.items() if not k.startswith("head")}
    
    msg = model.load_state_dict(state_dict, strict=False)
    logger.warning(msg)
    return model




# class DeformableVisionTransformerV1(VisionTransformer):
#     """Vision Transformer with support for token reduction"""

#     def __init__(
#         self, reduction_layers=(3,6,9), keep_ratios=(0.7, 0.49, 0.343), hidden_state_dim=2048+32, **kwargs
#     ):
#         super().__init__(**kwargs)
#         self.reduction_layers = reduction_layers
#         self.hidden_state_dim = hidden_state_dim
#         self.fc_hidden = nn.Linear(self.hidden_state_dim, self.num_features)
#         self.keep_ratios = keep_ratios
    
#     def forward_features(self, x, hidden_states=None, masks=None):
#         B = x.shape[0]
#         x = self.patch_embed(x)

#         # add pos embed w/o cls token
#         x = x + self.pos_embed[:, 1:, :]

#         # masking: length -> length * mask_ratio
#         if self.mask_ratio is not None:
#             x, _, _ = self.random_masking(x, mask_ratio=self.mask_ratio)

#         # append cls token
#         cls_token = self.cls_token + self.pos_embed[:, :1, :]
#         x = torch.cat((cls_token.expand(B, -1, -1), x), dim=1)

#         hidden_states = self.fc_hidden(hidden_states).reshape(B, self.num_features)
#         output = torch.zeros_like(x)
#         remain_idx = torch.arange(x.shape[1], device=x.device).unsqueeze(0).expand(B, -1) # map from current x to output
#         cur_layer = 0
#         for layer_idx, keep_ratio in zip(self.reduction_layers, self.keep_ratios):
#             x = self.blocks[cur_layer:layer_idx](x)
            
#             scores = torch.einsum("bld,bd->bl", x, hidden_states)
#             scores[:, 0] = torch.float("inf")
#             keep_num = int(keep_ratio * scores.shape[1])
#             pruned_num = scores.shape[1] - keep_num
#             _, topk_idx = torch.topk(scores, keep_num, dim=1, sorted=False, largest=True) # B x k
            
#             mask = torch.ones_like(scores, dtype=torch.bool)
#             mask.scatter_(1, topk_idx, 0)
#             small_target_idx = remain_idx[mask].reshape(B, pruned_num)
#             pruned_features = x.gather(1, remain_idx.unsqueeze(-1).expand(-1, -1, x.shape[-1]))
#             output.scatter_(1, small_target_idx.unsqueeze(-1).expand(-1, -1, x.shape[-1]), pruned_features)
            
#             remain_idx = remain_idx.gather(1, topk_idx)
#             x = x.gather(1, topk_idx.unsqueeze(-1).expand(-1, -1, x.shape[-1]))
#             cur_layer = layer_idx
#         x = self.blocks[cur_layer:](x)
#         output.scatter_(1, remain_idx.unsqueeze(-1).expand(-1, -1, x.shape[-1]), x)
        
#         return self.handle_outcome(output)

#     def forward(self, x, *args):
#         return self.forward_features(x, *args)