from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import cast, Dict, List, Union, Generic, TypeVar, Final
from copy import deepcopy
from logging import Logger as LoggerType

from regex import X
import torch
from torch import nn
from torch.utils.hooks import RemovableHandle
from rich.progress import (
    BarColumn,
    Progress,
    Task,
    SpinnerColumn,
    TaskProgressColumn,
    TimeElapsedColumn,
    TimeRemainingColumn,
    MofNCompleteColumn
)

# from prune.reconstructor import BaseRLSReconstructor
from utils.layer_utils import LayerConfig, LayerSchema, TransformerConfig, TransformerLayerSchema
from prune.hybrid_obs_pruner_certified import HybridOBSLinearPruner, HybridOBSConv1dPruner, HybridOBSConv2dPruner
from utils.logs import SharedLogger
from utils.pruning_utils import PruningArguments
from prune.utils import zero_heads_gqa, zero_ffn_units, head_l2sq_full_gqa, ffn_l2sq_full, head_l1_full_gqa, ffn_l1_full


@dataclass
class _ModuleBudget:
    energy:  torch.Tensor          # ⟨‖y‖²⟩  (fp32 scalar on same device)
    beta:    float                 # global β
    gamma:   float                 # β·energy  (ratio term scale)
    score:   torch.Tensor          # running Sℓ,m  (fp32, initially 0)

    @property
    def budget(self) -> torch.Tensor:
        return self.beta * self.energy

    def accumulate(self, taylor_loss: torch.Tensor, n_blocks_pre: int):
        # Prevent division by zero
        if n_blocks_pre <= 0:
            return
        self.score += taylor_loss + self.gamma / n_blocks_pre            # Eq. (7)

    def fired(self) -> bool:
        return bool(self.score > self.budget)

    def reset(self):
        self.score.zero_()
        

ConfigType = TypeVar("ConfigType", bound=LayerConfig)
SchemaType = TypeVar("SchemaType", bound=LayerSchema)

class HALPE(ABC, Generic[ConfigType, SchemaType]):
    """
    HALPE: Hessian-Aware LLM Pruning on the Edge
    """
    def __init__(self, layer: nn.Module, layer_config: ConfigType, layer_type: SchemaType, device: str = "cpu", use_chunking: bool = False, chunk_size: int = 32) -> None:
        self.layer: nn.Module  = layer
        self.layer_config = layer_config
        self.layer_type = layer_type
        self.device = device
        self.logger = SharedLogger.get_logger(self.__class__.__name__)
        self.use_chunking = use_chunking
        self.chunk_size = chunk_size
        # Pruners
        self.pruners: Dict[str, Union[HybridOBSLinearPruner, HybridOBSConv1dPruner, HybridOBSConv2dPruner]] = {}

        # Layer Sensitivity metadata
        self.handle: Union[RemovableHandle, None] = None
        self.H: Union[torch.Tensor, None] = None
        self.local_sensitivity: float = 0.
        self.global_sensitivity: float = 0.

        self.num_samples = 0
        self.selected_blocks = []

        self.mod_budget: dict[str, _ModuleBudget] = {}
        # self.recons: dict[str, BaseRLSReconstructor] = {}   # 'head' | 'ffn' → recon
    
    def register_hook(self):
        def hook_fn(module, input):
            """Capture activations during forward pass"""
            if isinstance(input, tuple):
                input = input[0]
                # print(torch.isnan(input).any(), torch.isinf(input).any())
            self.add_batch(input)
        self.handle = self.layer.register_forward_pre_hook(hook_fn)
    
    def remove_hook(self):
        if self.handle is not None:
            self.handle.remove()
            self.handle = None
        else:
            raise ValueError("Hook not registered")
    
    # TODO: Running mean vs add and then divide
    # Note: We apply square-root scaling to reduce magnitude while preserving relative ordering
    @torch.inference_mode()
    def add_batch(self, x: torch.Tensor):
        try:
            x = self._preprocess_input(x).detach()                # (B × d) bf16 on GPU
            
            # Compute Hessian trace (Fisher Information Matrix approximation)
            # x contains gradients: (B × d) where B=batch_size, d=feature_dim
            x = x.to(torch.float32)
            if torch.isnan(x).any() or torch.isinf(x).any():
                self.logger.error(f"Batch contains nan or inf: {x}")
                raise ValueError("Batch contains nan or inf")
            
            # Compute squared gradients: (B × d)
            x = x * x
            
            # Sum across features to get per-sample Hessian trace: (B,)
            x = x.sum(dim=1)
            
            if torch.isnan(x).any() or torch.isinf(x).any():
                self.logger.error(f"Hessian trace contains nan or inf: {x}")
                raise ValueError("Hessian trace contains nan or inf")
            
            # Take mean across batch to get average Hessian trace
            hessian_trace = x.mean()
            
            if torch.isnan(hessian_trace) or torch.isinf(hessian_trace):
                self.logger.error(f"Hessian trace is inf or nan: {hessian_trace}")
                raise ValueError("Hessian trace is inf or nan")
            
            # Apply square-root scaling to reduce magnitude while preserving ordering
            hessian_trace_scaled = torch.sqrt(hessian_trace + 1e-8)
            
            # Accumulate the sum of scaled Hessian traces (will be averaged later)
            self.local_sensitivity += hessian_trace_scaled
            self.num_samples += x.shape[0]
            
            # Clean up memory
            # torch.cuda.empty_cache()
        except Exception as e:
            self.logger.error(f"Error in add_batch: {e}")
            raise e
    
    @torch.inference_mode()
    def add_batch1(self, x: torch.Tensor):
        try:
            x = self._preprocess_input(x).detach()                # (B × d) bf16 on GPU
            
            # Numerically stable row-wise L2 norm squared using per-row scaling
            # We compute the mean norm squared per sample.
            shape = x.shape
            x = x.to(torch.float32)
            if torch.isnan(x).any() or torch.isinf(x).any():
                self.logger.error(f"Batch contains nan or inf: {x}")
                raise ValueError("Batch contains nan or inf")
            eps = 1e-12
            row_max = x.abs().amax(dim=1, keepdim=True).clamp_min(eps)
            x = x / row_max
            if torch.isnan(x).any() or torch.isinf(x).any():
                self.logger.error(f"Scaled batch contains nan or inf: {x}")
                raise ValueError("Scaled batch contains nan or inf")
            row_l2_sq = (x * x).sum(dim=1) * (row_max.squeeze(1) ** 2)
            if torch.isnan(row_l2_sq).any() or torch.isinf(row_l2_sq).any():
                self.logger.error(f"Row L2 squared contains nan or inf: {row_l2_sq}")
                raise ValueError("Row L2 squared contains nan or inf")
            norm_squared = row_l2_sq.mean()
            if torch.isnan(norm_squared) or torch.isinf(norm_squared):
                self.logger.error(f"Norm squared is inf or nan: {norm_squared}")
                raise ValueError("Norm squared is inf or nan")
            # norm_squared is already the mean per sample, no need to multiply by batch size
            if not (torch.isinf(norm_squared) or torch.isnan(norm_squared)):
                self.local_sensitivity += norm_squared  # Accumulate mean sensitivity per sample
            else:
                raise ValueError(f"Norm squared is inf or nan: {norm_squared}")
            self.num_samples += x.shape[0]               # Accumulate number of samples
            
            # Clean up memory
            # torch.cuda.empty_cache()
        except Exception as e:
            self.logger.error(f"Error in add_batch: {e}")
            raise e

    @torch.inference_mode()
    def compute_local_sensitivity(self):
        if self.num_samples == 0:
            raise ValueError("No samples added. Cannot compute local sensitivity. Call add_batch() first.")
        # Compute the average scaled Hessian trace across all samples
        self.local_sensitivity = self.local_sensitivity / self.num_samples
    
    @torch.inference_mode()
    def set_global_sensitivity(self, global_sensitivity: float, alpha: float = 1.0):
        self.global_sensitivity = self.local_sensitivity + (alpha * global_sensitivity)
    
    @torch.inference_mode()
    def update_global_sensitivity(self, global_sensitivity: float):
        self.global_sensitivity = global_sensitivity
    
    @torch.inference_mode()
    def get_global_sensitivity(self):
        return self.global_sensitivity

    @torch.inference_mode()
    def remove_micropruner_hooks(self):
        for _, pruner in self.pruners.items():
            pruner.remove_hook()
    
    @torch.inference_mode()
    def finalize_calibration(self, min_damping: float = 0.0001, max_damping: float = 1.0, max_iterative_iterations: int = 20, iterative_tolerance: float = 1e-6):
        """
        Finalize calibration for all pruners that have candidate blocks.
        This computes the inverse Hessian matrices needed for exact importance computation.
        """
        for role, pruner in self.pruners.items():
            if hasattr(pruner, 'selected_blocks') and len(pruner.selected_blocks) > 0:
                pruner.finalize_calibration(min_damping=min_damping, max_damping=max_damping, max_iterative_iterations=max_iterative_iterations, iterative_tolerance=iterative_tolerance)
                # Force garbage collection after each pruner
                # torch.cuda.empty_cache()
    
    @torch.inference_mode()
    def move_all_hessians_to_cpu(self):
        """Move all Hessian accumulators to CPU to free GPU memory after calibration"""
        self.logger.debug("Moving all Hessians to CPU to free GPU memory")
        for role, pruner in self.pruners.items():
            if hasattr(pruner, 'move_hessian_to_cpu'):
                pruner.move_hessian_to_cpu()
        
    @torch.inference_mode()
    def move_all_hessians_to_gpu(self):
        """Move all Hessian accumulators to GPU to free CPU memory after calibration"""
        self.logger.debug("Moving all Hessians to GPU to free CPU memory")
        for role, pruner in self.pruners.items():
            if hasattr(pruner, 'move_hessian_to_gpu'):
                pruner.move_hessian_to_gpu()
    
    @torch.inference_mode()
    def move_layer_to_cpu(self):
        """Move the layer to CPU to free GPU memory"""
        self.layer.to("cpu")
    
    @torch.inference_mode()
    def move_layer_to_gpu(self):
        """Move the layer to GPU to free CPU memory"""
        self.layer.to(self.device)
    
    @torch.inference_mode()
    def get_candidate_blocks(self):
        return self.selected_blocks
    
    @torch.inference_mode()
    def do_layer_reconstruction(self):
        for role, pruner in self.pruners.items():
            mb = self.mod_budget[role]
            # 3) trigger LS if budget exceeded
            if mb.fired():
                # do layer reconstruct (TODO: add reconstruction) Not for now
                mb.reset()
    
    @abstractmethod
    @torch.inference_mode()
    def set_candidate_blocks(self, candidate_blocks: torch.Tensor):
        """
        Sets candidate blocks based on the computed tensor.
        Args:
            candidate_blocks: Tensor with columns [importance, layer_idx, block_idx, block_type]
        Returns:
            None
        """
        ...
        
    @abstractmethod
    @torch.inference_mode()
    def _preprocess_input(self, x: torch.Tensor) -> torch.Tensor:
        """
        Preprocess the input tensor before adding it to the Hessian.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: Preprocessed input tensor.
        """
        ...
    
    @abstractmethod
    @torch.inference_mode()
    def get_initial_importances(self, layer_idx: int) -> torch.Tensor:
        """
        Computes initial importances for each block in the layer.

        Returns:
            torch.Tensor: Initial importances for each block.
        """
        ...
    
    @abstractmethod
    @torch.inference_mode()
    def initialize_pruners(self):
        """
        Initializes the pruners for the layer.

        Args:
            selected_blocks (list): List of selected blocks to prune.

        Returns:
            None
        """
        ...
    
    @abstractmethod
    @torch.inference_mode()
    def compute_exact_importances(self, layer_idx: int):
        """
        Computes exact importances for the layer.

        Args:
            layer_idx (int): Index of the layer.

        Returns:
            torch.Tensor: Exact importances for the layer.
        """
        ...
    
    @abstractmethod
    @torch.inference_mode()
    def block_indices_to_prune_per_layer(self, indices: torch.Tensor) -> Dict[str, List[int]]:
        """
        Maps flat sorted indices to their respective layer indices.

        Args:
            flat_sorted_indices (torch.Tensor): Indices of candidate blocks.

        Returns:
            dict: Mapping of layer indices to their respective pruned block indices.
        """
        ...
    
    @abstractmethod
    @torch.inference_mode()
    def prune(self, selected_blocks: torch.Tensor, prune_args: Final[PruningArguments], layer_idx: int, progress: Union[Progress, None] = None, task_pruned: Union[Task, None] = None, conditioned_score_max_chunk_size: int = 256):
        ...
    
    @abstractmethod
    @torch.inference_mode()
    def get_pruned_blocks(self) -> Dict[str, List[int]]:
        """
        Retrieves the pruned blocks for the layer.

        Returns:
            dict: Pruned blocks for the layer.
        """
        ...

    @abstractmethod
    @torch.inference_mode()
    def update_module_energy(self, role: str, activ: torch.Tensor):
        """
        Updates the module energy based on the activations.

        Args:
            role (str): Role of the module (e.g., 'head', 'ffn').
            activ (torch.Tensor): Activations from the forward pass.

        Returns:
            None
        """
        ...
    
    @abstractmethod
    @torch.inference_mode()
    def finalize_module_energy(self, num_batches: int):
        """
        Finalizes the module energy after processing all batches.

        Args:
            num_batches (int): Number of batches processed.

        Returns:
            None
        """
        ...
    
    @abstractmethod
    @torch.inference_mode()
    def update_ratio_term(self, role: str, taylor_loss: torch.Tensor):
        """
        Updates the ratio term for the module budget.

        Args:
            role (str): Role of the module (e.g., 'head', 'ffn').
            taylor_loss (torch.Tensor): Taylor loss for the module.

        Returns:
            None
        """
        ...

    @abstractmethod
    @torch.inference_mode()
    def get_updated_configs(self) -> LayerConfig:
        """
        Returns the updated layer config after pruning.
        """
        ...
    
    @abstractmethod
    @torch.inference_mode()
    def get_num_params(self) -> int:
        """
        Returns the number of parameters in the layer.
        """
        ...
    
    @abstractmethod
    @torch.inference_mode()
    def reset_pruner(self):
        """
        Resets the pruner for the next iteration of the pruning.
        """
        ...

class TransformerHALPE(HALPE[TransformerConfig, TransformerLayerSchema]):
    def __init__(self, layer: nn.Module, layer_config: TransformerConfig, layer_type: TransformerLayerSchema,
                 beta: float = 2.5e-3, gamma_scale: float = 1.0, device: str = "cpu", use_chunking: bool = False, chunk_size: int = 32) -> None:
        super().__init__(layer, layer_config, layer_type, device=device, use_chunking=use_chunking, chunk_size=chunk_size)

        # Initialize module budgets
        for role in ("head", "ffn"):
            energy = torch.zeros((), device=self.device)  # placeholder
            self.mod_budget[role] = _ModuleBudget(
                energy=energy, beta=beta,
                gamma=gamma_scale * beta * energy.item(),
                score=torch.tensor(0., device=self.device))
    
    @torch.inference_mode()
    def update_module_energy(self, role: str, activ: torch.Tensor):
        # activ:  (B, seq_len, d)  for MHA or (B, seq_len, 4d) for FFN
        # e = activ.pow(2).mean()

        # Numerically stable energy estimate: mean L2 norm squared with scaling
        xf = activ.to(torch.float32)
        eps = 1e-12
        row_max = xf.abs().amax(dim=-1, keepdim=True).clamp_min(eps)
        xf = xf / row_max
        row_l2_sq = (xf * xf).sum(dim=-1) * (row_max.squeeze(-1) ** 2)
        e = row_l2_sq.mean()
        mb = self.mod_budget[role]
        mb.energy += e        # accumulate

    @torch.inference_mode()
    def finalize_module_energy(self, num_batches: int):
        for mb in self.mod_budget.values():
            mb.energy /= num_batches
            mb.gamma  = mb.beta * mb.energy.item()

    @torch.inference_mode()
    def update_ratio_term(self, role: str, taylor_loss: torch.Tensor):
        mb = self.mod_budget[role]
        if role == "head":
            remaining = self.layer_config.num_heads - len(self.pruners['head'].pruned_blocks)
        else:
            remaining = self.layer_config.intermediate_dimension - len(self.pruners['ffn'].pruned_blocks)

        # If nothing left to prune in this module, don’t update (or clamp to 1)
        if remaining <= 0:
            self.logger.warning(f"Role {role}: No blocks remaining to prune. Skipping accumulation. "
                              f"Total blocks: {self.layer_config.num_heads if role == 'head' else self.layer_config.intermediate_dimension}, "
                              f"Pruned blocks: {len(self.pruners[role].pruned_blocks)}")
            return
        mb.accumulate(taylor_loss=taylor_loss, n_blocks_pre=remaining)
    
    @torch.inference_mode()
    def _preprocess_input(self, x: torch.Tensor) -> torch.Tensor:
        x = x.flatten(start_dim=1)  # (batch_size, d)
        return x
    
    @torch.inference_mode()
    def get_initial_importances(self, layer_idx: int) -> torch.Tensor:
        # return self.get_initial_importances_l1_weight_magnitudes(layer_idx)
        # return self.get_initial_importances_l2_weight_magnitudes(layer_idx)
        return self.get_initial_importances_hessian_diagonal(layer_idx)
    
    # Weight Magnitudes based initial importances
    @torch.inference_mode()
    def get_initial_importances_l2_weight_magnitudes(self, layer_idx: int) -> torch.Tensor:
        q_proj = getattr(getattr(self.layer, self.layer_type.layers['q'].module_name), self.layer_type.layers['q'].attribute)
        k_proj = getattr(getattr(self.layer, self.layer_type.layers['k'].module_name), self.layer_type.layers['k'].attribute)
        v_proj = getattr(getattr(self.layer, self.layer_type.layers['v'].module_name), self.layer_type.layers['v'].attribute)
        o_proj = getattr(getattr(self.layer, self.layer_type.layers['o'].module_name), self.layer_type.layers['o'].attribute)
        head_importance = head_l2sq_full_gqa(q_proj, k_proj, v_proj, o_proj, self.layer_config.head_size)
        if len(self.pruners) > 0:
            head_importance[self.pruners['head'].get_pruned_blocks()] = torch.inf
        
        up_proj = getattr(getattr(self.layer, self.layer_type.layers['fc2'].module_name), self.layer_type.layers['fc2'].attribute)
        down_proj = getattr(getattr(self.layer, self.layer_type.layers['fc3'].module_name), self.layer_type.layers['fc3'].attribute)
        gate_proj = getattr(getattr(self.layer, self.layer_type.layers['fc1'].module_name), self.layer_type.layers['fc1'].attribute) if self.layer_type.layers['fc1'].attribute is not None else None
        ffn_importance = ffn_l2sq_full(up_proj, down_proj, gate_proj=gate_proj)
        if len(self.pruners) > 0:
            ffn_importance[self.pruners['ffn'].get_pruned_blocks()] = torch.inf
        
        importances = torch.cat([head_importance, ffn_importance], dim=0)
        
        # Create layer indices
        total_blocks = self.layer_config.num_heads + self.layer_config.intermediate_dimension
        layer_indices = torch.full((total_blocks,), fill_value=layer_idx, dtype=torch.int, device=torch.device(self.device))
        
        # Create global block indices (0 to total_blocks-1)
        block_indices = torch.cat([torch.arange(self.layer_config.num_heads, dtype=torch.int, device=torch.device(self.device)), torch.arange(self.layer_config.intermediate_dimension, dtype=torch.int, device=torch.device(self.device))], dim=0)
        # block_indices = torch.arange(total_blocks, dtype=torch.int, device=torch.device(self.device))
        
        # Create block type indices: 0 for head, 1 for FFN
        head_type_indices = torch.zeros(self.layer_config.num_heads, dtype=torch.int, device=torch.device(self.device))
        ffn_type_indices = torch.ones(self.layer_config.intermediate_dimension, dtype=torch.int, device=torch.device(self.device))
        block_type_indices = torch.cat([head_type_indices, ffn_type_indices], dim=0)
        
        # Stack [importance, layer_idx, block_idx, block_type] where block_type: 0=head, 1=ffn
        stacked = torch.stack([importances, layer_indices, block_indices, block_type_indices], dim=1)

        return stacked
    
    def get_initial_importances_l1_weight_magnitudes(self, layer_idx: int) -> torch.Tensor:
        q_proj = getattr(getattr(self.layer, self.layer_type.layers['q'].module_name), self.layer_type.layers['q'].attribute)
        k_proj = getattr(getattr(self.layer, self.layer_type.layers['k'].module_name), self.layer_type.layers['k'].attribute)
        v_proj = getattr(getattr(self.layer, self.layer_type.layers['v'].module_name), self.layer_type.layers['v'].attribute)
        o_proj = getattr(getattr(self.layer, self.layer_type.layers['o'].module_name), self.layer_type.layers['o'].attribute)
        head_importance = head_l1_full_gqa(q_proj, k_proj, v_proj, o_proj, self.layer_config.head_size)
        if len(self.pruners) > 0:
            head_importance[self.pruners['head'].get_pruned_blocks()] = torch.inf
        up_proj = getattr(getattr(self.layer, self.layer_type.layers['fc2'].module_name), self.layer_type.layers['fc2'].attribute)
        down_proj = getattr(getattr(self.layer, self.layer_type.layers['fc3'].module_name), self.layer_type.layers['fc3'].attribute)
        gate_proj = getattr(getattr(self.layer, self.layer_type.layers['fc1'].module_name), self.layer_type.layers['fc1'].attribute) if self.layer_type.layers['fc1'].attribute is not None else None
        ffn_importance = ffn_l1_full(up_proj, down_proj, gate_proj=gate_proj)
        if len(self.pruners) > 0:
            ffn_importance[self.pruners['ffn'].get_pruned_blocks()] = torch.inf
        importances = torch.cat([head_importance, ffn_importance], dim=0)
        
        # Create layer indices
        total_blocks = self.layer_config.num_heads + self.layer_config.intermediate_dimension
        layer_indices = torch.full((total_blocks,), fill_value=layer_idx, dtype=torch.int, device=torch.device(self.device))
        
        # Create global block indices (0 to total_blocks-1)
        block_indices = torch.arange(total_blocks, dtype=torch.int, device=torch.device(self.device))
        
        # Stack [importance, layer_idx, block_idx]
        stacked = torch.stack([importances, layer_indices, block_indices], dim=1)
        return stacked

  
    # @torch.inference_mode()
    # def get_initial_importances_weight_magnitudes(self, layer_idx: int) -> torch.Tensor:
    #     """
    #     Computes initial importances for each block in the layer.

    #     Returns:
    #         torch.Tensor: Tensor with [importance, layer_idx, block_idx] where block_idx 
    #                      is the global block index within the layer (0 to total_blocks-1).
    #     """
    #     head_magnitudes = self.compute_head_magnitudes().to(torch.device(self.device))
    #     ffn_magnitudes = self.compute_ffn_magnitudes().to(torch.device(self.device))
    #     importances = torch.cat([head_magnitudes, ffn_magnitudes], dim=0)
        
    #     # Create layer indices
    #     total_blocks = self.layer_config.num_heads + self.layer_config.intermediate_dimension
    #     layer_indices = torch.full((total_blocks,), fill_value=layer_idx, dtype=torch.int, device=torch.device(self.device))
        
    #     # Create global block indices (0 to total_blocks-1)
    #     block_indices = torch.arange(total_blocks, dtype=torch.int, device=torch.device(self.device))
        
    #     # Stack [importance, layer_idx, block_idx]
    #     stacked = torch.stack([importances, layer_indices, block_indices], dim=1)
    #     return stacked
    
    # @torch.inference_mode()
    # def compute_head_magnitudes(self):
    #     head_magnitudes = [0.0 for _ in range(self.layer_config.num_heads)]
    #     for key, layer_specs in self.layer_type.layers.items():
    #         module = getattr(self.layer, layer_specs.module_name)
    #         layer = getattr(module, layer_specs.attribute)
    #         weight_matrix = layer.weight.detach()
    #         # Extract the rows corresponding to the head
    #         if key in ['q', 'k', 'v']:
    #             for head_idx in range(self.layer_config.num_heads):
    #                 # Extract the rows corresponding to the head
    #                 head_magnitudes[head_idx] += weight_matrix[head_idx * self.layer_config.head_size : (head_idx + 1) * self.layer_config.head_size, :].abs().sum().item()
    #         elif key in ['o']:
    #             for head_idx in range(self.layer_config.num_heads):
    #                 # Sum absolute values over the column block corresponding to this head
    #                 head_magnitudes[head_idx] += weight_matrix[:, head_idx * self.layer_config.head_size : (head_idx + 1) * self.layer_config.head_size].abs().sum().item()
    #     head_magnitudes = torch.tensor(head_magnitudes, dtype=torch.float32, device=torch.device(self.device))
    #     if len(self.pruners) > 0:
    #         head_magnitudes[self.pruners['head'].get_pruned_blocks()] = torch.inf
    #     return head_magnitudes
    
    # @torch.inference_mode()
    # def compute_ffn_magnitudes(self):
    #     ffn_magnitudes = None
    #     num_units = 0
    #     # Find the FFN layer to prune (either 'fc1', or 'fc2', as per schema)
    #     for key, layer_specs in self.layer_type.layers.items():
    #         if key not in ['fc1', 'fc2', 'fc3']:
    #             continue
    #         module = getattr(self.layer, layer_specs.module_name)
    #         layer = getattr(module, layer_specs.attribute)
    #         weight_matrix = layer.weight.detach()
    #         if layer_specs.prune_type == 'row':
    #             num_units = weight_matrix.shape[0]
    #             ffn_magnitudes = weight_matrix.abs().sum(dim=1) if ffn_magnitudes is None else ffn_magnitudes + weight_matrix.abs().sum(dim=1)
    #         elif layer_specs.prune_type == 'column':
    #             num_units = weight_matrix.shape[1]
    #             ffn_magnitudes = weight_matrix.abs().sum(dim=0) if ffn_magnitudes is None else ffn_magnitudes + weight_matrix.abs().sum(dim=0)
    #         assert ffn_magnitudes is not None and ffn_magnitudes.shape[0] == num_units, f"Number of units in FFN layer {key} is not consistent with the number of units in the layer schema"
    #     ffn_magnitudes = ffn_magnitudes.to(dtype=torch.float32, device=torch.device(self.device))
    #     if len(self.pruners) > 0:
    #         ffn_magnitudes[self.pruners['ffn'].get_pruned_blocks()] = torch.inf
    #     return ffn_magnitudes

    # Diagonal hessian based initial importances
    @torch.inference_mode()
    def get_initial_importances_hessian_diagonal(self, layer_idx: int) -> torch.Tensor:
        H_diag = torch.diagonal(self.pruners['head']._H_accum)
        head_importances = self.head_importance_from_O(H_diag, self.layer_config.head_size)
        del H_diag
        H_diag = torch.diagonal(self.pruners['ffn']._H_accum)
        ffn_importances = self.ffn_importance_from_down(H_diag)
        del H_diag
        importances = torch.cat([head_importances, ffn_importances], dim=0)

        # Create layer indices
        total_blocks = self.layer_config.num_heads + self.layer_config.intermediate_dimension
        layer_indices = torch.full((total_blocks,), fill_value=layer_idx, dtype=torch.int, device=torch.device(self.device))
        
        # Create global block indices (0 to total_blocks-1)
        block_indices = torch.cat([torch.arange(self.layer_config.num_heads, dtype=torch.int, device=torch.device(self.device)), torch.arange(self.layer_config.intermediate_dimension, dtype=torch.int, device=torch.device(self.device))], dim=0)
        # block_indices = torch.arange(total_blocks, dtype=torch.int, device=torch.device(self.device))
        
        # Stack [importance, layer_idx, block_idx]
        stacked = torch.stack([importances, layer_indices, block_indices], dim=1)   
        return stacked
    
    @torch.inference_mode()
    def head_importance_from_O(self, H_o_diag: torch.Tensor, head_dim: int):
        """
        attn: LlamaAttention (has attn.o_proj.weight of shape [hidden_size, num_heads*head_dim])
        H_o_diag: same shape as o_proj.weight, containing diagonal Hessian entries
        returns: [num_heads] OBS-diag scores using O only
        """
        for key, layer_specs in self.layer_type.layers.items():
            module = getattr(self.layer, layer_specs.module_name)
            layer = getattr(module, layer_specs.attribute)
            weight_matrix = layer.weight.detach()
            if key in ['o']:
                W = layer.weight.detach().clone().to(torch.float64)
                H_o_diag = H_o_diag.to(torch.float64)
                hs, total = W.shape         # [hidden_size, num_heads*head_dim]
                num_heads = total // head_dim
                W2H = (W**2) * H_o_diag  # accumulate in fp64 for safety
                per_col = W2H.sum(dim=0)                   # [num_heads*head_dim]
                scores = per_col.view(num_heads, head_dim).sum(dim=1)  # sum cols per head
                if len(self.pruners) > 0:
                    scores[self.pruners['head'].get_pruned_blocks()] = torch.inf
                return (0.5 * scores).float()             # [num_heads]

    @torch.inference_mode()
    def ffn_importance_from_down(self, H_down_diag: torch.Tensor):
        """
        down_linear.weight shape [hidden_size, intermediate_size]
        H_down_diag same shape
        returns: [intermediate_size] OBS-diag scores using down only
        """
        for key, layer_specs in self.layer_type.layers.items():
            module = getattr(self.layer, layer_specs.module_name)
            layer = getattr(module, layer_specs.attribute)
            weight_matrix = layer.weight.detach()
            if key in ['fc3']:
                W = layer.weight.detach().clone().to(torch.float64)
                H_down_diag = H_down_diag.to(torch.float64)
                hs, total = W.shape
                W2H = (W**2) * H_down_diag  # accumulate in fp64 for safety
                per_unit = W2H.sum(dim=0)   # sum over rows -> per neuron column
                if len(self.pruners) > 0:
                    per_unit[self.pruners['ffn'].get_pruned_blocks()] = torch.inf
                return (0.5 * per_unit).float()

    @torch.inference_mode()
    def initialize_pruners(self):
        for key, layer_specs in self.layer_type.layers.items():
            module = getattr(self.layer, layer_specs.module_name)
            layer = getattr(module, layer_specs.attribute)
            if key == 'o' and layer_specs.prune_type == "column":
                if 'head' not in self.pruners:
                    # Create pruner only if it doesn't exist (first iteration)
                    self.pruners['head'] = HybridOBSLinearPruner(
                        layer,
                        self.layer_config.head_size,
                        device=self.device,
                        use_chunking=self.use_chunking,
                        chunk_size=self.chunk_size,
                    )
                # Always register hook for each iteration
                self.pruners['head'].register_hook()
            elif key == 'fc2' and layer_specs.prune_type == "column":
                if 'ffn' not in self.pruners:
                    # Create pruner only if it doesn't exist (first iteration)
                    self.pruners['ffn'] = HybridOBSLinearPruner(
                        layer,
                        1,
                        device=self.device,
                        use_chunking=self.use_chunking,
                        chunk_size=self.chunk_size,
                    )
                # Always register hook for each iteration
                self.pruners['ffn'].register_hook()
            elif key == 'fc3' and layer_specs.prune_type == "column":
                if 'ffn' not in self.pruners:
                    # Create pruner only if it doesn't exist (first iteration)
                    self.pruners['ffn'] = HybridOBSLinearPruner(
                        layer,
                        1,
                        device=self.device,
                        use_chunking=self.use_chunking,
                        chunk_size=self.chunk_size,
                    )
                # Always register hook for each iteration
                self.pruners['ffn'].register_hook()
    
    @torch.inference_mode()
    def compute_exact_importances(self, layer_idx: int):
        # Collect importances and block indices using tensors
        head_importances = []
        head_block_indices = []
        ffn_importances = []
        ffn_block_indices = []
        
        # Only compute if there are candidate blocks
        if 'head' in self.pruners and hasattr(self.pruners['head'], 'selected_blocks') and len(self.pruners['head'].selected_blocks) > 0:
            # Get importance tensor and convert to global block indices
            head_importance_tensor = self.pruners['head'].importance_all(return_tensor=True)
            
            if isinstance(head_importance_tensor, torch.Tensor) and head_importance_tensor.numel() > 0:
                # Debug: Check for infinite values
                if torch.isinf(head_importance_tensor).any() or torch.isnan(head_importance_tensor).any():
                    head_importance_tensor = torch.where(torch.isinf(head_importance_tensor) | torch.isnan(head_importance_tensor), 
                                                        torch.tensor(1e6, device=head_importance_tensor.device, dtype=head_importance_tensor.dtype), 
                                                        head_importance_tensor)
                
                # Debug: Check individual head importance before scaling
                
                head_importances.append(head_importance_tensor * self.global_sensitivity)
                # Use the actual selected blocks (which are already global indices for head)
                head_block_indices.append(torch.tensor(self.pruners['head'].selected_blocks, dtype=torch.int, device=torch.device(self.device)))
            else:
                self.logger.debug(f"Layer {layer_idx} head_importance_tensor is not a valid tensor or is empty")

        
        if 'ffn' in self.pruners and hasattr(self.pruners['ffn'], 'selected_blocks') and len(self.pruners['ffn'].selected_blocks) > 0:
            # Get importance tensor and convert to global block indices
            ffn_importance_tensor = self.pruners['ffn'].importance_all(return_tensor=True)
            
            if isinstance(ffn_importance_tensor, torch.Tensor) and ffn_importance_tensor.numel() > 0:
                # Debug: Check for infinite values
                if torch.isinf(ffn_importance_tensor).any() or torch.isnan(ffn_importance_tensor).any():
                    ffn_importance_tensor = torch.where(torch.isinf(ffn_importance_tensor) | torch.isnan(ffn_importance_tensor), 
                                                       torch.tensor(1e6, device=ffn_importance_tensor.device, dtype=ffn_importance_tensor.dtype), 
                                                       ffn_importance_tensor)
                
                # Debug: Check individual ffn importance before scaling
                
                ffn_importances.append(ffn_importance_tensor * self.global_sensitivity)
                # Convert local block indices to global block indices
                ffn_global_indices = torch.tensor(self.pruners['ffn'].selected_blocks, dtype=torch.int, device=torch.device(self.device)) + self.layer_config.num_heads
                ffn_block_indices.append(ffn_global_indices)
            else:
                self.logger.debug(f"Layer {layer_idx} ffn_importance_tensor is not a valid tensor or is empty")
        
        # Concatenate all importances and block indices
        all_importances = []
        # all_block_indices = []
        all_block_types = torch.empty(0, dtype=torch.int, device=torch.device(self.device))
    
        if head_importances:
            all_importances.extend(head_importances)
            # all_block_indices.extend(head_block_indices)
            all_block_types = torch.cat([all_block_types, torch.zeros(len(head_block_indices[0]), dtype=torch.int, device=torch.device(self.device))], dim=0)
        if ffn_importances:
            all_importances.extend(ffn_importances)
            # all_block_indices.extend(ffn_block_indices)
            all_block_types = torch.cat([all_block_types, torch.ones(len(ffn_block_indices[0]), dtype=torch.int, device=torch.device(self.device))], dim=0)
        if not all_importances:
            self.logger.debug(f"Layer {layer_idx} no importances computed")
            return torch.empty((0, 4), device=torch.device(self.device))
        
        # Concatenate tensors
        importances_tensor = torch.cat(all_importances, dim=0).to(self.device)
        block_indices_tensor = torch.cat([torch.tensor(self.pruners['head'].selected_blocks, dtype=torch.int), torch.tensor(self.pruners['ffn'].selected_blocks, dtype=torch.int)], dim=0).to(self.device)
        # block_types_tensor = torch.cat([torch.zeros(len(all_block_indices[0])), torch.ones(len(all_block_indices[1]))], dim=0).to(self.device)
        block_types_tensor = all_block_types.to(self.device)
        layer_indices_tensor = torch.full((len(importances_tensor),), fill_value=layer_idx, dtype=torch.int).to(self.device)
        
        # Stack [importance, layer_idx, block_idx]
        stacked = torch.stack([importances_tensor, layer_indices_tensor, block_indices_tensor, block_types_tensor], dim=1)
        
        return stacked

    @torch.inference_mode()
    def block_indices_to_prune_per_layer(self, indices: torch.Tensor) -> Dict[str, List[int]]:
        """
        Returns:
            dict: Mapping of module types to their local block indices.
        """
        # indices = indices.sort().values
        # Since head blocks come first in the concatenated tensor (0 to num_heads-1)
        # and FFN blocks come second (0 to intermediate_dimension-1),
        # we need to map global indices back to local indices
        # head_blocks = indices[indices < self.layer_config.num_heads].tolist()
        # ffn_blocks = (indices[indices >= self.layer_config.num_heads] - self.layer_config.num_heads).tolist()
        head_blocks = indices[indices[:, 3] == 0].tolist()
        ffn_blocks = indices[indices[:, 3] == 1].tolist()
        return {'head': head_blocks, 'ffn': ffn_blocks}
    
    @torch.inference_mode()
    def set_candidate_blocks(self, candidate_blocks: torch.Tensor):
        """
        Sets candidate blocks based on the computed tensor.
        Args:
            candidate_blocks: Tensor with columns [importance, layer_idx, block_idx, block_type]
        Returns:
            None
        """
        if candidate_blocks.shape[0] == 0:
            self.selected_blocks = torch.empty(0, dtype=torch.long, device=self.device)
            return
        
        # Extract block indices and types
        block_indices = candidate_blocks[:, 2].long()  # block_idx column
        block_types = candidate_blocks[:, 3].long()    # block_type column
        
        # Separate head and FFN blocks
        head_mask = block_types == 0
        ffn_mask = block_types == 1
        
        head_indices = block_indices[head_mask]
        ffn_indices = block_indices[ffn_mask]

        if head_indices.numel() > 0 and head_indices.max() >= self.layer_config.num_heads:
            self.logger.warning(f"Head block indices out of bounds: max={head_indices.max()}, max_allowed={self.layer_config.num_heads-1}, indices={head_indices}")
            raise ValueError(f"Head block indices out of bounds: max={head_indices.max()}, max_allowed={self.layer_config.num_heads-1}")
        if ffn_indices.numel() > 0 and ffn_indices.max() >= self.layer_config.intermediate_dimension:
            self.logger.warning(f"FFN block indices out of bounds: max={ffn_indices.max()}, max_allowed={self.layer_config.intermediate_dimension-1}")
            raise ValueError(f"FFN block indices out of bounds: max={ffn_indices.max()}, max_allowed={self.layer_config.intermediate_dimension-1}")
        
        # Set the selected blocks for each pruner
        if 'head' in self.pruners:
            self.pruners['head'].set_selected_blocks(head_indices)
        
        if 'ffn' in self.pruners:
            self.pruners['ffn'].set_selected_blocks(ffn_indices)
        
        # Store the overall selected blocks (combined)
        self.selected_blocks = torch.cat([head_indices, ffn_indices+self.layer_config.num_heads], dim=0)
        
    @torch.inference_mode()
    def prune(self, selected_blocks: torch.Tensor, prune_args: Final[PruningArguments], layer_idx: int, progress: Union[Progress, None] = None, task_pruned: Union[Task, None] = None, conditioned_score_max_chunk_size: int = 256):
        blocks_to_prune_per_layer = self.block_indices_to_prune_per_layer(indices=selected_blocks)
        # with Progress(*columns) as progress:
        # if progress is not None:
            # task = progress.add_task(f"[blue]Pruning layer {layer_idx}...", total=selected_blocks.shape[0])
        for module_name, pruner in self.pruners.items():

            num_blocks_to_prune = len(blocks_to_prune_per_layer[module_name])
            pruned = 0
            if num_blocks_to_prune > 0:
                # Ensure calibration is finalized before pruning
                assert pruner.G_CC is not None, f"G_CC is None for {module_name} in layer {layer_idx}"
                assert 0 < len(pruner.selected_blocks), f"No selected blocks for {module_name} in layer {layer_idx}"
                if module_name == "head":
                    max_try = min(num_blocks_to_prune, prune_args.max_try_head)
                elif module_name == "ffn":
                    max_try = min(num_blocks_to_prune, prune_args.max_try_ffn)
                else:
                    raise ValueError(f"Invalid module name: {module_name}")
                pruned_blocks = []
                iteration = 0
                while pruned < num_blocks_to_prune:
                    iteration += 1
                    scores = pruner.importance_all(return_tensor=True)
                    candidates_local = torch.topk(scores, k=scores.numel(), largest=False).indices
                    if scores.numel() == 1:
                        blk, score = pruner.prune_lowest(scores)
                        pruned_blocks.append(blk)
                        pruned += 1
                        # 2) accumulate module score  Sℓ,m
                        mb = self.mod_budget[module_name]
                        taylor = torch.tensor(score, device=mb.score.device)
                        # Accumulate once, with correct n_blocks_pre (inside update_ratio_term)
                        self.update_ratio_term(module_name, taylor * (self.global_sensitivity ** 2))
                        if progress is not None:
                            progress.update(task_pruned, advance=1)
                        break
                    elif scores.numel() == 0:
                        self.logger.warning(f"No scores for {module_name} in layer {layer_idx}")
                        raise ValueError(f"No scores for {module_name} in layer {layer_idx}")
                    B_list = pruner.certify_batch_chunked(
                        scores,
                        candidates=candidates_local,
                        max_try=max_try,
                        chunk_size=min(conditioned_score_max_chunk_size, candidates_local.numel()),
                    )
                    
                    B_loc = torch.tensor(B_list, device=scores.device, dtype=torch.long)
                    if B_loc.numel() == 0:
                        self.logger.warning(f"Certification failed, using single worst in pool. Module: {module_name}, Layer: {layer_idx}")
                        blk, score = pruner.prune_lowest(scores)
                        pruned_blocks.append(blk)
                        pruned += 1
                        # 2) accumulate module score  Sℓ,m
                        mb = self.mod_budget[module_name]
                        taylor = torch.tensor(score, device=mb.score.device)
                        # Accumulate once, with correct n_blocks_pre (inside update_ratio_term)
                        self.update_ratio_term(module_name, taylor * (self.global_sensitivity ** 2))
                        # mb.accumulate(
                        #     taylor_loss=taylor * (self.global_sensitivity ** 2),
                        #     n_blocks_pre=pruner.block_size)
                        if progress is not None:
                            progress.update(task_pruned, advance=1)
                        continue
                    self.logger.debug(f"Pruning {B_loc.numel()} blocks. Module: {module_name}, Layer: {layer_idx}")
                    batch_size = int(B_loc.numel())
                    pruner.apply_joint_update_and_downdate(B_loc)
                    pruned += batch_size
                    pruned_blocks.extend(B_loc.flatten().tolist())
                    # 2) accumulate module score  Sℓ,m
                    mb = self.mod_budget[module_name]
                    taylor = scores[B_loc].sum().to(device=mb.score.device, dtype=mb.score.dtype)
                    # Accumulate once, with correct n_blocks_pre (inside update_ratio_term)
                    self.update_ratio_term(module_name, taylor * (self.global_sensitivity ** 2))
                    # mb.accumulate(
                    #     taylor_loss=taylor * (self.global_sensitivity ** 2),
                    #     n_blocks_pre=pruner.block_size)

                    if module_name == "head":
                        self.zero_out_head(pruned_blocks)
                    elif module_name == "ffn":
                        self.zero_out_ffn(pruned_blocks)
                    
                    if progress is not None:
                        progress.update(task_pruned, advance=batch_size)

                # for i in range(num_blocks_to_prune):
                #     blk, score = pruner.prune_lowest()
                #     if progress is not None:
                #         progress.update(task_pruned, advance=1)

                    # 2) accumulate module score  Sℓ,m
                    # mb = self.mod_budget[block_name]
                    # taylor = torch.tensor(score, device=mb.score.device)
                    # self.update_ratio_term(block_name, taylor * (self.global_sensitivity ** 2))
                    # mb.accumulate(
                    #     taylor_loss=taylor * (self.global_sensitivity ** 2),
                    #     n_blocks_pre=pruner.block_size)
        # if progress is not None:
        #     progress.update(task, completed=selected_blocks.shape[0])

    @torch.inference_mode()
    def zero_out_head(self, pruned_blocks: List[int]):
        q_proj = getattr(getattr(self.layer, self.layer_type.layers['q'].module_name), self.layer_type.layers['q'].attribute)
        k_proj = getattr(getattr(self.layer, self.layer_type.layers['k'].module_name), self.layer_type.layers['k'].attribute)
        v_proj = getattr(getattr(self.layer, self.layer_type.layers['v'].module_name), self.layer_type.layers['v'].attribute)
        o_proj = getattr(getattr(self.layer, self.layer_type.layers['o'].module_name), self.layer_type.layers['o'].attribute)
        zero_heads_gqa(q_proj, k_proj, v_proj, o_proj, pruned_blocks)

    @torch.inference_mode()
    def zero_out_ffn(self, pruned_blocks: List[int]):
        up_proj = getattr(getattr(self.layer, self.layer_type.layers['fc2'].module_name), self.layer_type.layers['fc2'].attribute)
        down_proj = getattr(getattr(self.layer, self.layer_type.layers['fc3'].module_name), self.layer_type.layers['fc3'].attribute)
        gate_proj = getattr(getattr(self.layer, self.layer_type.layers['fc1'].module_name), self.layer_type.layers['fc1'].attribute, None)
        zero_ffn_units(up_proj, down_proj, pruned_blocks, gate_proj=gate_proj)

    @torch.inference_mode()
    def get_pruned_blocks(self):
        return {
            'head': self.pruners['head'].get_pruned_blocks(),
            'ffn': self.pruners['ffn'].get_pruned_blocks()  # Already local indices
        }
    
    @torch.inference_mode()
    def get_updated_configs(self) -> TransformerConfig:
        pruned_blocks = self.get_pruned_blocks()
        updated_configs = deepcopy(cast(TransformerConfig, self.layer_config))
        updated_configs.num_heads = self.layer_config.num_heads - len(pruned_blocks['head'])
        updated_configs.intermediate_dimension = self.layer_config.intermediate_dimension - len(pruned_blocks['ffn'])
        
        return updated_configs

    @torch.inference_mode()
    def get_num_params(self) -> int:
        """
        Returns the number of parameters in the layer.
        """
        config = self.get_updated_configs()
        head_params = config.num_heads * config.head_size * config.hidden_size * 4
        ffn_params = config.intermediate_dimension * config.hidden_size
        ffn_params *= (3 if self.layer_type.layers['fc3'] is not None else 2)
        return head_params + ffn_params

    @torch.inference_mode()
    def reset_pruner(self):
        """
        Reset HALPE for the next iteration of the algorithm.
        Keeps: layer, layer_config, layer_type, device, pruners, selected_blocks, mod_budget
        Resets: handle, H, local_sensitivity, global_sensitivity, num_samples
        """
        # Reset layer sensitivity metadata
        if self.handle is not None:
            self.remove_hook()
        self.H = None
        self.local_sensitivity = 0.
        self.global_sensitivity = 0.
        self.num_samples = 0
        
        # Reset module budgets
        # for budget in self.mod_budget.values():
        #     budget.reset()
        
        # Reset all pruners for next iteration
        for pruner in self.pruners.values():
            pruner.reset_pruner()
