import torch
import torch.nn as nn

from .utils.linalg import inv_sym
from .base_obc import BaseOBC


class FastOBC(BaseOBC):

    def __init__(self, layer: nn.Module, block_size: int = 64, rel_damp: float = 1e-2) -> None:
        super().__init__(layer, rel_damp)
        self.block_size = block_size

    @torch.no_grad()
    def pruning_step(self, sparsity: float):
        d_col, block_size = self.d_col, self.block_size
        # prepare weight and Cholesky of H^{-1}
        W = self.W
        H_inv_cho = torch.linalg.cholesky(inv_sym(self.H), upper=True)
        # iterate over columns
        for c1 in range(0, d_col, block_size):
            c2 = min(c1 + block_size, d_col)
            ncols = c2 - c1 # number of columns
            W_blk = W[:, c1:c2].clone() # column-wise weight slice
            res = torch.zeros_like(W_blk)
            errs = torch.zeros_like(W_blk)
            losses_blk = torch.zeros_like(W_blk)
            H_inv_cho_blk = H_inv_cho[c1:c2, c1:c2]
            # 1) score computation
            scores = W_blk.pow(2).div_(H_inv_cho_blk.diag().reshape(1, -1).pow(2)) # the same regardless of shrinkage
            thr, _ = torch.kthvalue(scores.view(-1), round(W_blk.numel() * sparsity))
            mask = scores > thr
            # 2) iterate over block
            for i in range(ncols):
                w_ci = W_blk[:, i]
                d = H_inv_cho_blk[i, i]

                q = w_ci.clone()
                q[~mask[:, i]] = 0

                res[:, i] = q
                err = (w_ci - q).div_(d)
                losses_blk[:, i] = err.pow(2)
                W_blk[:, i:].addr_(err, H_inv_cho_blk[i, i:], alpha=-1)
                errs[:, i] = err
            # 3) update the weights after block
            W[:, c1:c2] = res
            W[:, c2:].addmm_(errs, H_inv_cho[c1:c2, c2:], alpha=-1)

        self.layer.weight.data = self._reshape_to_orig_shape(W)
