#!/usr/bin/env python
import argparse
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

def get_tensor_magnitude_mask(model, sparsity_fraction):
    print(f"Computing mask for {sparsity_fraction*100:.2f}% sparsity...")
    # flatten weight vector
    flattened = []
    for name, param in model.named_parameters():
        if param.requires_grad and "bias" not in name:
            flattened.append(param.data.abs().view(-1))
    flattened = torch.cat(flattened)

    # get index of threshold
    k = int(len(flattened) * sparsity_fraction)
    if k == 0:
        print("Sparsity fraction too low; no weights will be pruned.")
        mask_dict = {}
        for name, param in model.named_parameters():
            if param.requires_grad and "bias" not in name:
                mask_dict[name] = torch.ones_like(param.data)
        return mask_dict

    t = torch.kthvalue(flattened, k).values.item()
    print(f"Global magnitude threshold: {t:.6f}")

    mask_dict = {}
    for name, param in model.named_parameters():
        if param.requires_grad and "bias" not in name:
            mask = (param.data.abs() > t)
            mask_dict[name] = mask.to(param.data.dtype)
    return mask_dict

def apply_mask(model, mask_dict):
    with torch.no_grad():
        total_pruned = 0
        total_weights = 0
        for name, param in model.named_parameters():
            if name in mask_dict:
                mask = mask_dict[name]

                # count number of parameters before
                num_param_before = param.numel()
                total_weights += num_param_before

                # number of parameters after and number pruned
                num_param_after = mask.sum().item()
                pruned_count = num_param_before - num_param_after
                total_pruned += pruned_count

                # apply mask
                param.mul_(mask)
                print(f"Layer {name}: Pruned {int(pruned_count):d}/{int(num_param_before):d} weights.")
        print(f"Overall sparsity achieved: {100.0 * total_pruned / total_weights:.2f}%.")
    return model


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--w",
        type=str,
        required=True,
        help="Path for the original trained model (w)."
    )
    parser.add_argument(
        "--w_hat",
        type=str,
        required=True,
        help="Path for the quantized model (w_hat)."
    )
    parser.add_argument(
        "--sparsity_fraction",
        type=float,
        default=0.5,
        help="Fraction of weights to be sparisified (0.0 < fraction <= 1.0)."
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default=None,
        help="Directory to save the pruned model. If not provided, model isn't saved."
    )
    parser.add_argument(
        "--save_tokenizer",
        action='store_true',
        help="Save the tokenizer if available."
    )
    args = parser.parse_args()

    print(f"Loading reference model from {args.w}...")
    w = AutoModelForCausalLM.from_pretrained(args.w)

    print(f"Loading target model from {args.w_hat}...")
    w_hat = AutoModelForCausalLM.from_pretrained(args.w_hat)

    mask_dict = get_tensor_magnitude_mask(w, args.sparsity_fraction)
    pruned_w_hat = apply_mask(w_hat, mask_dict)

    if args.output_dir:
        os.makedirs(args.output_dir, exist_ok=True)
        pruned_w_hat.save_pretrained(args.output_dir)
        print(f"Pruned model saved to {args.output_dir}")
    
    if args.save_tokenizer:
        try:
            tokenizer = AutoTokenizer.from_pretrained(args.w)
            tokenizer.save_pretrained(args.output_dir)
            print(f"Tokenizer saved to: {args.output_dir}")
        except Exception as e:
            print(f"Error saving tokenizer to '{args.output_dir}': {e}")

if __name__ == "__main__":
    main()
