# Copyright 2020-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     XXXX
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Count remaining (non-zero) weights in the encoder (i.e. the transformer layers).
Sparsity and remaining weights levels are equivalent: sparsity % = 100 - remaining weights %.
"""
import argparse
import os

import torch

from emmental.modules import ThresholdBinarizer, TopKBinarizer


def main(args):
    serialization_dir = args.serialization_dir
    pruning_method = args.pruning_method
    threshold = args.threshold

    st = torch.load(os.path.join(serialization_dir, "pytorch_model.bin"), map_location="cpu")

    remaining_count = 0  # Number of remaining (not pruned) params in the encoder
    encoder_count = 0  # Number of params in the encoder

    print("name".ljust(60, " "), "Remaining Weights %", "Remaining Weight")
    for name, param in st.items():
        if "encoder" not in name:
            continue

        if "mask_scores" in name:
            if pruning_method == "topK":
                mask_ones = TopKBinarizer.apply(param, threshold).sum().item()
            elif pruning_method == "sigmoied_threshold":
                mask_ones = ThresholdBinarizer.apply(param, threshold, True).sum().item()
            elif pruning_method == "l0":
                l, r = -0.1, 1.1
                s = torch.sigmoid(param)
                s_bar = s * (r - l) + l
                mask = s_bar.clamp(min=0.0, max=1.0)
                mask_ones = (mask > 0.0).sum().item()
            else:
                raise ValueError("Unknown pruning method")
            remaining_count += mask_ones
            print(name.ljust(60, " "), str(round(100 * mask_ones / param.numel(), 3)).ljust(20, " "), str(mask_ones))
        else:
            encoder_count += param.numel()
            if "bias" in name or "LayerNorm" in name:
                remaining_count += param.numel()

    print("")
    print("Remaining Weights (global) %: ", 100 * remaining_count / encoder_count)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--pruning_method",
        choices=["l0", "topK", "sigmoied_threshold"],
        type=str,
        required=True,
        help=(
            "Pruning Method (l0 = L0 regularization, topK = Movement pruning, sigmoied_threshold = Soft movement"
            " pruning)"
        ),
    )
    parser.add_argument(
        "--threshold",
        type=float,
        required=False,
        help=(
            "For `topK`, it is the level of remaining weights (in %) in the fine-pruned model."
            "For `sigmoied_threshold`, it is the threshold \tau against which the (sigmoied) scores are compared."
            "Not needed for `l0`"
        ),
    )
    parser.add_argument(
        "--serialization_dir",
        type=str,
        required=True,
        help="Folder containing the model that was previously fine-pruned",
    )

    args = parser.parse_args()

    main(args)
