import math
import logging
from functools import partial
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.nn.functional as F

from timm.models.vision_transformer import VisionTransformer, _cfg
from timm.models.registry import register_model

from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F

from timm.layers import get_act_layer
from timm.layers.config import use_fused_attn
from timm.layers.weight_init import trunc_normal_tf_

from .custom_utils.tensor_plot import log_tensor_statistics, plot_tensor_views, save_top5_abs_tokens, log_tensor_statistics_with_rank, save_attention_heatmap
from .custom_utils.plot_utils_vit import plot_3d_feat_vit
from .custom_utils.sweep import *

import numpy as np
# from probe import probe
# from token_probe import norm_probing_not_sorted
# from token_select import token_select
import pandas as pd
import os
import re
import json
from datetime import datetime
import matplotlib.pyplot as plt

PROBE=False
## About Fc2 input
CLAMP=False 
CLAMP_VAL=2.5

QKV_CACHE_PATH_TEMPLATE = "/home/user/regcache/clip_global_prefix/prefix_from_{}/block_{}_top20_results/qkv_cache_input_token/"
QKV_CACHE_FILES = [
    ('rank_{:02d}/block_{:02d}_qkv.pt', 1), 
]

Q_BLOCK_INPUT_PATH_TEMPLATE = "/home/user/regcache/clip_global_prefix/prefix_from_{}/block_{}_top20_results/block_input_token/"
Q_BLOCK_PREFIX_FILES = [
    ('rank_{:02d}/block_{:02d}_input.pt', 1),
]

FREQUENT_OUTLIER_INDEX = [139, 25, 16, 170]
RANDOM_OUTLIER_INDEX1 = [1,30,51,68,91,114,177,180]
RANDOM_OUTLIER_INDEX2 = [38,63,72,114,118,123,158,162]
RANDOM_OUTLIER_INDEX3 = [14,20,88,105,146,163,177,190]

#Probing 
HEATMAP_PROBE = False
TOKEN_INDEX_PROBE= False

OPTION='delete_compare'
_logger = logging.getLogger(__name__)
# __all__ = ['qt_deit_small_patch16_224']

def save_token_norm_plot(x, block_num, model_name, option, layer_info):
    save_dir=f"/home/user/regcache/Meeting_plot/stats/{model_name}/{layer_info}/{option}"
    os.makedirs(save_dir, exist_ok=True)
    
    x_abs = x[0,:,:].abs()

    token_maxs, _ = torch.max(x_abs, dim=-1)
    token_norms_np = token_maxs.cpu().detach().numpy()
    
    plt.figure(figsize=(12, 6))
    plt.bar(range(len(token_norms_np)), token_norms_np)
    plt.xlabel('Token Index')
    plt.ylabel('Magnitude (L1 Norm)')
    plt.grid(axis='y', linestyle='--')

    # plot 파일 경로 설정 및 저장
    plot_filename = f"block{block_num}.png"
    plot_save_path = os.path.join(save_dir, plot_filename)
    plt.savefig(plot_save_path)
    plt.close()
    print(f"Plot saved to {plot_save_path}") 

    x_max = x_abs.max().item()
    
    probe_filename = f'block{block_num}.txt'
    probe_file_path = os.path.join(save_dir, probe_filename)

    with open(probe_file_path, 'a') as f:
        f.write(f'Block {block_num}: {x_max}\n')

    return



def error_log(sweep_config): 
    output_dir = "/home/user/regcache/failed_experiment_config"
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    file_name = f"config_{timestamp}.json"
    save_path = os.path.join(output_dir, file_name)
    with open(save_path, 'w', encoding='utf-8') as f:
        json.dump(sweep_config, f, ensure_ascii=False, indent=4)
    print(f"Saved failed config to: {save_path}")
    return

def round_pass(x):
    y = x.round()
    y_grad = x
    return y.detach() - y_grad.detach() + y_grad

class Quantizer():
    def __init__(self, N_bits: int, q_type: str, signed: bool = True, symmetric: bool = True):
        super().__init__()
        self.N_bits = N_bits
        self.signed = signed
        self.symmetric = symmetric
        self.q_type = q_type
        self.minimum_range = 1e-6
        
        if self.N_bits is None:
            return 

        if self.signed:
            self.Qn = - 2 ** (self.N_bits - 1)
            self.Qp = 2 ** (self.N_bits - 1) - 1
            
        else:
            self.Qn = 0
            self.Qp = 2 ** self.N_bits - 1

    def __call__(self, x):  
        return self.forward(x)

    def forward(self, x): 
        if self.N_bits is None:
            scale = torch.ones_like(x)
            return x, 1
            
        if self.symmetric:
            if self.q_type == 'per_tensor': 
                max_x = x.abs().max()
            elif self.q_type == 'per_token': 
                max_x = x.abs().amax(dim=-1, keepdim=True)              
            elif self.q_type == 'per_channel': 
                max_x = x.abs().amax(dim=0, keepdim=True)
            max_x = max_x.clamp_(self.minimum_range)
            scale = max_x / self.Qp
            x = x / scale 
            x = round_pass(x)
   
        else: #Asymmetric
            if self.q_type == 'per_tensor': 
                min_x = x.min().detach()
                max_x = x.max().detach()
            elif self.q_type == 'per_token': 
                min_x = x.min(dim=-1, keepdim=True).detach()
                max_x = x.max(dim=-1, keepdim=True).detach()
            elif self.q_type == 'per_channel': 
                min_x = x.min(dim=0, keepdim=True).detach()
                max_x = x.max(dim=0, keepdim=True).detach()

            range_x = (max_x - min_x).detach().clamp_(min=self.minimum_range)
            scale = range_x / (self.Qp - self.Qn)
            zero_point = torch.round((min_x / scale) - self.Qn)
            x = (x / scale) + zero_point
            x = round_pass(x.clamp_(self.Qn, self.Qp))

        return x, scale


class QuantAct(nn.Module):
    def __init__(self, 
                 N_bits: int, 
                 q_type: str , 
                 signed: bool = True, 
                 symmetric: bool = True):
        super(QuantAct, self).__init__()
        self.quantizer = Quantizer(N_bits=N_bits, q_type = q_type, signed=signed, symmetric=symmetric)

    def forward(self, x):
        q_x, s_qx = self.quantizer(x)
        return q_x, s_qx
class Quantized_Linear(nn.Linear):
    def __init__(self, prefix_position, weight_quantize_module: Quantizer, act_quantize_module: Quantizer, weight_grad_quantize_module: Quantizer, act_grad_quantize_module: Quantizer,
                 in_features, out_features, abits, bias=True, clamping=False):
        super(Quantized_Linear, self).__init__(in_features, out_features, bias=bias)
        self.weight_quantize_module = weight_quantize_module
        self.act_quantize_module = act_quantize_module
        self.weight_grad_quantize_module = weight_grad_quantize_module
        self.act_grad_quantize_module = act_grad_quantize_module
        self.prefix_qmodule = Quantizer(abits, 'per_token' if act_quantize_module.q_type is not None else None)
        self.clamping = clamping
        self.prefix_position = prefix_position

    def forward(self, input, block_num=None, prefix_token_num=0):
        return _quantize_global.apply(self.prefix_position,block_num, prefix_token_num, input, self.weight, self.bias, self.weight_quantize_module,
                                      self.act_quantize_module, self.weight_grad_quantize_module, self.act_grad_quantize_module, self.prefix_qmodule, self.clamping)
    
class _quantize_global(torch.autograd.Function):
    @staticmethod
    def forward(ctx, prefix_position, block_num, prefix_token_num, x, w, bias=None,
                w_qmodule=None, a_qmodule=None, w_g_qmodule=None, a_g_qmodule=None, prefix_qmodule=None, clamping=False):
        #save for backward
        ctx.block_num = block_num
        ctx.a_g_qmodule = a_g_qmodule
        ctx.w_g_qmodule = w_g_qmodule 
        ctx.has_bias = bias is not None

        B, S, C = x.shape[0], x.shape[1], x.shape[2]
        ctx.x_size = B,S,C
        if clamping == True:
            x = torch.clamp(x, min=a_g_qmodule.Qn, max=a_g_qmodule.Qp)
        x_2d = x.view(-1, C)
    
        #full precision
        if all(x is None for x in (w_qmodule.N_bits, a_qmodule.N_bits, w_g_qmodule.N_bits, a_g_qmodule.N_bits)):
            ctx.fullprecision = True
            output = torch.matmul(x_2d, w.t())
            ctx.weight = w
            ctx.activation = x_2d
            if bias is not None:
                output += bias.unsqueeze(0).expand_as(output)
            return output.view(B, S, -1)
        else: 
            ctx.fullprecision = False

        #Quantization 
        if prefix_token_num == 0:
            input_quant, s_input_quant = a_qmodule(x_2d)
            weight_quant, s_weight_quant = w_qmodule(w)

            ctx.weight = (weight_quant, s_weight_quant)
            ctx.activation = (input_quant, s_input_quant) if w_g_qmodule.N_bits is not None else x_2d
            output = torch.matmul(input_quant, weight_quant.t())
            s_o = s_input_quant * s_weight_quant
            output = output * s_o
            if bias is not None:
                output += bias.unsqueeze(0).expand_as(output)
            return output.view(B, S, -1)
        
        else: 
            if prefix_position == 'front' or prefix_position == 'specific': 
                prefix_token = x[:, :prefix_token_num, :]  # [B, prefix_token_num + CLS token, C]
                patch_x = x[:, prefix_token_num:, :]  # [B, N, C]

                prefix_token = prefix_token.reshape(-1, C)  # [B * (prefix_token_num + 1), C]
                patch_x = patch_x.reshape(-1, C)            # [B * patch_num, C]

                q_prefix_token, s_prefix_token = prefix_qmodule(prefix_token)  # quantized registers
                q_patch_x, s_patch_x = a_qmodule(patch_x)                      # quantized patches

                q_prefix_token = q_prefix_token.reshape(B, -1, C)  # [B, prefix_token_num, C]
                q_patch_x = q_patch_x.reshape(B, -1, C)            # [B, patch_num, C]

                input_quant = torch.cat([q_prefix_token,
                                        q_patch_x
                                        ], dim=1)  # [B, patch_num + prefix_token_num, C]

                s_prefix_token = s_prefix_token.reshape(B, -1) # [B, prefix_token_num + 1]
                if a_qmodule.q_type == 'per_token':
                    s_patch_x = s_patch_x.reshape(B, -1)                   # [B, patch_num]
                elif a_qmodule.q_type == 'per_tensor':
                    s_patch_x = s_patch_x.expand(B, S - prefix_token_num)  # broadcast to [B, patch_num]

                s_input_quant = torch.cat((s_prefix_token, s_patch_x), dim=1)  # [B, total_token]
                
            elif prefix_position == 'middle': 
                half_size = (S - prefix_token_num) // 2  # Example: 98
                prefix_token = x[:, half_size:half_size + prefix_token_num, :]  # [B, prefix_token_num, C]
                first_half, second_half = x[:, :half_size, :], x[:, half_size + prefix_token_num:, :]
                patch_x = torch.cat([first_half, second_half], dim=1)  # [B, N, C]

                prefix_token = prefix_token.reshape(-1, C)  # [B * prefix_token_num, C]
                patch_x = patch_x.reshape(-1, C)            # [B * patch_num, C]

                q_prefix_token, s_prefix_token = prefix_qmodule(prefix_token)  # quantized registers
                q_patch_x, s_patch_x = a_qmodule(patch_x)                      # quantized patches

                q_prefix_token = q_prefix_token.reshape(B, -1, C)  # [B, prefix_token_num, C]
                q_patch_x = q_patch_x.reshape(B, -1, C)            # [B, patch_num, C]

                input_quant = torch.cat([
                                        q_patch_x[:, :half_size, :], 
                                        q_prefix_token, 
                                        q_patch_x[:, half_size:, :]
                                        ], dim=1)  # [B, patch_num + prefix_token_num, C]
                s_prefix_token = s_prefix_token.reshape(B, -1) # [B, prefix_token_num]
                if a_qmodule.q_type == 'per_token':
                    s_patch_x = s_patch_x.reshape(B, -1)                   # [B, patch_num]
                elif a_qmodule.q_type == 'per_tensor':
                    s_patch_x = s_patch_x.expand(B, S - prefix_token_num)  # broadcast to [B, patch_num]

                s_input_quant = torch.cat((s_patch_x[:, :half_size], s_prefix_token, s_patch_x[:, half_size:]), dim=1)  # [B, total_token]
        
                
            elif prefix_position == 'back': 
                prefix_token = x[:, -(prefix_token_num):, :] #[32, 9, 768]
                patch_x = x[:, :-(prefix_token_num), :] #[32, 196, 768]

                prefix_token = prefix_token.reshape(-1, C) #[288, 768]
                patch_x = patch_x.reshape(-1, C) #[6272, 768]
                
                q_prefix_token, s_prefix_token = prefix_qmodule(prefix_token) #per-token 
                q_patch_x, s_patch_x = a_qmodule(patch_x)
                
                q_prefix_token = q_prefix_token.reshape(B,-1, C)
                q_patch_x = q_patch_x.reshape(B,-1,C)
                input_quant = torch.cat((q_patch_x, q_prefix_token), dim=1)
                if a_qmodule.q_type == 'per_token':
                    s_prefix_token = s_prefix_token.reshape(B, -1) #[32, 9]
                    s_patch_x = s_patch_x.reshape(B, -1) #[32, 196]
                    s_input_quant = torch.cat((s_patch_x, s_prefix_token),dim=1) # Here!
                elif a_qmodule.q_type == 'per_tensor':
                    s_prefix_token = s_prefix_token.reshape(B, -1) #[32, 9]
                    s_patch_x = s_patch_x.expand(B, S-prefix_token_num)
                    s_input_quant = torch.cat((s_patch_x, s_prefix_token), dim=1) # Here!
                    

            ctx.activation = (input_quant, s_input_quant) if w_g_qmodule.N_bits is not None else x_2d
            weight_quant, s_weight_quant = w_qmodule(w) 
            ctx.weight = (weight_quant, s_weight_quant)

            s_o = s_weight_quant * s_input_quant
            output = torch.matmul(input_quant, weight_quant.t())
            output = output.view(B, S, -1) #[32, 205, 768])
            s_o = s_o.unsqueeze(-1).expand(-1, -1, output.shape[2]) 
            output = output * s_o
            if bias is not None:
                output += bias.unsqueeze(0).expand_as(output)
            return output

    @staticmethod
    def backward(ctx, g_3D):
        g_2D = g_3D.reshape(-1, g_3D.size(-1)) #reshape to 2D
        grad_X = grad_W = grad_bias = None
        B,S,C = ctx.x_size
        
        if ctx.fullprecision:
            w = ctx.weight
            x = ctx.activation
            grad_W = torch.matmul(g_2D.t(), x)
            grad_X = torch.matmul(g_2D, w)
            grad_X = grad_X.view(B,S,-1)
            if ctx.has_bias:
                grad_bias = g_2D.sum(dim=0)
            else:
                grad_bias = None
        else:
            q_w, s_w = ctx.weight
            a_g_qmodule = ctx.a_g_qmodule
            w_g_qmodule = ctx.w_g_qmodule
            a_g_2D_quant, a_s_g_2D_quant = a_g_qmodule(g_2D)
            
            if w_g_qmodule.N_bits is not None:
                w_g_2D_quant, w_s_g_2D_quant = w_g_qmodule(g_2D)
                (q_x, s_x) = ctx.activation
            else:
                x = ctx.activation
            grad_X = torch.matmul(a_g_2D_quant, q_w)
            s_grad_X = a_s_g_2D_quant * s_w
            grad_X = grad_X * s_grad_X

            #Weigth Gradient
            if w_g_qmodule.N_bits is None: 
                grad_W = torch.matmul(g_2D.t(), x)
            else: 
                q_x = q_x.reshape(-1,q_x.size(-1))
                s_x = s_x.reshape(-1,s_x.size(-1))
                grad_W = torch.matmul(w_g_2D_quant.t(), q_x) #([768, 3072])
                s_grad_W = w_s_g_2D_quant * s_x
                grad_W = grad_W * s_grad_W

            if ctx.has_bias:
                grad_bias = g_2D.sum(dim=0)
            else:
                grad_bias = None
            grad_X = grad_X.view(B,S,-1)
        return None, None, None, grad_X, grad_W, grad_bias, None, None, None, None, None, None

class Mlp(nn.Module):
    def __init__(
            self,
            sweep_config,
            prefix_position,
            block_num,
            w_quant_type, wg_quant_type,
            a_quant_type, ag_quant_type,
            abits, 
            wbits,
            w_gbits, 
            a_gbits,
            in_features,
            hidden_features=None,
            act_layer=nn.GELU, 
            clamping=False):
        super().__init__()
        self.block_num = block_num
        out_features = in_features

        hidden_features = hidden_features or in_features * 4
        self.fc1 = Quantized_Linear(
                                prefix_position=prefix_position,
                                weight_quantize_module=Quantizer(wbits, w_quant_type), 
                                act_quantize_module=Quantizer(abits, a_quant_type), 
                                weight_grad_quantize_module=Quantizer(None, None),
                                act_grad_quantize_module=Quantizer(None, None),
                                in_features=in_features, 
                                out_features=hidden_features, 
                                abits = abits,
                                bias=True,
                                clamping=clamping
                                )
        self.act = act_layer()




        self.fc2 = Quantized_Linear(
                                prefix_position=prefix_position,
                                weight_quantize_module=Quantizer(wbits, w_quant_type), 
                                act_quantize_module=Quantizer(abits, a_quant_type),
                                weight_grad_quantize_module=Quantizer(None, None),
                                act_grad_quantize_module=Quantizer(None, None),
                                in_features=hidden_features, 
                                out_features=out_features, 
                                abits = abits, 
                                bias=True,
                                clamping=clamping
                                )
        self.sweep_config = sweep_config


    def forward(self, x, epoch=None, iteration=None, device_id=None, prefix_token_num=0):


        indices_to_delete = None
        x = self.fc1(x)
        x = self.act(x)


        if self.sweep_config['token_delete'] and self.block_num in self.sweep_config['token_delete_block']:
            topk = self.sweep_config['token_delete_number']
            B, N, C = x.shape
            token_max_norms, _ = torch.max(x[:, 1: ,:].abs(), dim=-1)
            _, indices_to_delete = torch.topk(token_max_norms, k=topk, dim=-1)
            keep_mask = torch.ones(B, N, dtype=torch.bool, device=x.device)
            keep_mask.scatter_(1, indices_to_delete, False)
            num_kept_tokens = N - topk
            x = torch.masked_select(x, keep_mask.unsqueeze(-1)).view(B, num_kept_tokens, C)

        
        x = self.fc2(x)

        return x,indices_to_delete

class Attention(nn.Module):
    def __init__(
            self,
            sweep_config,
            prefix_position,
            block_num,
            w_quant_type, wg_quant_type,
            a_quant_type, ag_quant_type,
            abits, 
            wbits, 
            w_gbits,
            a_gbits,
            dim,
            num_heads,
            qkv_bias=True,
            clamping=False):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        self.block_num = block_num
        self.qkv = Quantized_Linear(
                                    prefix_position=prefix_position,
                                    weight_quantize_module=Quantizer(wbits, w_quant_type), 
                                    act_quantize_module=Quantizer(abits, a_quant_type),
                                    weight_grad_quantize_module=Quantizer(None, None),
                                    act_grad_quantize_module=Quantizer(None, None),
                                    in_features=dim, 
                                    out_features=dim * 3, 
                                    abits = abits,
                                    bias=qkv_bias,
                                    clamping=clamping
                                    )
      
        self.proj = Quantized_Linear(
                                prefix_position=prefix_position,
                                weight_quantize_module=Quantizer(wbits, w_quant_type), 
                                act_quantize_module=Quantizer(abits, a_quant_type),
                                weight_grad_quantize_module=Quantizer(None, None),
                                act_grad_quantize_module=Quantizer(None, None),
                                in_features=dim, 
                                out_features=dim, 
                                abits = abits,
                                bias=True,
                                clamping=clamping
                                )
        
        self.index_correction = 0
        if sweep_config['prefix_add']:
            if sweep_config['cache_option'] == 'Token_Insert' and self.block_num >= min(sweep_config['prefix_add_block']):
                self.index_correction = sweep_config['prefix_number']
            elif sweep_config['cache_option'] == 'KV_Cache' and self.block_num in sweep_config['prefix_add_block']:
                self.index_correction = sweep_config['prefix_number']

    def forward(self, x, epoch=None, iteration=None, device_id=None, sweep_config=None, dynamic_delete_index=None):
        predicted_delete_index=None
        B, N, C = x.shape 
        x = self.qkv(x)
        qkv = x.reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0) 
        q = q * self.scale
        q, k, v = apply_prefix_cache(q, k, v, sweep_config, self.block_num, qkv_cache_path_template=QKV_CACHE_PATH_TEMPLATE, qkv_cache_files=QKV_CACHE_FILES)
        attn = (q @ k.transpose(-2, -1))
        prediction_blocks = calculate_prediction_blocks(sweep_config)
        attn = attn.softmax(dim=-1)



        if sweep_config['prefix_add'] and sweep_config['cache_option'] == 'Token_Insert' and self.block_num in sweep_config['prefix_add_block']:
            prefix_number_in_certain_block =  self.index_correction
        else:
            prefix_number_in_certain_block = 0

        x = (attn @ v).transpose(1, 2)
        x = x.reshape(B, N, C)
        x = self.proj(x) 
        
        return x, predicted_delete_index

class Q_Block(nn.Module):
    def __init__(self, timm_kwargs, block_num, dim, num_heads, mlp_ratio=4.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm, clamping=False):
        super().__init__()
        wbits = timm_kwargs.get('wbits', 8)
        abits = timm_kwargs.get('abits', 8)
        w_quant_type = timm_kwargs.get('w_quant_type', 'per_tensor')
        a_quant_type = timm_kwargs.get('a_quant_type', 'per_tensor')
        
        wg_quant_type = timm_kwargs.get('wg_quant_type', None)
        ag_quant_type = timm_kwargs.get('ag_quant_type', None)
        w_gbits = timm_kwargs.get('w_gbits', None)
        a_gbits = timm_kwargs.get('a_gbits', None)

        prefix_position = timm_kwargs.get('prefix_position', 'front')

        self.sweep_config = {
            "prefix_add": timm_kwargs.get("prefix_add", False),
            "cache_option": timm_kwargs.get("cache_option", None),
            "prefix_add_block": timm_kwargs.get("prefix_add_block", []),
            "prefix_number": timm_kwargs.get("prefix_number", 0),
            "token_delete": timm_kwargs.get("token_delete", False),
            "token_delete_block": timm_kwargs.get("token_delete_block", None),
            "token_delete_number": timm_kwargs.get("token_delete_number", None),
            "token_delete_previous_layer": timm_kwargs.get("token_delete_previous_layer", None),
            "token_delete_method": timm_kwargs.get("token_delete_method", None),
            "head_index_for_score": timm_kwargs.get("head_index_for_score", None),
            "delete_option": timm_kwargs.get("delete_option", None),
            "global_prefix_rank": timm_kwargs.get("global_prefix_rank", None),
            "target_block": timm_kwargs.get("target_block", None),
            "target_layer": timm_kwargs.get("target_layer", None),
        }

        self.norm1 = norm_layer(dim)
        self.block_num = block_num
        self.attn = Attention(
            self.sweep_config,
            prefix_position,
            block_num,
            w_quant_type, wg_quant_type,
            a_quant_type, ag_quant_type,
            abits,
            wbits,
            w_gbits,
            a_gbits,
            dim,
            num_heads=num_heads,
            clamping=clamping,
        )
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(
            self.sweep_config,
            prefix_position,
            block_num,
            w_quant_type, wg_quant_type,
            a_quant_type, ag_quant_type,
            abits,
            wbits,
            w_gbits,
            a_gbits,
            in_features=dim,
            hidden_features=mlp_hidden_dim,
            act_layer=act_layer,
            clamping=clamping
        )
        self.index_correction = 0
        if self.sweep_config['prefix_add']:
            if self.sweep_config['cache_option'] == 'Token_Insert' and self.block_num >= min(self.sweep_config['prefix_add_block']):
                self.index_correction = self.sweep_config['prefix_number']


    def forward(self, x, dynamic_delete_index=None):
        predicted_delete_index = None
        residual_1 = x

        x = self.norm1(x)
        if self.sweep_config['token_delete'] and self.sweep_config['token_delete_method'] in ['attn', 'value']:
            token_delete_prediction_number = [
                b - self.sweep_config['token_delete_previous_layer']
                for b in self.sweep_config['token_delete_block']
            ]
        x, predicted_delete_index = self.attn(x, sweep_config=self.sweep_config, dynamic_delete_index=dynamic_delete_index)

        # TODO: Prefix Add in residual
        # residual_1 = add_residual_prefix(residual_1, self.sweep_config, self.block_num, x, q_block_input_path_template=Q_BLOCK_INPUT_PATH_TEMPLATE, q_block_prefix_files=Q_BLOCK_PREFIX_FILES)

        #TODO: token delete in residual
        # residual_1 = delete_residual_tokens(residual_1, self.sweep_config, self.block_num, self.index_correction, x,
        #                                     random_outlier_index1=RANDOM_OUTLIER_INDEX1, random_outlier_index2=RANDOM_OUTLIER_INDEX2, random_outlier_index3=RANDOM_OUTLIER_INDEX3, 
        #                                     predicted_delete_index=predicted_delete_index, dynamic_delete_index=dynamic_delete_index, error_log=error_log)
        x = residual_1 + x
        residual_2 = x
        x = self.norm2(x)

        if self.sweep_config['cache_option'] == 'Token_Insert' and self.block_num >= min(self.sweep_config['prefix_add_block']):
            prefix_number_in_certain_block = self.sweep_config['prefix_number']
        else:
            prefix_number_in_certain_block = 0

        delete_index = None
        x, delete_index = self.mlp(x, prefix_token_num=prefix_number_in_certain_block)
        if delete_index is not None: 
            residual_2 = delete_residual_tokens2(residual_2, delete_index)
        x = residual_2 + x
        return x, predicted_delete_index


class AttentionPoolLatent(nn.Module):
    fused_attn: torch.jit.Final[bool]

    def __init__(
            self,
            in_features: int,
            out_features: int = None,
            embed_dim: int = None,
            num_heads: int = 8,
            feat_size: Optional[int] = None,
            mlp_ratio: float = 4.0,
            qkv_bias: bool = True,
            qk_norm: bool = False,
            latent_len: int = 1,
            latent_dim: int = None,
            pos_embed: str = '',
            pool_type: str = 'token',
            norm_layer: Optional[nn.Module] = None,
            act_layer: Optional[nn.Module] = nn.GELU,
            drop: float = 0.0,
    ):
        super().__init__()
        embed_dim = embed_dim or in_features
        out_features = out_features or in_features
        assert embed_dim % num_heads == 0
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.feat_size = feat_size
        self.scale = self.head_dim ** -0.5
        self.pool = pool_type
        self.fused_attn = use_fused_attn()

        if pos_embed == 'abs':
            assert feat_size is not None
            self.pos_embed = nn.Parameter(torch.zeros(feat_size, in_features))
        else:
            self.pos_embed = None

        self.latent_dim = latent_dim or embed_dim
        self.latent_len = latent_len
        self.latent = nn.Parameter(torch.zeros(1, self.latent_len, embed_dim))

        self.q = nn.Linear(embed_dim, embed_dim, bias=qkv_bias)
        self.kv = nn.Linear(embed_dim, embed_dim * 2, bias=qkv_bias)
        if qk_norm: 
            qk_norm_layer = norm_layer or nn.LayerNorm
            self.q_norm = qk_norm_layer(self.head_dim)
            self.k_norm = qk_norm_layer(self.head_dim)
        else:
            self.q_norm = nn.Identity()
            self.k_norm = nn.Identity()
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.proj_drop = nn.Dropout(drop)

        self.norm = norm_layer(out_features) if norm_layer is not None else nn.Identity()
        self.mlp = Mlp(
            sweep_config=None,
            prefix_position=None,
            block_num=None,
            w_quant_type=None, wg_quant_type=None,
            a_quant_type=None, ag_quant_type=None,
            abits=None, 
            wbits=None, 
            w_gbits=None, 
            a_gbits=None,
            in_features=embed_dim,
            hidden_features=int(embed_dim * mlp_ratio),
            act_layer=act_layer,
            clamping=False
        )

        self.init_weights()

    def init_weights(self):
        if self.pos_embed is not None:
            trunc_normal_tf_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5)
        trunc_normal_tf_(self.latent, std=self.latent_dim ** -0.5)

    def forward(self, x):
        B, N, C = x.shape

        if self.pos_embed is not None:
            # FIXME interpolate
            x = x + self.pos_embed.unsqueeze(0).to(x.dtype)

        q_latent = self.latent.expand(B, -1, -1)
        q = self.q(q_latent).reshape(B, self.latent_len, self.num_heads, self.head_dim).transpose(1, 2)

        kv = self.kv(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        k, v = kv.unbind(0)

        q, k = self.q_norm(q), self.k_norm(k)

        if self.fused_attn:
            x = F.scaled_dot_product_attention(q, k, v)
        else:
            q = q * self.scale
            attn = q @ k.transpose(-2, -1)
            attn = attn.softmax(dim=-1)
            x = attn @ v
        x = x.transpose(1, 2).reshape(B, self.latent_len, C)
        x = self.proj(x)
        x = self.proj_drop(x)

        x = x + self.mlp(self.norm(x),prefix_token_num=0)

        if self.pool == 'token':
            x = x[:, 0]
        elif self.pool == 'avg':
            x = x.mean(1)
        return x

class CustomSequential(nn.Sequential):
    def __init__(self, *modules):
        super(CustomSequential, self).__init__()
        for idx, module in enumerate(modules):
            self.add_module(str(idx), module)

    def forward(self, x, epoch, iteration, device_id, prefix_token_num, sweep_config):
        dynamic_delete_index = None

        for module in self._modules.values():
            x, dynamic_delete_index = module(x, epoch, iteration, device_id, prefix_token_num, sweep_config, dynamic_delete_index)
        
        return x


class lowbit_VisionTransformer(VisionTransformer):
    def __init__(self, pos_embed_order, w_quant_type, wg_quant_type, a_quant_type, ag_quant_type, register_num, num_classes, abits, wbits, w_gbits, a_gbits,
        patch_size, embed_dim, depth, num_heads, 
        norm_layer, sweep_config, mlp_ratio=4, qkv_bias=True, class_token=True, act_layer=None, global_pool=None, prefix_type=None, embed_prefix=True, prefix_position=None, clamping=False, pretrained_cfg=None, pretrained_cfg_overlay=None, **kwargs):
        super().__init__(patch_size=patch_size, 
                         embed_dim=embed_dim, 
                         depth=depth, 
                         num_heads=num_heads,
                         mlp_ratio=mlp_ratio, 
                         qkv_bias=qkv_bias,
                         norm_layer=norm_layer, 
                         **kwargs)

        num_patches = self.patch_embed.num_patches
        act_layer = get_act_layer(act_layer) or nn.GELU
        self.sweep_config = sweep_config
        self.pos_embed_order=pos_embed_order
        self.prefix_position = prefix_position
        self.prefix_token_num = register_num
        self.prefix_type = prefix_type
        self.reg_token = None
        self.embed_prefix = embed_prefix


        self.a_quant_type = a_quant_type
        self.abits = abits
        self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()      

        print(f"🧬 Prefix token ({prefix_type}): {register_num}")
        print(f"🧩 Embedding prefix token: {self.embed_prefix}")
        print(f"🧱 Quantization Settings:\n"
            f"   - Activation quantization: {a_quant_type} in {abits} bit.\n"
            f"   - Weight quantization: {w_quant_type} in {wbits} bit.")


        self.blocks = CustomSequential(*[
            Q_Block(sweep_config, self.prefix_position, w_quant_type, wg_quant_type, a_quant_type, ag_quant_type, abits, wbits, w_gbits, a_gbits, block_num=i, dim=embed_dim,
                    num_heads=num_heads, mlp_ratio=mlp_ratio, clamping=clamping)
            for i in range(depth)])
    

    def forward_features(self, x, epoch, iteration, device_id):
        x = self.patch_embed(x)
        x = self._pos_embed(x)
        x = self.patch_drop(x)
        x = self.norm_pre(x)
        x = self.blocks(x, epoch, iteration, device_id, self.prefix_token_num, self.sweep_config)
        # x = x[:, self.sweep_config['prefix_number']:,:]
        x = self.norm(x)
        return x

 

    def forward(self, x, siglip=None, epoch=None, iteration=None, device_id=None):
        x = self.forward_features(x, epoch, iteration, device_id)
        if PROBE or HEATMAP_PROBE : 
            exit()

        x = self.forward_head(x)

        return x  
     


#✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨
@register_model
def fullbits_vit_base_patch16_clip_224(
                                        pos_embed_order=None,
                                        wbits=None,
                                        abits=None,
                                        w_quant_type=None, 
                                        wg_quant_type=None, 
                                        a_quant_type=None,
                                        ag_quant_type=None, 
                                        register_num=0, 
                                        num_classes=0, 
                                        prefix_position=None,
                                        prefix_type=None,
                                        embed_prefix=True,
                                        **kwargs):


    sweep_config = {
        "prefix_add": kwargs.get("prefix_add", None),
        "cache_option": kwargs.get("cache_option", None),
        "prefix_add_block": kwargs.get("prefix_add_block", None),
        "prefix_number": kwargs.get("prefix_number", None),
        "token_delete": kwargs.get("token_delete", None),
        "token_delete_block": kwargs.get("token_delete_block", None),
        "token_delete_number": kwargs.get("token_delete_number", None),
        "token_delete_previous_layer": kwargs.get("token_delete_previous_layer", None),
        "token_delete_method": kwargs.get("token_delete_method", None),
        "head_index_for_score": kwargs.get("head_index_for_score", None),
        "delete_option": kwargs.get("delete_option", None),
        "global_prefix_rank": kwargs.get("global_prefix_rank", None),
        "target_block": kwargs.get("target_block", None),
        "target_layer": kwargs.get("target_layer", None),
        }


    model = lowbit_VisionTransformer(
        pos_embed_order=pos_embed_order,
        w_quant_type=w_quant_type,
        wg_quant_type=wg_quant_type,
        a_quant_type=a_quant_type,
        ag_quant_type=ag_quant_type,
        register_num=register_num, num_classes=num_classes,
        abits=abits, wbits=wbits, w_gbits=None, a_gbits=None,
        prefix_type=prefix_type, embed_prefix=embed_prefix, prefix_position=prefix_position, clamping=False, 
        patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=partial(nn.LayerNorm, eps=1e-5),
        sweep_config=sweep_config,
    )
    model.default_cfg = _cfg()

    return model