# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import torch
import numpy as np

from typing import List, Dict, Union, Optional, Any

from torch import Tensor
from torch.nn import Module, Parameter
from torch.optim.optimizer import Optimizer


from sparseml.pytorch.sparsification.modifier import \
    ModifierProp, PyTorchModifierYAML
from sparseml.pytorch.sparsification.pruning.mask_creator import \
    PruningMaskCreator, \
    get_mask_creator_default
from sparseml.pytorch.sparsification.pruning.modifier_pruning_magnitude import (
    GMPruningModifier,
)

from sparseml.pytorch.sparsification.pruning.modifier_pruning_acdc import \
    ACDC_BasePruningModifier


from sparseml.pytorch.sparsification.pruning.scorer import PruningParamsGradScorer

__all__ = [
    "ACDC_GradSqrPruningModifier",
]

class GradSqrPruningParamsScorer(PruningParamsGradScorer):
    """
    Scores parameters based on the moving average of gradients

    :param params: list of model Parameters to track and score
    """
    def __init__(self, params: List[Parameter], moving_average_gamma: float = 0.999):
        super().__init__(params)
        self._moving_average_gamma = moving_average_gamma
        self._grad_variances = [torch.zeros_like(param) for param in self._params]
         # to be used by score_parameters
        self._masks = None 

    @torch.no_grad()
    def pre_optim_step_update(self, masks: List[Tensor]):
        """
        Update movement scores based on the current Parameter weights and gradients

        :param masks: latest masks that are applied to these parameters
        """
        if not self._is_main_proc:
            return

        self._masks = masks  # to be used by score_parameters

        for idx, param in enumerate(self._params):
            if param.grad is not None and not torch.any(param.grad.isnan()):
                self._grad_variances[idx].mul_(self._moving_average_gamma)
                self._grad_variances[idx].add_(param.grad ** 2, alpha=(1 - self._moving_average_gamma))

    def score_parameters(self) -> List[Tensor]:
        """
        :return: List of Tensors the same shapes as the given Parameters where
            each Parameter's elements are scored by their magnitude (absolute value)
        """
        if self._is_main_proc:
            for i, score in enumerate(self._grad_variances):
                score[self._masks[i] == 0] = float("-inf")

        self._broadcast_list_from_main(self._grad_variances)
        return self._grad_variances

    @torch.no_grad()
    def mask_update(self, masks: List[Tensor], mask_diffs: List[Tensor]):
        """
        Resets non main process scores after they have been recorded in the main
        process during the mask update

        :param masks: latest masks to be applied to these parameters
        :param mask_diffs: mask diff values returned by mask_difference for these
            masks that describe how these masks changed since the last update
        """
        # zero variances
        for score in self._grad_variances:
            score.mul_(0.0)


@PyTorchModifierYAML()
class ACDC_GradSqrPruningModifier(ACDC_BasePruningModifier, GMPruningModifier):
    """
    This subclass of the ACDCBasePruningModifier implements 
    AC/DC training with the GlobalMagnitudePruningModifier scorer.
    """
    def __init__(
        self,
        compression_sparsity: float,
        start_epoch: Union[int, float],
        end_epoch: Union[int, float],
        update_frequency: Union[int, float],
        params: Union[str, List[str]],
        global_sparsity: bool = True,
        leave_enabled: bool = True,
        momentum_buffer_reset: bool = True,
        mask_type: str = "unstructured",
        last_decompression_epochs: Optional[float] = None,
        last_compression_epochs: Optional[float] = None,
        dense_fraction: float = 0.5,
        moving_average_gamma: float = 0.999
    ):
        super().__init__(
            compression_sparsity=compression_sparsity,
            start_epoch=start_epoch,
            end_epoch=end_epoch,
            update_frequency=update_frequency,
            params=params,
            global_sparsity=global_sparsity,
            leave_enabled=leave_enabled,
            momentum_buffer_reset=momentum_buffer_reset,
            mask_type=mask_type,
            last_decompression_epochs=last_decompression_epochs,
            last_compression_epochs=last_compression_epochs,
            dense_fraction=dense_fraction
        )
        # init GlobalMagnitudePruningModifier modifier
        GMPruningModifier.__init__(
            self,
            init_sparsity=0.0,
            final_sparsity=compression_sparsity,
            start_epoch=start_epoch,
            end_epoch=end_epoch,
            update_frequency=update_frequency,
            global_sparsity=global_sparsity,
            params=params,
            leave_enabled=leave_enabled,
            mask_type=mask_type,
        )  
        self.moving_average_gamma = moving_average_gamma

    def _get_mask_creator(
        self, param_names: List[str], params: List[Parameter]
    ) -> PruningMaskCreator:
        """
        :param names: full names of parameters to be pruned
        :param params: list of parameters to be masked
        :return: mask creator object to be used by this pruning algorithm
        """
        return get_mask_creator_default(self.mask_type)

    def _get_scorer(self, params: List[Parameter]) -> PruningParamsGradScorer:
        """
        :param params: list of Parameters for scorer to track
        :return: param scorer object to be used by this pruning algorithm
        """
        return GradSqrPruningParamsScorer(params=params, moving_average_gamma=self.moving_average_gamma)

    @ModifierProp(serializable=False)
    def global_sparsity(self) -> bool:
        """
        :return: True for global magnitude pruning, False for
            layer-wise. [DEPRECATED] - use GlobalMagnitudePruningModifier
            for global magnitude pruning and MagnitudePruningModifier for layer-wise
        """
        return self._global_sparsity