# Copyright (c) 2023 - present / Mediatek Research, All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Modifier classes implementing the blockwise version of the Sparse FishLeg Surgeon
pruning framework, optimized for small blocks. The algorithm is described in details
in the Optimal BERT Surgeon paper https://arxiv.org/abs/2203.07259
"""
import logging
import math
import sys
from typing import Any, Dict, List, Optional, Union

import torch
from torch import Tensor
from torch.nn import Module, Parameter

from sparseml.pytorch.sparsification.modifier import ModifierProp, PyTorchModifierYAML
from sparseml.pytorch.sparsification.pruning.mask_creator import (
    PruningMaskCreator,
    get_mask_creator_default,
)
from sparseml.pytorch.sparsification.pruning.modifier_pruning_base import (
    BaseGradualPruningModifier,
)
from sparseml.pytorch.sparsification.pruning.scorer import PruningParamsGradScorer
from sparseml.pytorch.utils import GradSampler
from sparseml.pytorch.utils.logger import BaseLogger
from sparseml.utils import interpolate


__all__ = [
    "FLSPruningModifier",
    "FLSPruningParamsScorer",
]

_LOGGER = logging.getLogger(__name__)


@PyTorchModifierYAML()
class FLSPruningModifier(BaseGradualPruningModifier):
    """
    TODO: Change the below!!!
    As described in https://arxiv.org/abs/2203.07259

    Gradually applies sparsity to a given parameter or parameters from
    init_sparsity until final_sparsity is reached over a given number of epochs.
    Uses the Optimal BERT Surgeon algorithm to prune weights based on the
    approximate second-order information of the loss function. When pruning,
    it also updates remaining weights to compensate for accuracy drops incurred
    by pruning. It follows the Optimal Brain Surgeon framework with approximations
    and optimizations to make it efficient but accurate for huge models.
    It can be used to prune other models besides BERT too.

    Naming convention with respect to the paper:
        * damp == small dampening constant 'lambda'
        * num_grads == number of gradient outer products 'm'

    Memory requirements: O(dB), where 'd' is the total number of prunable weights.
    If O(dB) can't fit on a single GPU device, pytorch DDP should be used to split
    the computational overhead equally between devices.

    Supported mask types: unstructured and block4.

    | Sample yaml:
    |   !FLSPruningModifier
    |       init_sparsity: 0.7
    |       final_sparsity: 0.9
    |       start_epoch: 2.0
    |       end_epoch: 26.0
    |       update_frequency: 4.0
    |       params: ["re:.*weight"]
    |       leave_enabled: True
    |       inter_func: cubic
    |       global_sparsity: True
    |       mask_type: unstructured
    |       num_grads: 1024
    |       damp: 1e-7
    |       grad_sampler_kwargs:
    |           batch_size: 8

    :param init_sparsity: the initial sparsity for the param to start with at
        start_epoch
    :param final_sparsity: the final sparsity for the param to end with at end_epoch.
        Can also be a Dict of final sparsity values to a list of parameters to apply
        them to. If given a Dict, then params must be set to [] and the params to
        be pruned will be read from the final_sparsity Dict
    :param start_epoch: The epoch to start the modifier at
    :param end_epoch: The epoch to end the modifier at
    :param update_frequency: The number of epochs or fraction of epochs to update at
        between start and end
    :param params: A list of full parameter names or regex patterns of names to apply
        pruning to.  Regex patterns must be specified with the prefix 're:'. __ALL__
        will match to all parameters. __ALL_PRUNABLE__ will match to all ConvNd
        and Linear layers' weights. If a sparsity to param mapping is defined by
        final_sparsity, then params should be set to []
    :param leave_enabled: True to continue masking the weights after end_epoch,
        False to stop masking. Should be set to False if exporting the result
        immediately after or doing some other prune
    :param inter_func: the type of interpolation function to use:
        [linear, cubic, inverse_cubic]
    :param mask_type: String to define type of sparsity to apply. 'unstructured'
        and 'block4' are supported. Default is 'unstructured'
    :param global_sparsity: set True to enable global pruning. If False, pruning will
        be layer-wise. Default is True
    :param num_grads: number of gradients used to calculate the Fisher approximation
    :param damp: dampening factor, default is 1e-7
    :param grad_sampler_kwargs: kwargs to override default train dataloader config
        for pruner's gradient sampling.
    :param num_recomputations: number of recomputations of the inverse Hessian
        approximation while performing one pruning step
    """

    def __init__(
        self,
        init_sparsity: float,
        final_sparsity: float,
        start_epoch: float,
        end_epoch: float,
        update_frequency: float,
        params: Union[str, List[str]],  ### list of layer names
        leave_enabled: bool = True,
        inter_func: str = "cubic",
        global_sparsity: bool = True,
        mask_type: str = "unstructured",
        damp: float = 1e-7,
        num_recomputations: int = 1,
    ):
        super().__init__(
            params=params,
            init_sparsity=init_sparsity,
            final_sparsity=final_sparsity,
            inter_func=inter_func,
            start_epoch=start_epoch,
            end_epoch=end_epoch,
            update_frequency=update_frequency,
            global_sparsity=global_sparsity,
            leave_enabled=leave_enabled,
            parent_class_kwarg_names=[],
        )
        self._mask_type = mask_type
        self._damp = damp
        self._num_recomputations = num_recomputations
        self._last_applied_sparsity = 0.0  # keep track for recomputations

        self._supported_masks = ("unstructured", "block4")

        self._validate()

    @ModifierProp()
    def mask_type(self) -> str:
        """
        :return: the mask type used
        """
        return self._mask_type

    @ModifierProp()
    def damp(self) -> float:
        """
        DON'T NEED?
        :return: dampening factor used for inverse Fisher calculation
        """
        return self._damp

    @ModifierProp()
    def num_recomputations(self) -> int:
        """
        :return: number of recomputations of the inverse Hessian approximation
            while doing one pruning step
        """
        return self._num_recomputations

    def initialize(
        self,
        module: Module,
        epoch: float = 0,
        loggers: Optional[List[BaseLogger]] = None,
        **kwargs,
    ):
        """
        Grab the layers and apply if epoch in range to control pruning for.
        Expects `grad_sampler` dict with `data_loader_builder` and `loss_function`
        to initialize GradSampler instance and optionally override data-loader's
        hyperparams with `grad_sampler_kwargs` given in the recipe.

        :param module: the PyTorch model/module to modify
        :param epoch: the epoch to initialize the modifier and module at.
            Defaults to 0 (start of the training process)
        :param loggers: optional list of loggers to log the modification process to
        :param kwargs: optional kwargs to support specific arguments
            for individual modifiers.
        """
        _LOGGER.info("Initializing FLSPruningModifier")
        named_layers_and_params = self._create_named_layers_and_params(module)
        fishleg_layers_ind = [nlp.layer for nlp in named_layers_and_params]
        self.fishleg_layers = sorted(
            set([nlp.layer for nlp in named_layers_and_params]),
            key=fishleg_layers_ind.index,
        )
        super().initialize(module, epoch, loggers, **kwargs)

    def check_mask_update(
        self, module: Module, epoch: float, steps_per_epoch: int, **kwargs
    ):
        if steps_per_epoch == 1 and not math.isinf(epoch):
            return  # not a one-shot run

        _LOGGER.info("Running FLS Pruning")
        torch.cuda.empty_cache()
        if self._scorer._is_main_proc:
            if not self._pre_step_completed:
                # do pre optim step before mask update on update steps
                self._module_masks.pre_optim_step_update()
                self._pre_step_completed = True
            self._scorer._enabled_grad_buffering = True
            to_apply_sparsities = self.get_applied_sparsity_for_epoch(
                epoch, steps_per_epoch
            )
            last_applied_sparsities = (
                self._last_applied_sparsity
                if isinstance(self._last_applied_sparsity, List)
                else [self._last_applied_sparsity] * len(to_apply_sparsities)
            )

            for i in range(1, self._num_recomputations + 1):
                recomputation_sparsity = [
                    interpolate(
                        i,
                        0,
                        self._num_recomputations,
                        start_sparsity,
                        target_sparsity,
                    )
                    for start_sparsity, target_sparsity in zip(
                        last_applied_sparsities, to_apply_sparsities
                    )
                ]
                super().check_mask_update(
                    module,
                    epoch,
                    steps_per_epoch,
                    recomputation_sparsity=recomputation_sparsity,
                )

            torch.cuda.empty_cache()
            self._scorer._enabled_grad_buffering = False
            self._last_applied_sparsity = to_apply_sparsities

    def _get_mask_creator(
        self, param_names: List[str], params: List[Parameter]
    ) -> PruningMaskCreator:
        """
        :param names: full names of parameters to be pruned
        :param params: list of Parameters to be masked
        :return: mask creator object to be used by this pruning algorithm
        """
        return get_mask_creator_default(self.mask_type)

    def _get_scorer(self, params: List[Parameter]) -> PruningParamsGradScorer:
        """
        :param params: list of Parameters for scorer to track
        :return: param scorer object to be used by this pruning algorithm
        """
        return FLSPruningParamsScorer(
            params=params,
            damp=self._damp,
            mask_type=self._mask_type,
            fishleg_layers=self.fishleg_layers,
        )

    def _validate(self):
        if not hasattr(torch.linalg, "solve"):
            raise RuntimeError(
                f"torch>=1.9 required to use {self.__class__.__name__} "
                f"found {torch.__version__}"
            )

        if isinstance(self._damp, str):  # to support 'damp: 1e-7' in the recipe
            self._damp = float(self._damp)

        if self._mask_type not in self._supported_masks:
            raise ValueError(f"{self._mask_type} mask_type not supported")


class FLSPruningParamsScorer(PruningParamsGradScorer):
    """
    TODO: Change the below!
    Scores parameters using the equations introduced in the Optimal BERT Surgeon
    to solve for the optimal weight update in the Optimal Brain Surgeon (obs)
    framework. Implements unstructured and semi-structured (block4) scoring and
    pruning.

    :param params: list of model Parameters to track and score
    :param num_grads: number of gradients used to calculate the Fisher approximation
    :param damp: dampening factor, default is 1e-7
    :param fishleg_layers: layers of the model once converted to fishleg
    """

    def __init__(
        self,
        params: List[Parameter],
        damp: float,
        mask_type: str,
        fishleg_layers,
    ):
        super().__init__(params)
        self._damp = damp
        self._mask_type = mask_type
        self._enabled_grad_buffering = False
        self._eps = torch.finfo(torch.float32).eps
        self._fishleg_layers = fishleg_layers

        # assign device to each Finv
        self._devices = []
        num_devices = torch.cuda.device_count()
        if num_devices == 0:
            self._devices = [torch.device("cpu")] * len(self._params)
        else:
            num_devices = min(num_devices, len(self._params))
            per_device = math.floor(len(self._params) / num_devices)
            for i in range(num_devices):
                self._devices += [torch.device("cuda", i)] * per_device
            remainder = len(self._params) - len(self._devices)
            if remainder > 0:
                self._devices += [self._devices[-1]] * remainder

        self._pickle_exclude_params.extend(
            [
                "_enabled_grad_buffering",
                "_devices",
            ]
        )
        self._validate()

    @torch.no_grad()
    def score_parameters(self) -> List[Tensor]:
        """
        :return: List of Tensors the same shapes as the given Parameters where
            each Parameter's elements are scored based on the blockwise fls
        """
        scores = [None] * len(self._params)
        block_finv_w = [None] * len(self._params)

        if self._is_main_proc:
            for l, layer in enumerate(self._fishleg_layers):
                if self._mask_type == "unstructured":
                    (
                        diag_w,
                        _,
                    ) = layer.diagQ()  ## vector of size numel(weight)* numel(bias)
                    d = diag_w.reshape(self._params[l].shape[::-1]).T.reshape(-1)

                    scores[l] = (
                        (self._params[l].data.view(-1) ** 2).to(self._devices[l])
                        / (2.0 * d.to(self._devices[l]) + self._eps)
                    ).view(self._params[l].shape)
                else:  # self._mask_type == "block4":
                    pass
                    # block_w = self._params[i].data.view(-1, 4)  # (d/Q, Q)
                    # block_finv = (
                    #     torch.cat(
                    #         [
                    #             finv.f_inv[:, i : i + 4, i : i + 4]
                    #             for i in range(0, finv.B, 4)
                    #         ],
                    #         dim=1,
                    #     )
                    #     .reshape((finv.d // finv.B, finv.B // 4, 4, 4))
                    #     .reshape((finv.d // 4, 4, 4))
                    # )  # (Q, d/Q, Q) -> (d/Q, Q, Q)
                    # block_finv_w[i] = torch.linalg.solve(
                    #     block_finv,
                    #     block_w,
                    # )  # (d/Q, Q)
                    # score = 0.5 * torch.einsum(
                    #     "bi,bi->b", block_w, block_finv_w[i]
                    # )  # d/Q
                    # scores[i] = (
                    #     score.unsqueeze(1)
                    #     .expand(-1, 4)
                    #     .reshape(self._params[i].data.shape)
                    # )

            # CHECK THIS IS NOT HAPPENING
            # make sure pruned ones will stay pruned
            for i, score in enumerate(scores):
                score[self._masks[i] == 0] = float("-inf")

        self._broadcast_list_from_main(scores)
        # if self._mask_type == "block4":
        #     self._broadcast_list_from_main(block_finv_w)
        #     self._block_finv_w = block_finv_w  # cache for fls weight update

        return scores

    @torch.no_grad()
    def pre_optim_step_update(self, masks: List[Tensor]):
        """
        Update the empirical inverse Fisher estimation based on the current gradients

        :param masks: latest masks that are applied to these parameters
        """

        self._masks = masks  # to be used by score_parameters

    @torch.no_grad()
    def mask_update(self, masks: List[Tensor], mask_diffs: List[Tensor]):
        """
        Apply fls weight update which zeros-out pruned weights and updates the
        remaining weights to preserve the loss.

        :param masks: latest masks to be applied to these parameters
        :param mask_diffs: mask diff values returned by mask_difference for these
            masks that describe how these masks changed since the last update
        """
        fls_updates = [None] * len(self._params)
        if self._is_main_proc:
            for l, layer in enumerate(self._fishleg_layers):
                if self._mask_type == "unstructured":
                    bias_mask = torch.zeros([mask_diffs[l].shape[0]]).to(
                        mask_diffs[l].device
                    )
                    diag_w, _ = layer.diagQ()
                    d = diag_w.reshape(self._params[l].shape[::-1]).T.reshape(-1)

                    z = ((mask_diffs[l] == -1) * self._params[l].data).view(-1) / (
                        d.to(mask_diffs[l].device) + self._eps
                    )

                    fls_updates[l] = layer.Qv(
                        v=(
                            (z).view(self._params[l].data.shape),
                            bias_mask,
                        )
                    )[0]
                    if l == 0:
                        print(d)
                        print("number of ones: ", (mask_diffs[l] == -1).sum())
                        print(
                            "diagonal smallest:{:.5f}, largest:{:.5f}".format(
                                d.min().detach().cpu().numpy(),
                                d.max().detach().cpu().numpy(),
                            )
                        )
                        print(
                            "update smallest:{:.5f}, largest:{:.5f}".format(
                                torch.abs(fls_updates[l]).min().detach().cpu().numpy(),
                                torch.abs(fls_updates[l]).max().detach().cpu().numpy(),
                            )
                        )
                        # print(fls_updates[l].max(), fls_updates[l].min(), fls_updates[l].mean())
                    ## tuple of size (numel(weight), numel(bias) )
                    # if l == 1:
                    #     print(fls_updates[l])
                    #     sys.exit()
                else:  # self._mask_type == "block4":
                    pass
                    # fls_updates[i] = (
                    #     self._finvs[i]
                    #     .mul(
                    #         self._block_finv_w[i].view(-1)
                    #         * (mask_diffs[i] == -1).view(-1).to(self._devices[i])
                    #     )
                    #     .view(param.data.shape)
                    # )

        self._broadcast_list_from_main(fls_updates)
        # apply fls update and manually zero-out pruned weights
        for i, param in enumerate(self._params):
            param.data -= fls_updates[i].to(param.data.device)
            param.data[mask_diffs[i] == -1] = 0.0

        self._finvs = None

    def _validate(self):
        pass
