from typing import Dict
import torch
import torch.nn as nn
import torch.nn.functional as F
from abc import ABC, abstractmethod

class OBSLayerPrunerBase(ABC):
    def __init__(self, layer: nn.Module, layer_config: dict, block_size: int):
        self.layer = layer
        self.layer_config = layer_config
        self.block_size = block_size
        self.columns = None
        self.rows = None
        self.nsamples = 0  # Number of samples seen so far
        self.handles = []
        self.H = None  # Hessian (dict: idx -> tensor)
        self.H_inv_selected = None
        self.selected_blocks = []  # List of selected block indices (set externally)
        self.selected_indices = []  # List of selected indices (set externally)
        self.blocks_to_indices = {}  # Mapping from block index to selected indices

    def register_hooks(self):
        def forward_hook(module, inp, out):
            self.add_batch(inp[0].detach())
        handle = self.layer.register_forward_hook(forward_hook)
        self.handles.append(handle)

    def remove_hooks(self):
        for h in self.handles:
            h.remove()
        self.handles = []

    def set_selected_blocks(self, selected_indices: list):
        self.selected_blocks = sorted(selected_indices)
        self.compute_selected_block_indices()
    
    def compute_hessian_partial(self, num_samples, selected_blocks_inputs, inputs):
        # Compute the Hessian matrix for the selected blocks
        if self.H is None:
            self.H = selected_blocks_inputs.T @ inputs  # [block_size, in_features]
            self.nsamples = num_samples
        else:
            self.H *= self.nsamples / (self.nsamples + num_samples)  # scale old H
            self.nsamples += num_samples
            self.H += (selected_blocks_inputs.T @ inputs) / self.nsamples        # update with normalized new outer product
    
    def compute_hessian(self, num_samples, inputs):
        # Compute the Hessian matrix for the selected blocks
        if self.H is None:
            self.H = inputs.T @ inputs  # [block_size, in_features]
            self.nsamples = num_samples
        else:
            self.H *= self.nsamples / (self.nsamples + num_samples)  # scale old H
            self.nsamples += num_samples
            self.H += (inputs.T @ inputs) / self.nsamples        # update with normalized new outer product
    
    def compute_inverse_hessian(self) -> Dict[int, float]:
        # Compute the inverse of the Hessian matrix for the selected blocks
        if self.H is None:
            raise ValueError("Hessian matrix is not computed. Please call add_batch() first.")
    
        device = self.H.device
        C = torch.as_tensor(C, device=device)
        m = C.numel()
        d = self.H.shape[0]

        # batched solve  H X = E_C   -> columns of H^{-1}
        E_C = torch.eye(d, device=device)[:, C]        # (d, m)
        self.H_inv_selected = torch.linalg.solve(self.H, E_C)        # (d, m)

    def compute_importance_vectorized(self, W_flat: torch.Tensor) -> Dict[int, float]:
        """
        Parameters
        ----------
        W_flat : 1-D tensor  (d,)  -- flattened weight vector of *this* layer
                must follow the same indexing scheme used to build
                self.selected_indices.
        Returns
        -------
        obs_scores : dict  { block_id : OBS importance (float) }
        """
        # ---- sanity checks -------------------------------------------------
        if self.H is None:
            raise ValueError("Hessian matrix is not computed. "
                            "Call add_batch() first.")
        if self.H_inv_selected is None:
            raise ValueError("Inverse Hessian columns not computed. "
                            "Call compute_inverse_hessian() first.")

        device   = self.H.device
        C        = torch.as_tensor(self.selected_indices, device=device)  # (m,)
        B        = len(self.selected_blocks)

        # parameter-count per block  (same for every block in this layer type)
        param_per_block = len(self.blocks_to_indices[self.selected_blocks[0]])
        assert C.numel() == B * param_per_block, "index list does not match block layout"

        # ---- 1. gather weights of every candidate block --------------------
        # shape → (B , P , 1)   where P = param_per_block
        W_sel = W_flat[C].view(B, param_per_block, 1)

        # ---- 2. build the inverse sub-matrix  (m × m) ----------------------
        # H_inv_selected   =  H⁻¹ [ : , C ]
        H_inv_CC = self.H_inv_selected[C][:, C]          # (m , m)

        # ---- 3. reshape to (B,P,B,P) and take only diagonal blocks --------
        H4          = H_inv_CC.view(B, param_per_block, B, param_per_block)
        Hinv_pp     = H4[torch.arange(B, device=device), :, torch.arange(B, device=device), :]  # (B,P,P)

        # ---- 4. batched solve  (H⁻¹_pp)^{-1} · W_p  ------------------------
        sol = torch.linalg.solve(Hinv_pp, W_sel)         # (B,P,1)

        # ---- 5. OBS importance for every block  ---------------------------
        # ½ · W_pᵀ · (H⁻¹_pp)^{-1} · W_p
        importance = 0.5 * (W_sel.squeeze(-1) * sol.squeeze(-1)).sum(dim=1)  # (B,)

        # ---- 6. map back to block-ID → float ------------------------------
        obs_scores = {
            blk_id: importance[i].item()
            for i, blk_id in enumerate(self.selected_blocks)
        }
        return obs_scores
    
    def compute_importance_loop(self, W_flat: torch.Tensor       
    ) -> Dict[int, float]:
        """
        W_flat : 1-D tensor  (d,)  -- 1-D flattened weight vector of *this* layer
        Returns  { block_id : OBS importance (float) }  using a Python loop.
        Works for Linear, Conv1d, Conv2d subclasses that expose:
            self.selected_blocks       list[int]     length = B
            self.selected_indices      list[int]     length = m (= B*P)
            self.blocks_to_indices     dict{block_id -> list[int]  length P}
            self.H_inv_selected        (d × m)  tensor  = H⁻¹[:, selected_indices]
        """
        if self.H is None:
            raise ValueError("Hessian not computed. Call add_batch() first.")
        if self.H_inv_selected is None:
            raise ValueError("Inverse columns missing. Call compute_inverse_hessian() first.")

        device = self.H.device
        C      = torch.as_tensor(self.selected_indices, device=device)   # (m,)

        # parameter count per block
        P = len(self.blocks_to_indices[self.selected_blocks[0]])
        assert C.numel() == len(self.selected_blocks) * P, "index layout mismatch"

        obs_scores = {}

        # -----------------------------------------------------------------
        for blk_idx, blk_id in enumerate(self.selected_blocks):
            # slice offset for this block inside the concatenated list C
            start = blk_idx * P
            end   = start + P

            block_indices = C[start:end]                     # tensor (P,)

            # collect weights of this block
            W_p = W_flat[block_indices]                      # (P,)

            # square inverse sub-block (H⁻¹)_{p,p}
            Hinv_pp = self.H_inv_selected[block_indices][:, block_indices]  # (P,P)

            # OBS importance: ½ · W_pᵀ · (H⁻¹_pp)^{-1} · W_p
            imp = 0.5 * W_p @ torch.linalg.solve(Hinv_pp, W_p)

            obs_scores[blk_id] = imp.item()
        # -----------------------------------------------------------------
        return obs_scores

    @abstractmethod
    def add_batch(self, inputs):
        raise NotImplementedError("This method should be overridden in subclasses.")

    @abstractmethod
    def compute_selected_block_indices(self):
        raise NotImplementedError("This method should be overridden in subclasses.")
    
    @abstractmethod
    def gather_block_columns(self) -> torch.Tensor:
        """
        Gather the weight columns of the selected blocks.
        1. Return weight columns of all selected blocks stacked as
        (B*k,) 1-D tensor  OR  (B,k) 2-D if you prefer.
        Must match the ordering in 'selected_indices'.
        """
        raise NotImplementedError("This method should be overridden in subclasses.")
        
        
class OBSLinearLayerPruner(OBSLayerPrunerBase):
    def __init__(self, layer, layer_config, block_size):
        super().__init__(layer, layer_config, block_size)
        self.rows = self.layer.weight.shape[0]
        self.columns = self.layer.weight.shape[1]

    def compute_selected_block_indices(self):
        # Get the indices of the blocks to be pruned
        out_features: torch.Tensor | nn.Module = self.layer.out_features
        for block_idx in self.selected_blocks:
            start = block_idx * self.block_size
            end = min(start + self.block_size, out_features)
            self.selected_indices.extend(range(start, end))
            self.blocks_to_indices[block_idx] = (start, end)
        
    def add_batch(self, inputs):
        tmp = inputs.shape[0]
        if inputs.ndim > 2:
            inputs = inputs.view(-1, inputs.size(-1))
        inputs = inputs.double()
        
        X_p = inputs[:, self.selected_indices]  # [B, block_size]
        self.compute_hessian(tmp, X_p, inputs)
    
    def gather_block_columns(self) -> torch.Tensor:
        cols = [self.layer.weight[:, columns] for _, columns in self.blocks_to_indices.items()] # (out, k)
        # stack and reshape  →  (B*k,)
        return torch.cat(cols, dim=1).flatten()


class OBSConv1dLayerPruner(OBSLayerPrunerBase):
    def compute_selected_block_indices(self):
        out_ch, in_ch, k = self.layer.weight.shape
        W_flat = self.layer.weight.flatten(1)   # (out_ch, in_ch * k)
        stride = self.block_size * k                    # columns per block
        for idx in self.selected_blocks:
            start_idx = idx * stride
            end_idx = start_idx + stride
            self.selected_indices.extend(range(start_idx, end_idx))
            self.blocks_to_indices[idx] = (start_idx, end_idx)
        # kernel_size = self.layer.kernel_size[0]
        # for idx in self.selected_blocks:
        #     start_idx = idx * kernel_size
        #     end_idx = start_idx + self.block_size * kernel_size
        #     self.selected_indices.extend(range(start_idx, end_idx))
        #     self.blocks_to_indices[idx] = (start_idx, end_idx)

    def add_batch(self, inputs):
        # inputs: [B, C_in, L]
        tmp = inputs.shape[0]
        unfolded = F.unfold(inputs.unsqueeze(-1),
                            kernel_size=(self.layer.kernel_size[0], 1),
                            padding=(self.layer.padding[0], 0),
                            stride=(self.layer.stride[0], 1)).transpose(1, 2)
        # unfolded: [B, L', C_in * K]
        unfolded = unfolded.reshape(-1, unfolded.shape[-1])  # [B*L', C_in*K]

        self.compute_hessian(tmp, unfolded)
    
    def gather_block_columns(self) -> torch.Tensor:
        # W : (out_ch, in_ch, k)
        out_ch, in_ch, k = self.layer.weight.shape
        W_flat = self.layer.weight.flatten(1)   # (out_ch, in_ch * k)
        cols = [W_flat[:, columns] for _, columns in self.blocks_to_indices.items()] # (out, k)
        stacked = torch.cat(cols, dim=1)           # (out_ch, B*stride)
        return stacked.flatten()                   # (B*stride,) 1-D


class OBSConv2dChannelPruner(OBSLayerPrunerBase):
    def compute_selected_block_indices(self):
        kernel_size = self.layer.kernel_size[0] * self.layer.kernel_size[1]
        for idx in self.selected_blocks:
            start_idx = idx * kernel_size
            end_idx = start_idx + self.block_size * kernel_size
            self.selected_indices.extend(range(start_idx, end_idx))
            self.blocks_to_indices[idx] = (start_idx, end_idx)

    def add_batch(self, inputs):
        tmp = inputs.shape[0]
        # inputs: [B, C_in, H, W]
        unfolded: torch.Tensor = F.unfold(inputs,
                            kernel_size=self.layer.kernel_size,
                            padding=self.layer.padding,
                            stride=self.layer.stride).transpose(1, 2)
        # unfolded: [B, L, C_in*K_h*K_w]
        unfolded = unfolded.reshape(-1, unfolded.shape[-1])  # [B*L, C_in*K_h*K_w]

        X_p = inputs[:, self.selected_indices]  # [B, block_size]
        self.compute_hessian(tmp, X_p, unfolded)
    
    def gather_block_columns(self) -> torch.Tensor:
        # W : (out_ch, in_ch, k_h, k_w)
        out_ch, in_ch, k_h, k_w = self.layer.weight.shape
        spatial = k_h * k_w                     # columns per input channel

        # flatten input-channel and spatial dims  →  (out_ch, in_ch * spatial)
        W_flat = self.layer.weight.data.flatten(1)                   # shape (out_ch, in_ch * spatial)

        # stride (columns per pruning block)
        stride = self.block_size * spatial
        cols = [W_flat[:, columns] for _, columns in self.blocks_to_indices.items()]  # (out_ch, in_ch * stride)
        stacked = torch.cat(cols, dim=1)           # (out_ch, in_ch * B*stride)
        return stacked.flatten()                   # (B*stride,) 1-D


