import re
import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..config import cfg
from ..module import check_skip_layers
# from transformers.activations import ACT2FN
# from pruning_module import HiddenRepresentationPruning
# from hf.utils import generate_probe, check_nan_inf, get_next_layer  
from prismatic.vla.constants import (
    ACTION_DIM,
    ACTION_PROPRIO_NORMALIZATION_TYPE,
    ACTION_TOKEN_BEGIN_IDX,
    IGNORE_INDEX,
    NUM_ACTIONS_CHUNK,
    STOP_INDEX,
    NormalizationType,
)
from .triton_linear import (
    indices_linear, 
    up_channel_level_indices_linear, 
    down_channel_level_indices_linear, 
    down_channel_transpose_level_indices_linear, 
    original_linear, 
)
# import torch.cuda.nvtx as nvtx
from experiments.robot.libero.run_libero_eval import GenerateConfig as generate_cfg

class LlamaPPModel(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.forward = self.model.forward
        self.add_pruner()

    def add_pruner(self):
        self._find_and_replace()
        mark_no_trainable(self.model)        
        return
    
    def _create_new_module(self, target, key):  
        has_old_bias = hasattr(target, "bias") and target.bias is not None
        FORCE_BIAS_SUFFIX = ("mlp.down_proj", "self_attn.o_proj")
        force_bias = any(key.endswith(suf) for suf in FORCE_BIAS_SUFFIX)
        if generate_cfg.static_prune == True:
            want_bias = has_old_bias or force_bias
        else: 
            want_bias = False

        in_features = getattr(target, "in_features", None)
        out_features = getattr(target, "out_features", None)
        
        kwargs = {
            "prune_metric": cfg["prune_metric"],
            "key": key,
            "dev": target.weight.device,
            "is_GQA": self.model.language_model.config.num_key_value_heads != self.model.language_model.config.num_attention_heads 
        }
        
        if isinstance(target, torch.nn.Linear):
            in_features, out_features = target.in_features, target.out_features
        else:
            raise ValueError(
                f"Target module {target} is not supported. "
                f"Currently, only `torch.nn.Linear` is supported."
            )
        new_module = Linear(in_features, out_features, bias=want_bias, **kwargs)

        return new_module


    def _find_and_replace(self):
        is_target_modules_in_base_model = False
        key_list = [key for key, _ in self.model.named_modules()]
        # return
        target_modules = _get_target_modules(cfg)
        for key in key_list:
            if "dense" in cfg["prune_method"] or "llmpruner" in cfg["prune_method"] or "loraprune" in cfg["prune_method"]:
                continue

            if not _check_target_module_exists(target_modules, key):
                continue
            
            if check_skip_layers(key):
                continue

            is_target_modules_in_base_model = True
            parent, target, target_name = _get_submodules(self.model, key)
            
            new_module = self._create_new_module(target, key)
            
            self._replace_module(parent, target_name, new_module, target)
        if not is_target_modules_in_base_model:
            print(
                f"Target modules {target_modules} not found in the base model. "
                f"Please check the target modules and try again."
            )

    def _replace_module(self, parent_module, child_name, new_module, old_module):
        setattr(parent_module, child_name, new_module)
        if ("use_triton" in cfg and cfg["use_triton"]) and ("transpose_down_linear" in cfg and cfg["transpose_down_linear"]) and (child_name == "o_proj" or child_name == "down_proj"): 
            new_module.weight = nn.Parameter(old_module.weight.permute(1,0).contiguous())
        else: 
            new_module.weight = old_module.weight
            new_module.bias = old_module.bias
        new_module.weight.requires_grad = False
        new_module.device = old_module.weight.device
        new_module.is_pruned = True

    def __getattr__(self, name: str):
        """Forward missing attributes to the wrapped module."""
        try:
            return super().__getattr__(name)  # defer to nn.Module's logic
        except AttributeError:
            return getattr(self.model, name)


def mark_no_trainable(model: nn.Module) -> None:
    for n, p in model.named_parameters():
        p.requires_grad = False
    return

def _get_submodules(model, key):
    parent = model.get_submodule(".".join(key.split(".")[:-1]))
    target_name = key.split(".")[-1]
    target = model.get_submodule(key)
    return parent, target, target_name

def _get_target_modules(cfg):
    target_modules = cfg["cust_tgt_modules"]
    return target_modules

def _check_target_module_exists(target_modules, key):
    if isinstance(target_modules, str):
        target_module_found = re.fullmatch(target_modules, key)
    else:
        target_module_found = any(key.endswith(target_key) for target_key in target_modules)
    return target_module_found

def indices0_weight(weight: torch.nn.Parameter, indices): 
    old_weight_data = weight.data
    new_weight_data = old_weight_data[indices].clone()
    new_weight = torch.nn.Parameter(new_weight_data)
    new_weight.requires_grad = False
    del old_weight_data
    torch.cuda.empty_cache()
    return new_weight

def indices1_weight(weight: torch.nn.Parameter, indices): 
    old_weight_data = weight.data
    new_weight_data = old_weight_data[:, indices].clone()
    new_weight = torch.nn.Parameter(new_weight_data)
    new_weight.requires_grad = False
    del old_weight_data
    torch.cuda.empty_cache()
    return new_weight

class EriLayer:
    def __init__(self, in_features: int, out_features: int, **kwargs):
        self.key = kwargs["key"]
        return

    def extract_in_dim_weight(self, weight, indices):
        return  weight[:, indices]
           
    def extract_out_dim_weight(self, weight, indices):
        return  weight[indices, :]


class Linear(nn.Linear, EriLayer):
    def __init__(
        self,
        in_features,
        out_features,
        bias,
        **kwargs,
    ):
        nn.Linear.__init__(self, in_features, out_features, bias=bias)
        EriLayer.__init__(self, in_features=in_features, out_features=out_features, **kwargs)
        self.is_GQA = kwargs["is_GQA"]
        # Freezing the pre-trained weight matrix
        self.weight.requires_grad = False
        self.layer_type = "linear"
        self.in_features = in_features
        self.prune_metric = cfg["prune_metric"]

        self.retrieve_weight = torch.cuda.Event(enable_timing=False, blocking=False)

        self.compensate_bias = None

        if ("o_proj" in self.key or "down_proj" in self.key):
            self.nsamples = torch.zeros(in_features, dtype=torch.int32, device=self.weight.data.device)   
            if "wandasp" in self.prune_metric:
                self.scaler_inp = torch.zeros((cfg["max_seq_len"], in_features), device=self.weight.data.device, dtype=torch.float32)
                if "bias" in cfg["prune_method"]:
                    self.baseline_inp = torch.zeros((cfg["max_seq_len"], in_features), device=self.weight.data.device, dtype=torch.float32)
            elif "flap" in self.prune_metric:
                self.fluc_inp = torch.zeros((cfg["max_seq_len"], in_features), device=self.weight.data.device, dtype=torch.float32)
                self.baseline_inp = torch.zeros((cfg["max_seq_len"], in_features), device=self.weight.data.device, dtype=torch.float32)
            else:
                raise ValueError(f"Unknown pruning method")

    def update_global_metric_score_distribution_ema(self, inp, update_indices):
        
        # if cfg["cur_batch_index"] == 0 or cfg["history_update"] == False:
        #     return
        if cfg["history_update"] == False :
            return
        if len(inp.shape) == 2:
            raise ValueError(f"Input shape {inp.shape} is not supported. Please provide a 3D tensor.")

        batch_size = inp.shape[0]
        seq_len = inp.shape[1]
        momentum = cfg["ema_momentum"]
        cur_device = inp.device
        # if update_indices is not None:
        update_indices = update_indices.to(cur_device)
        
        self.nsamples = self.nsamples.to(cur_device)
        
        if "wandasp" in self.prune_metric:
            self.scaler_inp = self.scaler_inp.to(cur_device)
            self.scaler_inp[:seq_len, update_indices] *= momentum

            if "bias" in cfg["prune_method"]:
                self.baseline_inp = self.baseline_inp.to(cur_device)
                self.baseline_inp[:seq_len, update_indices] *= momentum
                self.baseline_inp[:seq_len, update_indices] += (1 - momentum) * (torch.mean(inp, dim=0) / batch_size)

            if cfg["calibration_stage"] == True:
                if cfg["is_L1"]:
                    norm_l1 = torch.linalg.vector_norm(inp, ord=1, dim=0)
                    self.scaler_inp[:seq_len, update_indices] += (1 - momentum) * norm_l1 / batch_size
                else:
                    norm_squared = torch.linalg.vector_norm(inp, ord=2, dim=0) ** 2
                    self.scaler_inp[:seq_len, update_indices] += (1 - momentum) * norm_squared / batch_size
            elif cfg["calibration_stage"] == False:
                if cfg["is_L1"]:
                    if cfg["pad_tokens"] is not None:
                        cfg["nonpad_tokens_denominator"] = cfg["nonpad_tokens_denominator"].to(cur_device)
                        # norm_squared = torch.clamp(torch.linalg.vector_norm(inp, ord=2, dim=0) ** 2, max=cfg["data_type_max"])
                        norm_l1 = torch.linalg.vector_norm(inp, ord=1, dim=0)
                        self.scaler_inp[:seq_len, update_indices] += (1 - momentum) * torch.clamp(norm_l1 / cfg["nonpad_tokens_denominator"], max=cfg["data_type_max"])
                    else:
                        norm_l1 = torch.linalg.vector_norm(inp, ord=1, dim=0)
                        self.scaler_inp[:seq_len, update_indices] += (1 - momentum) * torch.clamp(norm_l1 / batch_size, max=cfg["data_type_max"])
                else:
                    if cfg["pad_tokens"] is not None:
                        cfg["nonpad_tokens_denominator"] = cfg["nonpad_tokens_denominator"].to(cur_device)
                        norm_squared = torch.clamp(torch.linalg.vector_norm(inp, ord=2, dim=0) ** 2, max=cfg["data_type_max"])
                        self.scaler_inp[:seq_len, update_indices] += (1 - momentum) * torch.clamp(norm_squared / cfg["nonpad_tokens_denominator"], max=cfg["data_type_max"])
                    else:
                        norm_squared = torch.clamp(torch.linalg.vector_norm(inp, ord=2, dim=0) ** 2, max=cfg["data_type_max"])
                        self.scaler_inp[:seq_len, update_indices] += (1 - momentum) * torch.clamp(norm_squared / batch_size, max=cfg["data_type_max"])
        elif "flap" in self.prune_metric:
            self.baseline_inp = self.baseline_inp.to(cur_device)
            self.fluc_inp = self.fluc_inp.to(cur_device)

            old_baseline_inp = self.baseline_inp.clone()
            self.baseline_inp[:seq_len, update_indices] *= self.nsamples[update_indices] / (self.nsamples[update_indices] + batch_size)
            self.baseline_inp[:seq_len, update_indices] += torch.mean(inp, dim=0) / (self.nsamples[update_indices] + batch_size)
            
            if torch.all(self.nsamples == 0):
                pass
            else:
  
                self.fluc_inp[:seq_len, update_indices] *= (self.nsamples[update_indices] - 1) / (self.nsamples[update_indices] + batch_size - 1)
                self.fluc_inp[:seq_len, update_indices] += torch.sum((inp - torch.mean(self.baseline_inp[:seq_len, update_indices], dim=0).unsqueeze(0).unsqueeze(0)) * (inp - torch.mean(old_baseline_inp[:seq_len, update_indices], dim=0).unsqueeze(0).unsqueeze(0)), dim=0) / (self.nsamples[update_indices] + batch_size) 

    def update_global_metric_score_distribution(self, inp, update_indices):
        # if cfg["cur_batch_index"] == 0 or cfg["history_update"] == False :
        #     return
        if cfg["history_update"] == False :
            return
        
        if len(inp.shape) == 2:
            raise ValueError(f"Input shape {inp.shape} is not supported. Please provide a 3D tensor.")
        
        batch_size = inp.shape[0]
        seq_len = inp.shape[1]
        cur_device = inp.device
        update_indices = update_indices.to(cur_device)
        self.nsamples = self.nsamples.to(cur_device)
        
        if "wandasp" in self.prune_metric:
            
            self.scaler_inp = self.scaler_inp.to(cur_device)
            if "bias" in cfg["prune_method"]:
                self.baseline_inp = self.baseline_inp.to(cur_device)
                self.baseline_inp[:seq_len, update_indices] *= self.nsamples[update_indices] / (self.nsamples[update_indices] + batch_size)
                self.baseline_inp[:seq_len, update_indices] += torch.mean(inp, dim=0) / (self.nsamples[update_indices] + batch_size)
            if cfg["calibration_stage"] == True:
                self.scaler_inp[:seq_len, update_indices] *= self.nsamples[update_indices] / (self.nsamples[update_indices] + batch_size)
                if cfg['is_L1']:
                    norm_l1 = torch.linalg.vector_norm(inp, ord=1, dim=0)
                    denominator = (self.nsamples[update_indices] + batch_size)
                    self.scaler_inp[:seq_len, update_indices] += norm_l1 / denominator
                else:
                    norm_squared = torch.linalg.vector_norm(inp, ord=2, dim=0) ** 2
                    denominator = (self.nsamples[update_indices] + batch_size)
                    self.scaler_inp[:seq_len, update_indices] += norm_squared / denominator
            elif cfg["calibration_stage"] == False:
                self.scaler_inp = self.scaler_inp.to(cfg["data_type"])
                
                if cfg['is_L1']:
                    if cfg["pad_tokens"] is not None:
                        cfg["nonpad_tokens_denominator"] = cfg["nonpad_tokens_denominator"].to(cur_device)
                        self.scaler_inp[:seq_len, update_indices] *= self.nsamples[update_indices] / (self.nsamples[update_indices] + cfg["nonpad_tokens_denominator"])
                        norm_L1 = torch.linalg.vector_norm(inp, ord=1, dim=0)
                        denominator = (self.nsamples[update_indices] + cfg["nonpad_tokens_denominator"])
                        self.scaler_inp[:seq_len, update_indices] += torch.clamp(norm_L1 / denominator, max=cfg["data_type_max"])
                    else:
                        self.scaler_inp[:seq_len, update_indices] *= self.nsamples[update_indices] / (self.nsamples[update_indices] + batch_size)
                        norm_L1 = torch.linalg.vector_norm(inp, ord=1, dim=0)
                        denominator = (self.nsamples[update_indices] + batch_size)
                        self.scaler_inp[:seq_len, update_indices] += torch.clamp(norm_L1 / denominator, max=cfg["data_type_max"])
                else:
                    if cfg["pad_tokens"] is not None:
                        cfg["nonpad_tokens_denominator"] = cfg["nonpad_tokens_denominator"].to(cur_device)
                        self.scaler_inp[:seq_len, update_indices] *= self.nsamples[update_indices] / (self.nsamples[update_indices] + cfg["nonpad_tokens_denominator"])
                        norm_squared = torch.clamp(torch.linalg.vector_norm(inp, ord=2, dim=0) ** 2, max=cfg["data_type_max"])
                        denominator = (self.nsamples[update_indices] + cfg["nonpad_tokens_denominator"])
                        self.scaler_inp[:seq_len, update_indices] += torch.clamp(norm_squared / denominator, max=cfg["data_type_max"])
                    else:
                        self.scaler_inp[:seq_len, update_indices] *= self.nsamples[update_indices] / (self.nsamples[update_indices] + batch_size)
                        norm_squared = torch.clamp(torch.linalg.vector_norm(inp, ord=2, dim=0) ** 2, max=cfg["data_type_max"])
                        denominator = (self.nsamples[update_indices] + batch_size)
                        self.scaler_inp[:seq_len, update_indices] += torch.clamp(norm_squared / denominator, max=cfg["data_type_max"])


        elif "flap" in self.prune_metric:
            self.baseline_inp = self.baseline_inp.to(cur_device)
            self.fluc_inp = self.fluc_inp.to(cur_device)
            
            old_baseline_inp = self.baseline_inp.clone()
            self.baseline_inp[:seq_len, update_indices] *= self.nsamples[update_indices] / (self.nsamples[update_indices] + batch_size)
            self.baseline_inp[:seq_len, update_indices] += torch.mean(inp, dim=0) / (self.nsamples[update_indices] + batch_size)
            
            if torch.all(self.nsamples == 0):
                pass
            else:
                self.fluc_inp[:seq_len, update_indices] *= (self.nsamples[update_indices] - 1) / (self.nsamples[update_indices] + batch_size - 1)
                self.fluc_inp[:seq_len, update_indices] += torch.sum((inp - torch.mean(self.baseline_inp[:seq_len, update_indices], dim=0).unsqueeze(0).unsqueeze(0)) * (inp - torch.mean(old_baseline_inp[:seq_len, update_indices], dim=0).unsqueeze(0).unsqueeze(0)), dim=0) / (self.nsamples[update_indices] + batch_size)  
        if cfg["pad_tokens"] is not None:
            self.nsamples[update_indices] += cfg["nonpad_tokens_denominator"]
        else:
            self.nsamples[update_indices] += batch_size

    def get_global_metric_score_distribution(self, to_idx=None):
        if "wandasp" in self.prune_metric:
            return self.scaler_inp if to_idx is None else self.scaler_inp[:to_idx]
        elif "flap" in self.prune_metric:
            return self.fluc_inp if to_idx is None else self.fluc_inp[:to_idx]
        else:
            raise ValueError(f"Unknown pruning metric")

    def free(self):
        if hasattr(self, "baseline_inp"):
            self.baseline_inp = None
        if hasattr(self, "fluc_inp"):
            self.fluc_inp = None
        if hasattr(self, "scaler_inp"):
            self.scaler_inp = None
        if hasattr(self, "scaler_row"):
            self.scaler_row = None
        torch.cuda.empty_cache()  

    def return_global_metric_info(self):
        if ("o_proj" in self.key or "down_proj" in self.key):
            if "wandasp" in self.prune_metric:
                if "bias" in cfg["prune_method"]:
                   return {
                        "nsamples": self.nsamples,
                        "baseline_inp": self.baseline_inp,
                        "scaler_inp": self.scaler_inp
                    } 
                else:
                    return {
                        "nsamples": self.nsamples,
                        "scaler_inp": self.scaler_inp
                    }
            elif "flap" in self.prune_metric:
                return {
                    "nsamples": self.nsamples,
                    "baseline_inp": self.baseline_inp,
                    "fluc_inp": self.fluc_inp
                }
            else:
                raise ValueError(f"Unknown pruning metric")
        else:
            return None

    def set_global_metric_to_data_type(self):
        if ("o_proj" in self.key or "down_proj" in self.key):
            if "wandasp" in self.prune_metric:
                if "bias" in cfg["prune_method"]:
                    self.baseline_inp = self.baseline_inp.to(cfg["data_type"])
                    self.scaler_inp = self.scaler_inp.to(cfg["data_type"])
                else:
                    self.scaler_inp = self.scaler_inp.to(cfg["data_type"])
            elif "flap" in self.prune_metric:
                self.baseline_inp = self.baseline_inp.to(cfg["data_type"])
                self.fluc_inp = self.fluc_inp.to(cfg["data_type"])
            else:
                raise ValueError(f"Unknown pruning metric")
    
    def get_weight(self):        
        # for GQA, just return the original weight
        if ("k_proj" in self.key or "v_proj" in self.key) and self.is_GQA == True:
            return self.weight
        return self.weight
    
    def get_compensate_bias(self, x, weight, in_dim_indices):
        # return torch.zeros(weight.shape[0], device=x.device)
        if cfg["cur_batch_index"] == 0:
            return torch.zeros(weight.shape[0], device=x.device)

        calib = torch.mean(self.baseline_inp, dim=0)
        calib = calib.to(x.device)
        in_dim_indices = in_dim_indices.to(device=x.device)
        calib[in_dim_indices] = 0
        compensate_bias = F.linear(calib, weight, bias=None)
        return compensate_bias
    
    def forward(self, x: torch.Tensor, **kwargs):
        if cfg["use_triton"]: 
            return self.triton_forward(x, **kwargs)
        else:
            return self.torch_forward(x, **kwargs)
    
    def triton_forward(self, x: torch.Tensor, **kwargs):
        with torch.no_grad():
            previous_dtype = x.dtype
            if cfg["calibration_stage"] == True:
                if "o_proj" in self.key or "down_proj" in self.key:
                    self.update_global_metric_score_distribution(x, torch.arange(self.in_features, dtype=torch.long, device=x.device))
                    if "transpose_down_linear" in cfg and cfg["transpose_down_linear"]: 
                        result = torch.matmul(x, self.weight)
                    else: 
                        # result = down_channel_level_indices_linear(x, self.weight, torch.arange(self.weight.shape[1], device=self.weight.device))
                        # result = original_linear(x, self.weight)
                        result = F.linear(x, self.weight)
                else: 
                    # result = up_channel_level_indices_linear(x, self.weight, torch.arange(self.weight.shape[0], device=self.weight.device))
                    # result = original_linear(x, self.weight)
                    result = F.linear(x, self.weight)
                result = result.to(previous_dtype)
                return result
            elif cfg["calibration_stage"] == False:
                if cfg["cur_batch_index"] == 0:
                    self.set_global_metric_to_data_type()
                    
                if "probe" in cfg["prune_method"] and "cal_mlp_probe_out_dim_metric" in kwargs and kwargs["cal_mlp_probe_out_dim_metric"] == True:
                    result = F.linear(x, self.weight, bias=None)
                    result = result.to(previous_dtype)
                    return result
                elif "probe" in cfg["prune_method"] and "cal_attn_probe_out_dim_metric" in kwargs and kwargs["cal_attn_probe_out_dim_metric"] == True:               
                    result = F.linear(x, self.weight, bias=None)
                    result = result.to(previous_dtype)
                    return result

                # dynamic prune
                if cfg["mode"] == "dynamic":
                    # update metric score
                    if ("o_proj" in self.key or "down_proj" in self.key) and cfg["is_prune"]:
                        # nvtx.range_push("update metric")
                        if "runningmean" in cfg["prune_method"]:
                            if "in_dim_indices" in kwargs:
                                self.update_global_metric_score_distribution(x, kwargs["in_dim_indices"])
                            else:
                                self.update_global_metric_score_distribution(x, torch.arange(self.in_features, dtype=torch.long, device=x.device))
                        elif "ema" in cfg["prune_method"]:
                            if "in_dim_indices" in kwargs:
                                self.update_global_metric_score_distribution_ema(x, kwargs["in_dim_indices"])
                            else:
                                self.update_global_metric_score_distribution_ema(x, torch.arange(self.in_features, dtype=torch.long, device=x.device))
                        # nvtx.range_pop()
                    if "out_dim_indices" in kwargs:
                        result = up_channel_level_indices_linear(x, self.weight, kwargs["out_dim_indices"])
                    elif "in_dim_indices" in kwargs:
                        if "transpose_down_linear" in cfg and cfg["transpose_down_linear"]: 
                            result = down_channel_transpose_level_indices_linear(x, self.weight, kwargs["in_dim_indices"])
                        else: 
                            result = down_channel_level_indices_linear(x, self.weight, kwargs["in_dim_indices"])
                    elif "transpose_down_linear" in cfg and cfg["transpose_down_linear"] and ("o_proj" in self.key or "down_proj" in self.key):
                        result = torch.matmul(x, self.weight)
                    else:
                        result = F.linear(x, self.weight, bias=None)
                    result = result.to(previous_dtype)
                # static prune
                elif cfg["mode"] == "static":
                    print("When using triton, can not use static prune")
                    raise AssertionError
                return result
    
    # no bias in llama-2
    def torch_forward(self, x: torch.Tensor, **kwargs):
        with torch.no_grad():
            previous_dtype = x.dtype
            if cfg["calibration_stage"] == True:
                if "o_proj" in self.key or "down_proj" in self.key:
                    self.update_global_metric_score_distribution(x, torch.arange(self.in_features, dtype=torch.long, device=x.device))
                result = F.linear(x, self.weight, bias=None)
                result = result.to(previous_dtype)
                return result
            elif cfg["calibration_stage"] == False:
                if cfg["cur_batch_index"] == 0:
                    self.set_global_metric_to_data_type()
                    
                if "probe" in cfg["prune_method"] and "cal_mlp_probe_out_dim_metric" in kwargs and kwargs["cal_mlp_probe_out_dim_metric"] == True:
                    result = F.linear(x, self.weight, bias=None)
                    result = result.to(previous_dtype)
                    return result
                elif "probe" in cfg["prune_method"] and "cal_attn_probe_out_dim_metric" in kwargs and kwargs["cal_attn_probe_out_dim_metric"] == True:               
                    result = F.linear(x, self.weight, bias=None)
                    result = result.to(previous_dtype)
                    return result
                
                # dynamic prune
                if cfg["mode"] == "dynamic":
                    weight = self.get_weight()
                    # update metric score
                    if ("o_proj" in self.key or "down_proj" in self.key) and cfg["is_prune"]:
                        # nvtx.range_push("update metric")
                        if "runningmean" in cfg["prune_method"]:
                            if "in_dim_indices" in kwargs:
                                self.update_global_metric_score_distribution(x, kwargs["in_dim_indices"])
                            else:
                                self.update_global_metric_score_distribution(x, torch.arange(self.in_features, dtype=torch.long, device=x.device))
                        elif "ema" in cfg["prune_method"]:
                            if "in_dim_indices" in kwargs:
                                self.update_global_metric_score_distribution_ema(x, kwargs["in_dim_indices"])
                            else:
                                self.update_global_metric_score_distribution_ema(x, torch.arange(self.in_features, dtype=torch.long, device=x.device))
                        # nvtx.range_pop()
                    # up linear
                    if "out_dim_indices" in kwargs:
                        if ("k_proj" in self.key or "v_proj" in self.key) and self.is_GQA == True:
                            weight = weight
                        else:
                            weight = self.extract_out_dim_weight(weight, kwargs["out_dim_indices"])
                        result = F.linear(x, weight, bias=None)
                    # down linear
                    elif "in_dim_indices" in kwargs:
                        weight = self.extract_in_dim_weight(weight, kwargs["in_dim_indices"])
                        result = F.linear(x, weight, bias=None)
                        if "o_proj" in self.key or "down_proj" in self.key:
                            if "bias" in cfg["prune_method"]:
                                compensate_bias = self.get_compensate_bias(x, self.weight, kwargs["in_dim_indices"])
                                result += compensate_bias
                    else:
                        result = F.linear(x, weight, bias=None)
                    result = result.to(previous_dtype)
                # static prune
                elif cfg["mode"] == "static":
                    result = F.linear(x, self.weight, bias=None)
                    result = result.to(previous_dtype)
                return result