""" Vision Transformer (ViT) in PyTorch

A PyTorch implement of Vision Transformers as described in:

'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale'
    - https://arxiv.org/abs/2010.11929

`How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers`
    - https://arxiv.org/abs/2106.10270

`FlexiViT: One Model for All Patch Sizes`
    - https://arxiv.org/abs/2212.08013

The official jax code is released and available at
  * https://github.com/google-research/vision_transformer
  * https://github.com/google-research/big_vision

Acknowledgments:
  * The paper authors for releasing code and weights, thanks!
  * I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch
  * Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
  * Bert reference code checks against Huggingface Transformers and Tensorflow Bert

Hacked together by / Copyright 2020, Ross Wightman
"""
import logging
import math
from collections import OrderedDict
from functools import partial
from typing import Any, Callable, Dict, Optional, Sequence, Set, Tuple, Type, Union, List
try:
    from typing import Literal
except ImportError:
    from typing_extensions import Literal

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.jit import Final
import copy

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \
    OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
from timm.layers import PatchEmbed, Mlp, DropPath, AttentionPoolLatent, RmsNorm, PatchDropout, SwiGLUPacked, \
    trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn, \
    get_act_layer, get_norm_layer, LayerType
from timm.models._builder import build_model_with_cfg
from timm.models._manipulate import named_apply, checkpoint_seq, adapt_input_conv
from timm.models._registry import generate_default_cfgs, register_model, register_model_deprecations
import time

_logger = logging.getLogger(__name__)


class Attention(nn.Module):
    fused_attn: Final[bool]

    def __init__(
            self,
            dim: int,
            num_heads: int = 8,
            qkv_bias: bool = False,
            qk_norm: bool = False,
            attn_drop: float = 0.,
            proj_drop: float = 0.,
            norm_layer: nn.Module = nn.LayerNorm,
    ) -> None:
        super().__init__()
        assert dim % num_heads == 0, 'dim should be divisible by num_heads'
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        self.fused_attn = use_fused_attn()

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
        self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        self.cur_kv_len = 0

    def forward(self, x: torch.Tensor, current, cache_dic) -> torch.Tensor:
        # B, N, C = x.shape
        # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        # q, k, v = qkv.unbind(0)
        # q, k = self.q_norm(q), self.k_norm(k)
        self.cur_kv_len = cache_dic['enco_cache'][current['enco_layer_idx']]['cur_kv_len']
        if current['token_cache'] and not current['is_force_fresh']:
            B, N, C = x.shape
            app_kv_len = N - 64

            qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
            q, k, v = qkv.unbind(0)
            q, k = self.q_norm(q), self.k_norm(k)
            cache_dic['enco_cache'][current['enco_layer_idx']]['k'].to(k)
            cache_dic['enco_cache'][current['enco_layer_idx']]['v'].to(k)

            cache_dic['enco_cache'][current['enco_layer_idx']]['k'][:,:,64+self.cur_kv_len:64+self.cur_kv_len+app_kv_len,:] = k[:,:,64:,:]
            cache_dic['enco_cache'][current['enco_layer_idx']]['k'][:,:,0:64,:] = k[:,:,0:64,:]

            cache_dic['enco_cache'][current['enco_layer_idx']]['v'][:,:,64+self.cur_kv_len:64+self.cur_kv_len+app_kv_len,:] = v[:,:,64:,:]
            cache_dic['enco_cache'][current['enco_layer_idx']]['v'][:,:,0:64,:] = v[:,:,0:64,:]

            self.cur_kv_len = self.cur_kv_len + app_kv_len
            available_kv_len = self.cur_kv_len + 64

            k = cache_dic['enco_cache'][current['enco_layer_idx']]['k']
            v = cache_dic['enco_cache'][current['enco_layer_idx']]['v']

        else:
            B, N, C = x.shape
            qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
            q, k, v = qkv.unbind(0)
            q, k = self.q_norm(q), self.k_norm(k)
            cache_dic['enco_cache'][current['enco_layer_idx']]['k'].to(k)
            cache_dic['enco_cache'][current['enco_layer_idx']]['v'].to(k)
            if current['cfg_cache']:
                cache_dic['enco_cache'][current['enco_layer_idx']]['k'][:int(B/2), :self.num_heads, :N, :self.head_dim] = k[:int(B/2)]
                cache_dic['enco_cache'][current['enco_layer_idx']]['v'][:int(B/2), :self.num_heads, :N, :self.head_dim] = v[:int(B/2)]
            else:
                cache_dic['enco_cache'][current['enco_layer_idx']]['k'][:B, :self.num_heads, :N, :self.head_dim] = k[:B]
                cache_dic['enco_cache'][current['enco_layer_idx']]['v'][:B, :self.num_heads, :N, :self.head_dim] = v[:B]
            self.cur_kv_len = N - 64
            available_kv_len = self.cur_kv_len + 64

        if False:
            x = F.scaled_dot_product_attention(
                q, k[:, :, :available_kv_len, :], v[:, :, :available_kv_len, :],
                dropout_p=self.attn_drop.p if self.training else 0.,
            )
        else:
            q = q * self.scale
            attn = torch.matmul(q, k[:, :, :available_kv_len, :].transpose(-2, -1))
            # attn = q @ k[:, :, :available_kv_len, :].transpose(-2, -1)
            attn = attn.softmax(dim=-1)
            attn = self.attn_drop(attn)
            x = torch.matmul(attn, v[:, :, :available_kv_len, :])
            # x = attn @ v[:, :, :available_kv_len, :]

        x = x.transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        cache_dic['enco_cache'][current['enco_layer_idx']]['cur_kv_len'] = self.cur_kv_len
        self.cur_kv_len = 0
        return x


class LayerScale(nn.Module):
    def __init__(
            self,
            dim: int,
            init_values: float = 1e-5,
            inplace: bool = False,
    ) -> None:
        super().__init__()
        self.inplace = inplace
        self.gamma = nn.Parameter(init_values * torch.ones(dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x.mul_(self.gamma) if self.inplace else x * self.gamma


class Block(nn.Module):
    def __init__(
            self,
            dim: int,
            num_heads: int,
            mlp_ratio: float = 4.,
            qkv_bias: bool = False,
            qk_norm: bool = False,
            proj_drop: float = 0.,
            attn_drop: float = 0.,
            init_values: Optional[float] = None,
            drop_path: float = 0.,
            act_layer: nn.Module = nn.GELU,
            norm_layer: nn.Module = nn.LayerNorm,
            mlp_layer: nn.Module = Mlp,
    ) -> None:
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            qk_norm=qk_norm,
            attn_drop=attn_drop,
            proj_drop=proj_drop,
            norm_layer=norm_layer,
        )
        self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
        self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        self.norm2 = norm_layer(dim)
        self.mlp = mlp_layer(
            in_features=dim,
            hidden_features=int(dim * mlp_ratio),
            act_layer=act_layer,
            drop=proj_drop,
        )
        self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
        self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def forward(self, x: torch.Tensor, current, cache_dic) -> torch.Tensor:
        B, N, C = x.size()
        current['module'] = 'attn'
        x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), current, cache_dic)))
        current['module'] = 'mlp'
        x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
        return x



