import warnings
from typing import Optional, Tuple

import numpy as np
import torch
import torch.nn as nn
from torch import Tensor

from .utils.linalg import inv_sym
from .base_obc import BaseOBC


class OBC(BaseOBC):

    def __init__(
        self, 
        layer: nn.Module, 
        rows_in_parallel: Optional[int] = None, 
        rel_damp: float = 1e-2
    ) -> None:
        super().__init__(layer, rel_damp)
        self.rows_in_parallel = rows_in_parallel or self.d_row
        self.weight_traces = None

    @torch.no_grad()
    def _prepare_row_slice(
        self, r1: int, r2: int, block_size: Optional[int] = None
    ) -> Tuple[Tensor, Tensor, Tensor, int, int, Tensor]:
        nr = r2 - r1
        # get a slice of rows
        w = self.weight[r1:r2].clone()
        # create mask of already pruned weights
        if block_size is not None:
            mask = w.reshape(w.shape[0], -1, block_size).ne(0).any(dim=-1)
            weight_mask = mask.repeat_interleave(block_size, dim=1)
        else:
            mask = w.ne(0)
            weight_mask = mask
        # get minimal number of zeros in a slice
        min_zeros = (~mask).sum(dim=1).min().item()
        # get nonzero ids
        row_ids, col_ids = torch.nonzero(~weight_mask).T
        # create N copies (d_row, d_col) -> (nr, d_col, d_col)
        H_inv = self.H.clone().expand(r2 - r1, self.d_col, self.d_col)
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            # mask rows with zeroed weights
            H_inv[row_ids, col_ids, :] = 0
            H_inv[row_ids, :, col_ids] = 0
            H_inv[row_ids, col_ids, col_ids] = 1
        # invert
        H_inv = inv_sym(H_inv)

        return w, mask, H_inv, min_zeros, nr, torch.arange(nr)

    def reset(self) -> None:
        self.weight_traces = None
        super().reset()

    # preparation
    @torch.no_grad()
    def prune(self, sparsity: float) -> None:
        super().prune(sparsity)
        d_row, d_col, device, dtype = (
            self.d_row,
            self.d_col,
            self.weight.device,
            self.weight.dtype,
        )
        # prepare losses & traces
        self.losses = torch.zeros((d_row, d_col), dtype=dtype, device=device)
        self.weight_traces = torch.zeros(
            (d_col + 1, d_row, d_col), dtype=dtype, device="cpu"
        )
        # prune batch of rows
        for r1 in range(0, d_row, self.rows_in_parallel):
            r2 = min(r1 + self.rows_in_parallel, d_row)
          # prepare weight, mask and hessian inverse
            w, mask, H_inv, min_zeros, nr, row_ids = self._prepare_row_slice(r1, r2)
            # prepare pruning traces for slice of rows
            traces = torch.zeros(
                (self.d_col + 1, nr, self.d_col), device=device, dtype=dtype
            )
            traces[:(min_zeros + 1)] = w
            # accumulated losses for a given slice of rows
            accum_losses = torch.zeros(nr, device=device, dtype=dtype)
            # prune iteratively columns
            for col in range(min_zeros + 1, d_col + 1):
                # 1) compure scores
                H_inv_d = H_inv.diagonal(dim1=-2, dim2=-1)
                scores = w ** 2 / H_inv_d
                scores[~mask] = torch.inf
                # 2) mask selection
                p_ids = scores.argmin(dim=-1)
                mask[row_ids, p_ids] = False
                # 3) update losses
                accum_losses.add_(scores[row_ids, p_ids], alpha=0.5)
                self.losses[r1 + row_ids, p_ids] = accum_losses
                # 4) weight update
                H_inv_pr = H_inv[row_ids, p_ids]
                H_inv_pd = H_inv_d[row_ids, p_ids]
                w.add_(H_inv_pr * (w[row_ids, p_ids] / H_inv_pd).unsqueeze(1), alpha=-1)
                w[~mask] = 0
                # update pruning traces
                traces[col] = w
                # do not update H_inv on the last iteration
                if col == self.d_col:
                    break
                # update hessian
                H_inv_pr.div_(torch.sqrt(H_inv_pd).unsqueeze(1))
                H_inv.baddbmm_(H_inv_pr.unsqueeze(2), H_inv_pr.unsqueeze(1), alpha=-1)
                H_inv[row_ids, p_ids, p_ids] = 1.0

            self.weight_traces[:, r1:r2, :] = traces.cpu()

        self.weight.data = self._extract_from_traces(sparsity).to(device)

    def _extract_from_traces(self, sparsity: float):
        _, topk_indices = torch.topk(
            self.losses.reshape(-1), k=int((1 - sparsity) * self.losses.numel())
        )
        # mask with 0 for pruned weights and 1 elsewhere
        sparsity_mask = torch.zeros(np.prod(self.losses.shape), dtype=torch.bool)
        # in presence of nonzero weights
        if len(topk_indices) > 0:
            sparsity_mask[topk_indices] = 1
        # reshape mask to the weight shape
        sparsity_mask = sparsity_mask.reshape(self.losses.shape)
        # count number of zeros per row
        zeros_per_row = (~sparsity_mask).sum(dim=1)
        return self._reshape_to_orig_shape(
            self.weight_traces[zeros_per_row, torch.arange(self.d_row)],
        )
