import torch


import argparse
from Pruning_Config.loader_and_pruning import get_global_FFN_mask_by_distribution, get_global_FFN_mask_by_norm

def get_args_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument("--name", default="parameter-efficient fine-tuning")
    parser.add_argument("--dataset_name", default="cifar")
    parser.add_argument("--model_type", default="ViT-B_16")
    parser.add_argument("--dataset_dir", default="/home/datasets/vtab-1k/vtab-1k/")
    parser.add_argument("--pretrained_dir", type=str, default="ViT-B_16.npz")  # imagenet21k_
    parser.add_argument("--pruning_type", default="distribution", choices=["distribution", "norm", "importance"])
    parser.add_argument("--rate", default=0.2, type=float)
    args = parser.parse_args()

    return args

def analyze_ffn_pruning_config_(type, args):
    save_path = f"Pruning_Config/{type}_{args.dataset_name}_rate{str(args.rate)}_global.pt"
    mask = torch.load(save_path).cpu()
    print("mask size total number : ", mask.shape[0] * mask.shape[1])

    zeros_per_row = (mask == 0).sum(dim=1)
    print(zeros_per_row)
    print(zeros_per_row.sum())
    return mask

if __name__ == '__main__':
    args = get_args_parser()
    print("start analyze norm mask :")
    norm_mask = analyze_ffn_pruning_config_("norm", args)
    print("start analyze distribution mask :")
    dis_mask = analyze_ffn_pruning_config_("distribution", args)
    #
    merged = norm_mask * dis_mask
    zeros_per_row = (merged == 0).sum(dim=1)
    print(zeros_per_row)
    print(zeros_per_row.sum())

    save_path = f"Pruning_Config/merged_{args.dataset_name}_rate{str(args.rate)}_global.pt"
    torch.save(merged, save_path)