import os
import json
import torch
import numpy as np
from ..config import cfg
from ..module import nearest_multiple, check_skip_layers

BLOCK_SIZE = 64

class HiddenRepresentationPruning():

    def __init__(self, cfg, key, model_config=None):
        self.key = key
        self.prune_metric = cfg["prune_metric"]
        
        if cfg["prune_ratio_file"] is not None: 
            prune_ratio_list = json.load(open(cfg["prune_ratio_file"], "r"))
            assert len(prune_ratio_list) == model_config.num_hidden_layers - len(cfg["skip_layers"]), "Length in prune_ratio_file does not match number of (model layers - skip layers)"
            layer_idx = int(self.key.split("_")[-1])
            if layer_idx in cfg["skip_layers"]: 
                self.prune_ratio = 0.0
            else: 
                self.prune_ratio = prune_ratio_list[layer_idx - len(cfg["skip_layers"])]
        else: 
            if not isinstance(cfg["prune_ratio"], float) and len(cfg["prune_ratio"]) == 2:
                if "attention" in self.key:
                    self.prune_ratio = cfg["prune_ratio"][0]
                else:
                    self.prune_ratio = cfg["prune_ratio"][1]
            else:
                self.prune_ratio = cfg["prune_ratio"], model_config
        self.model_config = model_config

    def adjust_prune_ratio(self, prune_ratio, model_config):

        if model_config is not None:
            if prune_ratio == 0:
                return 0
            
            num_hidden_layers = model_config.num_hidden_layers
            prune_ratio = round(num_hidden_layers / (num_hidden_layers - (len(cfg["skip_layers"]))) * prune_ratio, 2) 
            return prune_ratio
        else:
            return cfg["prune_ratio"]
    
    def handle_attn_pad_tokens(self, probe_out, bsz_selected_indices, seq_selected_indices):
        # handling padding tokens
        if cfg["pad_tokens"] is not None:
            cfg["pad_tokens"] = cfg["pad_tokens"].to(probe_out.device)
            if bsz_selected_indices is not None and seq_selected_indices is not None:
                ii, jj = torch.meshgrid(bsz_selected_indices, seq_selected_indices, indexing="ij")
                probe_pad_tokens = cfg["pad_tokens"][ii, jj]
            elif bsz_selected_indices is None:
                probe_pad_tokens = cfg["pad_tokens"][:, seq_selected_indices]
            elif seq_selected_indices is None:
                probe_pad_tokens = cfg["pad_tokens"][bsz_selected_indices, :]

            probe_out[probe_pad_tokens] = 0
            probe_num = (probe_out.size(0) - torch.sum(probe_pad_tokens, dim=0) + 1e-3).to(probe_out.dtype).unsqueeze(1)
            return probe_out, probe_num
        else:
            return probe_out, probe_out.size(0)
    
    def handle_mlp_pad_tokens(self, probe_out: torch.Tensor, bsz_selected_indices: torch.Tensor = None, seq_selected_indices: torch.Tensor = None):
        if cfg["pad_tokens"] is not None:
            cfg["pad_tokens"] = cfg["pad_tokens"].to(probe_out.device)
            if bsz_selected_indices is not None and seq_selected_indices is not None:
                ii, jj = torch.meshgrid(bsz_selected_indices, seq_selected_indices, indexing="ij")
                probe_pad_tokens = cfg["pad_tokens"][ii, jj]
            elif bsz_selected_indices is None:
                probe_pad_tokens = cfg["pad_tokens"][:, seq_selected_indices]
            elif seq_selected_indices is None:
                probe_pad_tokens = cfg["pad_tokens"][bsz_selected_indices, :]
            else:
                bsz_selected_indices = torch.arange(probe_out.shape[0])
                seq_selected_indices = torch.arange(probe_out.shape[1])
                ii, jj = torch.meshgrid(bsz_selected_indices, seq_selected_indices, indexing="ij")
                probe_pad_tokens = cfg["pad_tokens"][ii, jj]

            probe_num = (probe_out.size(0) - torch.sum(probe_pad_tokens, dim=0) + 1e-3).to(probe_out.dtype).unsqueeze(1)
            return probe_out, probe_num
        else:
            return probe_out, probe_out.size(0)

    def cal_attn_prune_metric(self, probe_out, weight, metric_type, bsz_selected_indices, seq_selected_indices, global_metric_score_distribution=None):
        probe_out, probe_num = self.handle_attn_pad_tokens(probe_out, bsz_selected_indices, seq_selected_indices)
        device = probe_out.device

        if seq_selected_indices == None:
            seq_selected_indices = torch.arange(probe_out.shape[1], device=device)
        
        if cfg["is_L1"]:
            norm_probe_val = torch.linalg.vector_norm(probe_out, ord=1, dim=0) / probe_num
        else:
            # norm_probe_val = torch.linalg.vector_norm(probe_out, ord=2, dim=0) ** 2 / probe_num
            norm_probe_val = torch.sum(probe_out**2, dim=0) / probe_num
        norm_probe_val = torch.clamp(norm_probe_val, max=cfg["data_type_max"])
        if "ecovla" in metric_type:
            combined_probe_out = None
            if global_metric_score_distribution is not None:
                cur_global_metric_score_distribution = global_metric_score_distribution[seq_selected_indices, :] if seq_selected_indices is not None else global_metric_score_distribution
                if cfg["probe_ratio_flag"]:
                    combined_probe_out = cfg["probefixratio"] * cur_global_metric_score_distribution + (1-cfg["probefixratio"]) * norm_probe_val
                
                # dynaratio, since all nonnegative, no need to abs
                else: 
                    denominator = norm_probe_val + cur_global_metric_score_distribution
                    # avoid nan, nan is always a problem in float16
                    # tend to give the global metric more weight
                    probe_ratio = norm_probe_val / (denominator + 1e-6)
                    global_ratio = 1 - probe_ratio
                    combined_probe_out = global_ratio * cur_global_metric_score_distribution + probe_ratio * norm_probe_val
                combined_probe_out = torch.sum(combined_probe_out, dim=0).clamp(max=cfg["data_type_max"])
                if cfg["is_L1"]:
                    probe_out_dim_metric = torch.linalg.vector_norm((combined_probe_out.reshape((1,-1)) * torch.abs(weight)).reshape(self.model_config.hidden_size, self.model_config.num_attention_heads, -1), ord=1, dim=(0, 2)).clamp(max=cfg["data_type_max"])
                    # dim = 0 if not cfg["transpose_down_linear"] else 1
                    # probe_out_dim_metric = (combined_probe_out**2 * torch.sum(weight**2, dim=dim)).reshape(self.model_config.num_attention_heads, -1).sum(dim=-1).clamp(max=cfg["data_type_max"])
                else:
                    probe_out_dim_metric = torch.linalg.vector_norm((combined_probe_out.reshape((1,-1)) * torch.pow(weight, 2)).reshape(self.model_config.hidden_size, self.model_config.num_attention_heads, -1), ord=2, dim=(0, 2)).clamp(max=cfg["data_type_max"])
                    # dim = 0 if not cfg["transpose_down_linear"] else 1
                    # probe_out_dim_metric = (combined_probe_out**2 * torch.sum(weight**4, dim=dim)).reshape(self.model_config.num_attention_heads, -1).sum(dim=-1).clamp(max=cfg["data_type_max"])

                output_root_dir = ""
                layer_idx = int(self.key.split("_")[-1])
                if cfg["current_action_step"] == 0 and layer_idx == 3: 
                    num_episodes = len(os.listdir(output_root_dir))
                    output_dir = os.path.join(output_root_dir, "episode_{:04d}".format(num_episodes+1))
                    os.makedirs(output_dir)
                else: 
                    num_episodes = len(os.listdir(output_root_dir))
                    output_dir = os.path.join(output_root_dir, "episode_{:04d}".format(num_episodes))
                output_dir = os.path.join(output_dir, "attn_{:02d}".format(layer_idx))
                if not os.path.exists(output_dir): 
                    os.makedirs(output_dir)
                output_path = os.path.join(output_dir, "step_{:03d}".format(cfg["current_action_step"]))
                np.save(output_path, probe_out_dim_metric.cpu().detach().numpy())
                return probe_out_dim_metric
            else:
                norm_probe_val = torch.sum(norm_probe_val, dim=0).clamp(max=cfg['data_type_max'])
                if cfg["is_L1"]:
                    probe_out_dim_metric = torch.linalg.vector_norm((norm_probe_val.reshape((1,-1)) * torch.abs(weight)).reshape(self.model_config.hidden_size, self.model_config.num_attention_heads, -1), ord= 1, dim=(0, 2)).clamp(max=cfg["data_type_max"])
                    # dim = 0 if not cfg["transpose_down_linear"] else 1
                    # probe_out_dim_metric = (norm_probe_val**2 * torch.sum(weight**2, dim=dim)).reshape(self.model_config.num_attention_heads, -1).sum(dim=-1).clamp(max=cfg["data_type_max"])
                else:
                    probe_out_dim_metric = torch.linalg.vector_norm((norm_probe_val.reshape((1,-1)) * torch.pow(weight, 2)).reshape(self.model_config.hidden_size, self.model_config.num_attention_heads, -1), ord=2, dim=(0, 2)).clamp(max=cfg["data_type_max"])
                    # dim = 0 if not cfg["transpose_down_linear"] else 1
                    # probe_out_dim_metric = (norm_probe_val**2 * torch.sum(weight**4, dim=dim)).reshape(self.model_config.num_attention_heads, -1).sum(dim=-1).clamp(max=cfg["data_type_max"])
                return probe_out_dim_metric
        elif "wandasp" in metric_type:
            combined_probe_out = None
            if global_metric_score_distribution is not None:
                cur_global_metric_score_distribution = global_metric_score_distribution[seq_selected_indices, :] if seq_selected_indices is not None else global_metric_score_distribution
                norm_probe_out_square = torch.clamp(torch.linalg.vector_norm(probe_out, ord=2, dim=0) ** 2 / probe_num, max=cfg["data_type_max"])

                if cfg["probe_ratio_flag"]:
                    combined_probe_out = cfg["probefixratio"] * cur_global_metric_score_distribution + (1-cfg["probefixratio"]) * norm_probe_out_square
                # dynaratio, since all nonnegative, no need to abs
                else:  
                    denominator = norm_probe_out_square + cur_global_metric_score_distribution
                    # avoid nan, nan is always a problem in float16
                    # tend to give the global metric more weight
                    probe_ratio = norm_probe_out_square / (denominator + 1e-6)
                    global_ratio = 1 - probe_ratio
                    combined_probe_out = global_ratio * cur_global_metric_score_distribution + probe_ratio * norm_probe_out_square

                combined_probe_out = torch.sum(combined_probe_out, dim=0).clamp(max=cfg["data_type_max"])
                probe_out_dim_metric = (torch.sqrt(combined_probe_out.reshape((1,-1))) * torch.abs(weight)).sum(dim=0).clamp(max=cfg["data_type_max"])
                return probe_out_dim_metric
            else:
                norm_probe_out_square = torch.clamp(torch.linalg.vector_norm(probe_out, ord=2, dim=0) ** 2 / probe_num, max=cfg["data_type_max"])
                norm_probe_out_square = torch.sum(norm_probe_out_square, dim=0).clamp(max=cfg["data_type_max"])
                probe_out_dim_metric = (torch.sqrt(norm_probe_out_square.reshape((1,-1))) * torch.abs(weight)).sum(dim=0).clamp(max=cfg["data_type_max"])
                return probe_out_dim_metric
        elif "flap" in metric_type:
            combined_probe_out = None
            if global_metric_score_distribution is not None:
                cur_global_metric_score_distribution = global_metric_score_distribution[seq_selected_indices, :] if seq_selected_indices is not None else global_metric_score_distribution
                mean_probe_out = torch.mean(probe_out, dim=(0, 1))
                probe_variance = torch.sum(torch.pow(probe_out - mean_probe_out.reshape(1, 1, -1), 2), dim=0).clamp(max=cfg["data_type_max"]) / probe_num

                if "probefixratio" in cfg["prune_method"]:
                    combined_probe_out = cfg["probefixratio"] * cur_global_metric_score_distribution + (1-cfg["probefixratio"]) * probe_variance
                # dynaratio, since all nonnegative, no need to abs
                else:  
                    denominator = probe_variance + cur_global_metric_score_distribution
                    # avoid nan, nan is always a problem in float16
                    # tend to give the global metric more weight
                    probe_ratio = probe_variance / (denominator + 1e-6)
                    global_ratio = 1 - probe_ratio
                    combined_probe_out = global_ratio * cur_global_metric_score_distribution + probe_ratio * probe_variance

                combined_probe_out = torch.sum(combined_probe_out, dim=0).clamp(max=cfg["data_type_max"])
                probe_out_dim_metric = (combined_probe_out * torch.sum(torch.pow(weight, 2), dim=0)).clamp(max=cfg["data_type_max"])
                return probe_out_dim_metric
            else:
                mean_probe_out = torch.mean(probe_out, dim=(0, 1)).clamp(max=cfg["data_type_max"])
                probe_variance = torch.sum(torch.pow(probe_out - mean_probe_out.reshape(1, 1, -1), 2), dim=0).clamp(max=cfg["data_type_max"]) / probe_num
                probe_out_dim_metric = (probe_variance * torch.sum(torch.pow(weight, 2), dim=0)).clamp(max=cfg["data_type_max"])
                return probe_out_dim_metric

    def cal_attn_prune_metric_activation(self, probe_out, weight, metric_type, bsz_selected_indices, seq_selected_indices, global_metric_score_distribution=None):
        probe_out, probe_num = self.handle_attn_pad_tokens(probe_out, bsz_selected_indices, seq_selected_indices)
        device = probe_out.device
        if seq_selected_indices == None:
            seq_selected_indices = torch.arange(probe_out.shape[1], device=device)
        if "ecovla" in metric_type:
            combined_probe_out = None
            if global_metric_score_distribution is not None:
                cur_global_metric_score_distribution = global_metric_score_distribution[seq_selected_indices, :] if seq_selected_indices is not None else global_metric_score_distribution
                if probe_out.shape[0] == 1:
                    # norm_probe_out_square = (probe_out ** 2 / probe_num).clamp(max=cfg["data_type_max"])
                    norm_probe_out_square = torch.sum(probe_out ** 2, dim=0).clamp(max=cfg["data_type_max"])
                else:
                    norm_probe_out_square = torch.clamp(
                        torch.linalg.vector_norm(probe_out, ord=2, dim=0) ** 2 / probe_num,
                        max=cfg["data_type_max"]
                    )
                if cfg["probe_ratio_flag"]:
                    combined_probe_out = cfg["probefixratio"] * cur_global_metric_score_distribution + (1-cfg["probefixratio"]) * norm_probe_out_square
                    # dynaratio, since all nonnegative, no need to abs
                else:  
                    denominator = norm_probe_out_square + cur_global_metric_score_distribution
                    # avoid nan, nan is always a problem in float16
                    # tend to give the global metric more weight
                    probe_ratio = norm_probe_out_square / (denominator + 1e-6)
                    global_ratio = 1 - probe_ratio
                    combined_probe_out = global_ratio * cur_global_metric_score_distribution + probe_ratio * norm_probe_out_square

                combined_probe_out = torch.sum(combined_probe_out, dim=0).clamp(max=cfg["data_type_max"])
                # (combined_probe_out[0]).clamp(max=cfg["data_type_max"])
                # combined_probe_out = torch.sum(combined_probe_out, dim=0).clamp(max=cfg["data_type_max"])
                probe_out_dim_metric = torch.linalg.vector_norm((combined_probe_out.reshape((1,-1)) * torch.pow(weight, 2)).reshape(self.model_config.hidden_size, self.model_config.num_attention_heads, -1), ord=2, dim=(0, 2)).clamp(max=cfg["data_type_max"])

                return probe_out_dim_metric
            else:
                norm_probe_out_square = torch.clamp(torch.linalg.vector_norm(probe_out, ord=2, dim=0) ** 2 / probe_num, max=cfg["data_type_max"])
                norm_probe_out_square = torch.sum(norm_probe_out_square, dim=0).clamp(max=cfg["data_type_max"])
                probe_out_dim_metric = torch.linalg.vector_norm((norm_probe_out_square.reshape((1,-1)) * torch.pow(weight, 2)).reshape(self.model_config.hidden_size, self.model_config.num_attention_heads, -1), ord=2, dim=(0, 2)).clamp(max=cfg["data_type_max"])
                return probe_out_dim_metric
    
    def cal_mlp_prune_metric(
        self, 
        probe_out, 
        weight, 
        metric_type, 
        bsz_selected_indices: torch.Tensor = None, 
        seq_selected_indices: torch.Tensor = None, 
        global_metric_score_distribution = None
    ):
        probe_out, probe_num = self.handle_mlp_pad_tokens(probe_out, bsz_selected_indices, seq_selected_indices)
        device = probe_out.device
        
        if seq_selected_indices == None:
            seq_selected_indices = torch.arange(probe_out.shape[1], device=device)

        if cfg["is_L1"]:
            norm_probe_val = torch.linalg.vector_norm(probe_out, ord=1, dim=0) / probe_num
        else:
            # norm_probe_val = torch.linalg.vector_norm(probe_out, ord=2, dim=0) ** 2 / probe_num
            norm_probe_val = torch.sum(probe_out**2, dim=0) / probe_num
        norm_probe_val = torch.clamp(norm_probe_val, max=cfg["data_type_max"])

        if "ecovla" in metric_type:
            combined_probe_out = None
            if global_metric_score_distribution is not None:
                cur_global_metric_score_distribution = global_metric_score_distribution[seq_selected_indices, :] if seq_selected_indices is not None else global_metric_score_distribution

                # temp = torch.sum(cur_global_metric_score_distribution, dim=1)
                if cfg["probe_ratio_flag"]:
                    combined_probe_out = cfg["probefixratio"] * cur_global_metric_score_distribution + (1-cfg["probefixratio"]) * norm_probe_val
                # dynaratio, since all nonnegative, no need to abs
                else:  
                    denominator = norm_probe_val + cur_global_metric_score_distribution
                    # avoid nan, nan is always a problem in float16
                    # tend to give the global metric more weight
                    probe_ratio = norm_probe_val / (denominator + 1e-6)
                    global_ratio = 1 - probe_ratio
                    combined_probe_out = global_ratio * cur_global_metric_score_distribution + probe_ratio * norm_probe_val
                combined_probe_out = torch.sum(combined_probe_out, dim=0).clamp(max=cfg["data_type_max"])
                if cfg["is_L1"]: 
                    probe_out_dim_metric = torch.linalg.vector_norm(
                        (combined_probe_out.reshape((1,-1)) * torch.abs(weight)), 
                        ord=1, 
                        dim=0
                    ).clamp(max=cfg["data_type_max"])
                    # dim = 0 if not cfg["transpose_down_linear"] else 1
                    # probe_out_dim_metric = (combined_probe_out**2 * torch.sum(weight**2, dim=dim)).clamp(max=cfg["data_type_max"])
                else: 
                    probe_out_dim_metric = torch.linalg.vector_norm(
                        (combined_probe_out.reshape((1,-1)) * torch.pow(weight, 2)), 
                        ord=2, 
                        dim=0
                    ).clamp(max=cfg["data_type_max"])
                    # dim = 0 if not cfg["transpose_down_linear"] else 1
                    # probe_out_dim_metric = (combined_probe_out**2 * torch.sum(weight**4, dim=dim)).clamp(max=cfg["data_type_max"])

                output_root_dir = ""
                num_episodes = len(os.listdir(output_root_dir))
                output_dir = os.path.join(output_root_dir, "episode_{:04d}".format(num_episodes))
                layer_idx = int(self.key.split("_")[-1])
                output_dir = os.path.join(output_dir, "mlp_{:02d}".format(layer_idx))
                if not os.path.exists(output_dir): 
                    os.makedirs(output_dir)
                output_path = os.path.join(output_dir, "step_{:03d}".format(cfg["current_action_step"]))
                np.save(output_path, probe_out_dim_metric.cpu().detach().numpy())
                return probe_out_dim_metric
            else:
                if cfg["is_L1"]: 
                    probe_out_dim_metric = torch.linalg.vector_norm(
                        (norm_probe_val.reshape((1,-1)) * torch.abs(weight)), 
                        ord=1, 
                        dim=0
                    ).clamp(max=cfg["data_type_max"])
                    # dim = 0 if not cfg["transpose_down_linear"] else 1
                    # probe_out_dim_metric = (norm_probe_val**2 * torch.sum(weight**2, dim=dim)).clamp(max=cfg["data_type_max"])
                else: 
                    probe_out_dim_metric = torch.linalg.vector_norm(
                        (norm_probe_val.reshape((1,-1)) * torch.pow(weight, 2)), 
                        ord=2, 
                        dim=0
                    ).clamp(max=cfg["data_type_max"])
                    # dim = 0 if not cfg["transpose_down_linear"] else 1
                    # probe_out_dim_metric = (norm_probe_val**2 * torch.sum(weight**4, dim=dim)).clamp(max=cfg["data_type_max"])
                return probe_out_dim_metric
        elif "wandasp" in metric_type:
            combined_probe_out = None
            if global_metric_score_distribution is not None:
                cur_global_metric_score_distribution = global_metric_score_distribution[seq_selected_indices, :] if seq_selected_indices is not None else global_metric_score_distribution
                cur_global_metric_score_distribution = cur_global_metric_score_distribution.to(probe_out.device)
                norm_probe_out_square = torch.clamp(torch.linalg.vector_norm(probe_out, ord=2, dim=0) ** 2 / probe_num, max=cfg["data_type_max"])

                if "probefixratio" in cfg["prune_method"]:
                    combined_probe_out = cfg["probefixratio"] * cur_global_metric_score_distribution + (1-cfg["probefixratio"]) * norm_probe_out_square
                # dynaratio, since all nonnegative, no need to abs
                else:  
                    denominator = norm_probe_out_square + cur_global_metric_score_distribution
                    # avoid nan, nan is always a problem in float16
                    # tend to give the global metric more weight
                    probe_ratio = norm_probe_out_square / (denominator + 1e-6)
                    global_ratio = 1 - probe_ratio
                    combined_probe_out = global_ratio * cur_global_metric_score_distribution + probe_ratio * norm_probe_out_square

                combined_probe_out = torch.sum(combined_probe_out, dim=0).clamp(max=cfg["data_type_max"])
                probe_out_dim_metric = (torch.sqrt(combined_probe_out.reshape((1,-1))) * torch.abs(weight)).sum(dim=0).clamp(max=cfg["data_type_max"])
                return probe_out_dim_metric
            else:
                norm_probe_out_square = torch.clamp(torch.linalg.vector_norm(probe_out, ord=2, dim=0) ** 2 / probe_num, max=cfg["data_type_max"])
                norm_probe_out_square = torch.sum(norm_probe_out_square, dim=0).clamp(max=cfg["data_type_max"])
                probe_out_dim_metric = (torch.sqrt(norm_probe_out_square.reshape((1,-1))) * torch.abs(weight)).sum(dim=0).clamp(max=cfg["data_type_max"])
                return probe_out_dim_metric
        elif "flap" in metric_type:
            combined_probe_out = None
            if global_metric_score_distribution is not None:
                cur_global_metric_score_distribution = global_metric_score_distribution[seq_selected_indices, :] if seq_selected_indices is not None else global_metric_score_distribution
                cur_global_metric_score_distribution = cur_global_metric_score_distribution.to(probe_out.device)
                mean_probe_out = torch.mean(probe_out, dim=(0, 1))
                probe_variance = torch.sum(torch.pow(probe_out - mean_probe_out.reshape(1, 1, -1), 2), dim=0).clamp(max=cfg["data_type_max"]) / probe_num

                if "probefixratio" in cfg["prune_method"]:
                    combined_probe_out = cfg["probefixratio"] * cur_global_metric_score_distribution + (1-cfg["probefixratio"]) * probe_variance
                # dynaratio, since all nonnegative, no need to abs
                else:  
                    denominator = probe_variance + cur_global_metric_score_distribution
                    # avoid nan, nan is always a problem in float16
                    # tend to give the global metric more weight
                    probe_ratio = probe_variance / (denominator + 1e-6)
                    global_ratio = 1 - probe_ratio
                    combined_probe_out = global_ratio * cur_global_metric_score_distribution + probe_ratio * probe_variance

                combined_probe_out = torch.sum(combined_probe_out, dim=0).clamp(max=cfg["data_type_max"])
                probe_out_dim_metric = (combined_probe_out * torch.sum(torch.pow(weight, 2), dim=0)).clamp(max=cfg["data_type_max"])
                return probe_out_dim_metric
            else:
                mean_probe_out = torch.mean(probe_out, dim=(0, 1), max=cfg["data_type_max"])
                probe_variance = torch.sum(torch.pow(probe_out - mean_probe_out.reshape(1, 1, -1), 2), dim=0).clamp(max=cfg["data_type_max"]) / probe_num
                probe_out_dim_metric = (probe_variance * torch.sum(torch.pow(weight, 2), dim=0)).clamp(max=cfg["data_type_max"])
                return probe_out_dim_metric
  
    def cal_mlp_prune_metric_activation(
        self, 
        probe_out, 
        weight, 
        metric_type, 
        bsz_selected_indices: torch.Tensor = None, 
        seq_selected_indices: torch.Tensor = None, 
        global_metric_score_distribution = None
    ):
        probe_out, probe_num = self.handle_mlp_pad_tokens(probe_out, bsz_selected_indices, seq_selected_indices)
        device = probe_out.device
        if seq_selected_indices == None:
            seq_selected_indices = torch.arange(probe_out.shape[1], device=device)
        if "ecovla" in metric_type:
            combined_probe_out = None
            if global_metric_score_distribution is not None:
                cur_global_metric_score_distribution = global_metric_score_distribution[seq_selected_indices, :] if seq_selected_indices is not None else global_metric_score_distribution

                # norm_probe_out_square = torch.clamp(torch.linalg.vector_norm(probe_out, ord=2, dim=0) ** 2 / probe_num, max=cfg["data_type_max"])
                if probe_out.shape[0] == 1:
                    norm_probe_out_square = torch.clamp(
                        (probe_out[0] ** 2) / probe_num,
                        max=cfg["data_type_max"]
                    )
                else:
                    norm_probe_out_square = torch.clamp(
                        torch.linalg.vector_norm(probe_out, ord=2, dim=0) ** 2 / probe_num,
                        max=cfg["data_type_max"]
                    )
                if cfg["probe_ratio_flag"]:
                    combined_probe_out = cfg["probefixratio"] * cur_global_metric_score_distribution + (1-cfg["probefixratio"]) * norm_probe_out_square
                # dynaratio, since all nonnegative, no need to abs
                else:  
                    denominator = norm_probe_out_square + cur_global_metric_score_distribution
                    # avoid nan, nan is always a problem in float16
                    # tend to give the global metric more weight
                    probe_ratio = norm_probe_out_square / (denominator + 1e-6)
                    global_ratio = 1 - probe_ratio
                    combined_probe_out = global_ratio * cur_global_metric_score_distribution + probe_ratio * norm_probe_out_square
                # batch = 1
                combined_probe_out = torch.sum(combined_probe_out, dim=0).clamp(max=cfg["data_type_max"])
                # batch = 1
                # (combined_probe_out[0]).clamp(max=cfg["data_type_max"])
                # probe_out_dim_metric = torch.norm(probe_out, p=1, dim=1).squeeze(0)
                probe_out_dim_metric = torch.linalg.vector_norm((combined_probe_out.reshape((1,-1)) * torch.pow(weight, 2)), ord=2, dim=0).clamp(max=cfg["data_type_max"])
                
                # probe_out_dim_metric = (
                #     combined_probe_out.reshape(-1) * weight.pow(2)
                # ).abs().clamp(max=cfg["data_type_max"])


                return probe_out_dim_metric
            else:
                norm_probe_out_square = torch.clamp(torch.linalg.vector_norm(probe_out, ord=2, dim=0) ** 2 / probe_num, max=cfg["data_type_max"])
                norm_probe_out_square = torch.sum(norm_probe_out_square, dim=0).clamp(max=cfg["data_type_max"])
                probe_out_dim_metric = torch.linalg.vector_norm((norm_probe_out_square.reshape((1,-1)) * torch.pow(weight, 2)), ord=2, dim=0).clamp(max=cfg["data_type_max"])
                
                return probe_out_dim_metric
          
    def cal_attn_calib_prune_metric(self, calib, weight, metric_type):
        calib = calib.to(weight.device)
        calib = calib.to(torch.float32)
        weight = weight.to(torch.float32)
        if "ecovlasum" in metric_type:
            calib = torch.sum(calib, dim=0).clamp(max=cfg["data_type_max"])
            probe_out_dim_metric = torch.sum((calib.reshape((1,-1)) * torch.pow(weight, 2)).reshape(self.model_config.hidden_size, self.model_config.num_attention_heads, -1), dim=(0, 2)).clamp(max=cfg["data_type_max"])
        elif "ecovla" in metric_type:
            calib = torch.sum(calib, dim=0).clamp(max=cfg["data_type_max"])
            probe_out_dim_metric = torch.linalg.vector_norm((calib.reshape((1,-1)) * torch.pow(weight, 2)).reshape(self.model_config.hidden_size, self.model_config.num_attention_heads, -1), ord=2, dim=(0, 2)).clamp(max=cfg["data_type_max"])
        elif "wandasp" in metric_type:
            calib = torch.sum(calib, dim=0).clamp(max=cfg["data_type_max"])
            probe_out_dim_metric = (torch.sqrt(calib).reshape((1,-1)) * torch.abs(weight)).sum(dim=0).clamp(max=cfg["data_type_max"])
        elif "flap" in metric_type:
            calib = torch.sum(calib, dim=0).clamp(max=cfg["data_type_max"])
            probe_out_dim_metric = (calib * torch.sum(torch.pow(weight, 2), dim=0)).clamp(max=cfg["data_type_max"])
    
        return probe_out_dim_metric

    def cal_mlp_calib_prune_metric(self, calib, weight, metric_type):
        calib = calib.to(weight.device)
        calib = calib.to(torch.float32)

        weight = weight.to(torch.float32)
        if "ecovla" in metric_type:
            calib = torch.sum(calib, dim=0).clamp(max=cfg["data_type_max"])
            probe_out_dim_metric = torch.linalg.vector_norm((calib.reshape((1,-1)) * torch.pow(weight, 2)), ord=2, dim=0).clamp(max=cfg["data_type_max"])
        elif "wandasp" in metric_type:
            calib = torch.sum(calib, dim=0).clamp(max=cfg["data_type_max"])
            probe_out_dim_metric = (torch.sqrt(calib).reshape((1,-1)) * torch.abs(weight)).sum(dim=0).clamp(max=cfg["data_type_max"])
        elif "flap" in metric_type:
            calib = torch.sum(calib, dim=0).clamp(max=cfg["data_type_max"])
            probe_out_dim_metric = (calib * torch.sum(torch.pow(weight, 2), dim=0)).clamp(max=cfg["data_type_max"])
        
        return probe_out_dim_metric
    
    def sort_attn_metric(self, probe_out_dim_metric, num_heads, head_dim, prune_way, prune_module, multiple, pruning_ratio=None):
        if prune_way == None:
            return None, None, num_heads

        prune_ratio = pruning_ratio if pruning_ratio is not None else self.prune_ratio
        if "whole" in prune_way:    
            probe_out_dim_metric = probe_out_dim_metric.reshape(num_heads, -1)
            summed_metrics = torch.clamp(probe_out_dim_metric.sum(dim=-1), max=cfg["data_type_max"])
            sorted_value, sorted_indices = torch.sort(summed_metrics, dim=0)
            num_prune_heads = int(prune_ratio * num_heads)
            heads_to_preserve = sorted_indices[num_prune_heads:]
            heads_to_preserve = torch.sort(heads_to_preserve)[0]
            full_indices_to_preserve = (torch.arange(head_dim, device=probe_out_dim_metric.device) + heads_to_preserve.unsqueeze(1) * head_dim).view(-1)
            num_heads = num_heads - num_prune_heads
            return full_indices_to_preserve, num_heads, heads_to_preserve
        else:
            raise NotImplementedError

    def sort_mlp_metric(self, probe_out_dim_metric, multiple, pruning_ratio=None):        
        prune_ratio = pruning_ratio if pruning_ratio is not None else self.prune_ratio
        sorted_value, sorted_indices = torch.sort(probe_out_dim_metric, dim=0)
        num_prune = int(prune_ratio * probe_out_dim_metric.shape[0])
        num_prune = num_prune if num_prune % BLOCK_SIZE == 0 else (num_prune // BLOCK_SIZE + 1) * BLOCK_SIZE
        if "noround" in cfg["prune_method"]:
            pass
        else:
            num_prune = nearest_multiple(num_prune, probe_out_dim_metric.shape[0], multiple)
        # return torch.sort(sorted_indices[num_prune:])[0], torch.sort(sorted_indices[:num_prune])[0]
        return torch.sort(sorted_indices[num_prune:])[0], sorted_indices[:num_prune]
        # return sorted_indices[num_prune:], sorted_indices[:num_prune]
    
    def cal_multiple_flops_for_flap(self, model):
        if "csr" in cfg["task_name"]:
            approx_seq_len = 100
        elif "clm" in cfg["task_name"]:
            approx_seq_len = cfg["max_seq_len"]

        if "llama" in cfg["model_name"]:
            head_dim = model.config.hidden_size // model.config.num_attention_heads
            attn_flops = approx_seq_len * model.config.hidden_size * head_dim * 4 + approx_seq_len ** 2 * head_dim + approx_seq_len ** 2 * head_dim
            mlp_flops = approx_seq_len * model.config.hidden_size * 3 
            multiples = attn_flops / mlp_flops
            # GQA
            if ('llama-3' in cfg["model_name"] or 'llama-2-70b' in cfg["model_name"]) and ("q_proj" in cfg["cust_tgt_modules"] or "k_proj" in cfg["cust_tgt_modules"] or "v_proj" in cfg["cust_tgt_modules"] or "o_proj" in cfg["cust_tgt_modules"]) :
                head_dim = model.config.hidden_size // model.config.num_attention_heads
                attn_flops = approx_seq_len * model.config.hidden_size * head_dim * 2 + approx_seq_len ** 2 * head_dim + approx_seq_len ** 2 * head_dim
                mlp_flops = approx_seq_len * model.config.hidden_size * 3 
                multiples = attn_flops / mlp_flops
        elif "opt" in cfg["model_name"]:
            head_dim = model.config.hidden_size // model.config.num_attention_heads
            attn_flops = approx_seq_len * model.config.hidden_size * head_dim * 4 + approx_seq_len ** 2 * head_dim + approx_seq_len ** 2 * head_dim
            mlp_flops = approx_seq_len * model.config.hidden_size * 2 
            multiples = attn_flops / mlp_flops
        
        return multiples

    def flap_ratio(self, model, logger=None):
        
        prune_ratio = self.adjust_prune_ratio(cfg["prune_ratio"], model.config)
        attn_metric_list, mlp_metric_list = [], []
        standarlization = lambda x: (x - torch.mean(x, axis=1, keepdim=True)) / torch.std(x, axis=1, keepdim=True)

        if "llama" in cfg["model_name"]:
            for name, module in model.named_modules():   
                if "down_proj" not in name and "o_proj" not in name:
                    continue

                if "down_proj" in name and "down_proj" not in cfg["cust_tgt_modules"]:
                    continue
                elif "o_proj" in name and "o_proj" not in cfg["cust_tgt_modules"]:
                    continue

                if check_skip_layers(name):
                    continue

                if "o_proj" in name:
                    # flap code: W_metric = metrics[args.metrics](wrapped_layers, subset, name) ** 2
                    # we dont put the manually added square (only added for attn) here since it is unreasonable
                    metric = self.cal_attn_calib_prune_metric(module.get_global_metric_score_distribution(), module.weight.data, cfg["prune_metric"])
                    if "square" in cfg["prune_method"]:
                        metric = metric ** 2
                    attn_metric_list.append(metric.to("cpu"))
                elif "down_proj" in name:
                    metric = self.cal_mlp_calib_prune_metric(module.get_global_metric_score_distribution(), module.weight.data, cfg["prune_metric"])
                    mlp_metric_list.append(metric.to("cpu"))
                
            if len(attn_metric_list) > 0:
                attn_metric = torch.stack(attn_metric_list).to(torch.float64)
                attn_metric = standarlization(attn_metric)
                print("attn_metric", attn_metric.shape)
                attn_metric = attn_metric.reshape(attn_metric.shape[0], -1, 128).mean(dim=2)
            else:
                attn_metric = None

            if len(mlp_metric_list) > 0:
                mlp_metric = torch.stack(mlp_metric_list).to(torch.float64)
                mlp_metric = standarlization(mlp_metric)
            else:
                mlp_metric = None

            # prune 1 head will lead to multiples times more flops pruned than 1 mlp channel
            multiples = self.cal_multiple_flops_for_flap(model)
            print("multiples", multiples, flush=True)
            if attn_metric is not None and mlp_metric is not None:
                prune_metric = torch.cat([attn_metric.view(-1), mlp_metric.view(-1)])
                flops_measurement = torch.cat([torch.full_like(attn_metric, multiples).view(-1), torch.full_like(mlp_metric, 1).view(-1)])
            elif attn_metric is not None:
                prune_metric = attn_metric.view(-1)
                flops_measurement = torch.full_like(attn_metric, multiples).view(-1)
            elif mlp_metric is not None:
                prune_metric = mlp_metric.view(-1)
                flops_measurement = torch.full_like(mlp_metric, 1).view(-1)

            # descending=True
            sorted_prune, indices = torch.sort(prune_metric)
            sorted_flops_measurement = flops_measurement[indices]
            # threshold = sorted_prune[int(sorted_prune.numel() * cfg["prune_ratio"])]
            cumulative_sum = torch.cumsum(sorted_flops_measurement, dim=0)

            # Calculate the total sum and the target sum
            total_sum = sorted_flops_measurement.sum()
            target_sum = total_sum * prune_ratio
            
            # Find the index where the cumulative sum reaches or exceeds the target sum
            threshold_index = torch.searchsorted(cumulative_sum, target_sum, right=True)
            print("Total sum:", total_sum, target_sum, prune_ratio, threshold_index)
            # Set the threshold using the value at this index
            threshold = sorted_prune[threshold_index]

            print("Threshold:", threshold.item())

            attention_ratio = 0
            attention_counter = 0
            mlp_ratio = 0
            mlp_counter = 0
            for name, module in model.named_modules():
                if "down_proj" not in name and "o_proj" not in name:
                    continue
                
                if "down_proj" in name and "down_proj" not in cfg["cust_tgt_modules"]:
                    continue
                elif "o_proj" in name and "o_proj" not in cfg["cust_tgt_modules"]:
                    continue

                if check_skip_layers(name):
                    continue
                
                if "o_proj" in name:
                    # flap code: W_metric = metrics[args.metrics](wrapped_layers, subset, name) ** 2
                    # we dont put the manually added square (only added for attn) here since it is unreasonable
                    metric = self.cal_attn_calib_prune_metric(module.get_global_metric_score_distribution(), module.weight.data, cfg["prune_metric"]).reshape(1, -1)
                    if "square" in cfg["prune_method"]:
                        metric = metric ** 2
                    metric = standarlization(metric)
                    metric = metric.reshape(-1, 128).mean(dim=1)
                elif "down_proj" in name:
                    metric = self.cal_mlp_calib_prune_metric(module.get_global_metric_score_distribution(), module.weight.data, cfg["prune_metric"]).reshape(1, -1)
                    metric = standarlization(metric)

                threshold = threshold.to(metric.device)
                module.pruning_ratio = metric[metric <= threshold].numel() / metric.numel()
                if "o_proj" in name:
                    attention_ratio += module.pruning_ratio
                    attention_counter += 1
                elif "down_proj" in name:
                    mlp_ratio += module.pruning_ratio
                    mlp_counter += 1
                print("name", name, "module.pruning_ratio", module.pruning_ratio)
            
            logger.append({f"flap_attention_average_pruning_ratio": attention_ratio/(attention_counter+1e-4)}, "test")
            logger.append({f"flap_mlp_average_pruning_ratio": mlp_ratio/(mlp_counter+1e-4)}, "test")
        elif "opt" in cfg["model_name"]:
            for name, module in model.named_modules():   
                if 'fc2' not in name and "out_proj" not in name:
                    continue

                if 'fc2' in name and 'fc2' not in cfg["cust_tgt_modules"]:
                    continue
                elif "out_proj" in name and "out_proj" not in cfg["cust_tgt_modules"]:
                    continue

                numbers = int("".join(filter(str.isdigit, name)))
                # if numbers <= cfg["skip_layers"]:
                #     continue

                if check_skip_layers(name):
                    continue

                if "out_proj" in name:
                    # flap code: W_metric = metrics[args.metrics](wrapped_layers, subset, name) ** 2
                    # we dont put the manually added square (only added for attn) here since it is unreasonable
                    metric = self.cal_attn_calib_prune_metric(module.get_global_metric_score_distribution(), module.weight.data, cfg["prune_metric"])
                    if "square" in cfg["prune_method"]:
                        metric = metric ** 2
                    attn_metric_list.append(metric.to("cpu"))
                elif 'fc2' in name:
                    metric = self.cal_mlp_calib_prune_metric(module.get_global_metric_score_distribution(), module.weight.data, cfg["prune_metric"])
                    mlp_metric_list.append(metric.to("cpu"))
                
            if len(attn_metric_list) > 0:
                attn_metric = torch.stack(attn_metric_list).to(torch.float64)
                attn_metric = standarlization(attn_metric)
                attn_metric = attn_metric.reshape(attn_metric.shape[0], -1, 128).mean(dim=2)
            else:
                attn_metric = None

            if len(mlp_metric_list) > 0:
                mlp_metric = torch.stack(mlp_metric_list).to(torch.float64)
                mlp_metric = standarlization(mlp_metric)
            else:
                mlp_metric = None

            # prune 1 head will lead to multiples times more flops pruned than 1 mlp channel
            multiples = self.cal_multiple_flops_for_flap(model)
            print("multiples", multiples, flush=True)
            if attn_metric is not None and mlp_metric is not None:
                prune_metric = torch.cat([attn_metric.view(-1), mlp_metric.view(-1)])
                flops_measurement = torch.cat([torch.full_like(attn_metric, multiples).view(-1), torch.full_like(mlp_metric, 1).view(-1)])
            elif attn_metric is not None:
                prune_metric = attn_metric.view(-1)
                flops_measurement = torch.full_like(attn_metric, multiples).view(-1)
            elif mlp_metric is not None:
                prune_metric = mlp_metric.view(-1)
                flops_measurement = torch.full_like(mlp_metric, 1).view(-1)

            # descending=True
            sorted_prune, indices = torch.sort(prune_metric)
            sorted_flops_measurement = flops_measurement[indices]
            # threshold = sorted_prune[int(sorted_prune.numel() * cfg["prune_ratio"])]
            cumulative_sum = torch.cumsum(sorted_flops_measurement, dim=0)

            # Calculate the total sum and the target sum
            total_sum = sorted_flops_measurement.sum()
            target_sum = total_sum * prune_ratio
            
            # Find the index where the cumulative sum reaches or exceeds the target sum
            threshold_index = torch.searchsorted(cumulative_sum, target_sum, right=True)
            print("Total sum:", total_sum, target_sum, prune_ratio, threshold_index)
            # Set the threshold using the value at this index
            threshold = sorted_prune[threshold_index]

            print("Threshold:", threshold.item())

            for name, module in model.named_modules():
                if 'fc2' not in name and "out_proj" not in name:
                    continue
                
                if 'fc2' in name and 'fc2' not in cfg["cust_tgt_modules"]:
                    continue
                elif "out_proj" in name and "out_proj" not in cfg["cust_tgt_modules"]:
                    continue

                if check_skip_layers(name):
                    continue
                
                if "out_proj" in name:
                    # flap code: W_metric = metrics[args.metrics](wrapped_layers, subset, name) ** 2
                    # we dont put the manually added square (only added for attn) here since it is unreasonable
                    metric = self.cal_attn_calib_prune_metric(module.get_global_metric_score_distribution(), module.weight.data, cfg["prune_metric"]).reshape(1, -1)
                    if "square" in cfg["prune_method"]:
                        metric = metric ** 2
                    metric = standarlization(metric)
                    metric = metric.reshape(-1, 128).mean(dim=1)
                elif 'fc2' in name:
                    metric = self.cal_mlp_calib_prune_metric(module.get_global_metric_score_distribution(), module.weight.data, cfg["prune_metric"]).reshape(1, -1)
                    metric = standarlization(metric)

                threshold = threshold.to(metric.device)
                module.pruning_ratio = metric[metric <= threshold].numel() / metric.numel()
                print("name", name, "module.pruning_ratio", module.pruning_ratio)
        return