from abc import abstractmethod
from typing import Tuple

import numpy as np
import torch
import torch.nn as nn
from torch.nn.modules.conv import _ConvNd
from torch import Tensor


class BaseOBC:

    def __init__(self, layer: nn.Module, rel_damp: float = 1e-2, **kwargs) -> None:
        self.layer = layer
        self.rel_damp = rel_damp
        self.d_row, self.d_col = self._get_number_of_rows_and_cols()
        # define weight
        self.W = self.layer.weight
        # define hessian and number of collected samples
        self.H = None
        self.num_samples = 0
        # backup dtype
        self.W_dtype = self.W.dtype
        self._validate_layer()
    
    def _get_number_of_rows_and_cols(self) -> Tuple[int, int]:
        return self.layer.weight.shape[0], np.prod(self.layer.weight.shape[1:])

    def _validate_layer(self) -> None:
        # check layer type
        assert isinstance(
            self.layer, (nn.Linear, _ConvNd)
        ), "OBCUtil supports only linear and convolutional layers."

    @torch.no_grad()
    def update(self, batch: Tensor) -> None:  
        # data reshaping
        if isinstance(self.layer, nn.Linear):
            batch = batch.reshape(-1, batch.shape[-1])
        else:
            unfold = nn.Unfold(
                kernel_size=self.layer.kernel_size,
                dilation=self.layer.dilation,
                padding=self.layer.padding,
                stride=self.layer.stride,
            ) 
            # output size (batch_size, channels * \prod kernel_size, num_patches)
            batch = unfold(batch)
            batch = batch.transpose(1, 2).flatten(0, 1)
        # get batch size
        batch_size = batch.shape[0]   
        # cast data to float32 before addition
        batch = batch.to(torch.float32)
        # init hessian
        if self.H is None:
            dim = batch.shape[-1]
            self.H = torch.zeros((dim, dim), device=batch.device, dtype=torch.float32)
        # hessian update
        beta = self.num_samples / (self.num_samples + batch_size)
        alpha = 2.0 / (self.num_samples + batch_size)
        self.H.addmm_(batch.T, batch, beta=beta, alpha=alpha)
        self.num_samples += batch_size

    def reset(self) -> None:
        self.W = None
        self.H = None
        self.num_samples = 0
        torch.cuda.empty_cache()

    @torch.no_grad()
    def pruning_pre_step(self) -> None:
        """
        Preparatory step with hessian regularization and weight reshaping.
        """
        # 1) Hessian preparation
        assert self.H is not None
        # get ids of pruned channels
        pruned_ids = torch.diag(self.H) == 0
        self.H[pruned_ids, pruned_ids] = 1
        # Hessian regularization
        damp = self.rel_damp * torch.diag(self.H).mean()
        self.H.add_(torch.eye(self.H.shape[0], device=self.H.device), alpha=damp)
        # 2) Weight preparation
        self.W = self.W.clone().to(torch.float32)
        if isinstance(self.layer, _ConvNd):
            self.W = self.W.flatten(1, -1)
        self.W[:, pruned_ids] = 0

    @torch.no_grad()
    def pruning_post_step(self) -> None:
        self.reset()

    @abstractmethod
    @torch.no_grad()
    def pruning_step(self, sparsity: float) -> None:
        pass

    @torch.no_grad()
    def prune(self, sparsity: float) -> None:
        self.pruning_pre_step()
        self.pruning_step(sparsity)
        self.pruning_post_step()

    @torch.no_grad()
    def gradient_step(self, W_orig: Tensor, alpha: float = 1.0):
        assert self.H is not None
        W = self.W
        H = self.H.to(self.W_dtype)
        W_shape = W.shape

        if isinstance(self.layer, _ConvNd):
            W = W.reshape(self.d_row, self.d_col)
            W_orig = W_orig.reshape(self.d_row, self.d_col)

        grad = torch.mm(W - W_orig, H).reshape(W_shape)
        alpha_norm = alpha * (W - W_orig).norm() / grad.norm().clamp(min=1e-6)

        W.add_(grad, alpha=-alpha_norm)
        
    @torch.no_grad()
    def interpolate(self, W_orig: Tensor, alpha: float = 1.0):
        self.W.mul_(1 - alpha).add_(W_orig, alpha=alpha)

    def _reshape_to_orig_shape(self, weight: Tensor) -> Tensor:
        if isinstance(self.layer, _ConvNd):
            return weight.reshape(self.d_row, -1, *self.layer.kernel_size).to(self.W_dtype)
        else:
            return weight.reshape(self.d_row, self.d_col).to(self.W_dtype)
