#!/usr/bin/env python3
# Copyright 2018 CMU and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Bertology: this script shows how you can explore the internals of the models in the library to:
    - compute the entropy of the head attentions
    - compute the importance of each head
    - prune (remove) the low importance head.
    Some parts of this script are adapted from the code of Michel et al. (http://arxiv.org/abs/1905.10650)
    which is available at https://github.com/pmichel31415/are-16-heads-really-better-than-1
"""
import argparse
import logging
import os
from datetime import datetime

import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader, SequentialSampler, Subset
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm
import torch.distributed as dist

import transformers
from transformers import (
    AutoConfig,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    GlueDataset,
    default_data_collator,
    glue_compute_metrics,
    glue_output_modes,
    glue_processors,
    set_seed,
)
from transformers.trainer_utils import is_main_process
from utils import utils
import torch.nn.functional as F
from torch.nn.utils import spectral_norm
from transformers import (
    MODEL_MAPPING,
    AdamW,
    AutoTokenizer,
    AutoConfig,
    RobertaTokenizer,
    BertTokenizer,
    DataCollatorForLanguageModeling,
    get_scheduler,
    SchedulerType,
    set_seed,
)

logger = logging.getLogger(__name__)
import torch.nn.utils.prune as prune


#Unstratural Prune the RoBERTa
# ref: https://colab.research.google.com/drive/1onydMil8ulrdPY1LDxWbr2F_oWAENEEp#scrollTo=cA9iUrv76hTw
# https://github.com/VITA-Group/BERT-Tickets
def pruning_bert(model_ori, px):
    """
    prune out RoBERTa
    note: position encoding, projection heads, layernorm statistics are not pruned.
    """
    num_transformer_blocks = 12

    parameters_to_prune = []
    for layer in range(num_transformer_blocks):
        parameters_to_prune.append((model_ori.roberta.encoder.layer[layer].attention.self.query, 'weight'))
        parameters_to_prune.append((model_ori.roberta.encoder.layer[layer].attention.self.query, 'bias'))

        parameters_to_prune.append((model_ori.roberta.encoder.layer[layer].attention.self.key, 'weight'))
        parameters_to_prune.append((model_ori.roberta.encoder.layer[layer].attention.self.key, 'bias'))

        parameters_to_prune.append((model_ori.roberta.encoder.layer[layer].attention.self.value, 'weight'))
        parameters_to_prune.append((model_ori.roberta.encoder.layer[layer].attention.self.value, 'bias'))

        parameters_to_prune.append((model_ori.roberta.encoder.layer[layer].attention.output.dense, 'weight'))
        parameters_to_prune.append((model_ori.roberta.encoder.layer[layer].attention.output.dense, 'bias'))

        parameters_to_prune.append((model_ori.roberta.encoder.layer[layer].intermediate.dense, 'weight'))
        parameters_to_prune.append((model_ori.roberta.encoder.layer[layer].intermediate.dense, 'bias'))

        parameters_to_prune.append((model_ori.roberta.encoder.layer[layer].output.dense, 'weight'))
        parameters_to_prune.append((model_ori.roberta.encoder.layer[layer].output.dense, 'bias'))


    parameters_to_prune = tuple(parameters_to_prune)

    prune.global_unstructured(
        parameters_to_prune,
        pruning_method=prune.L1Unstructured,
        amount=px,
    )


def see_weight_rate(model_ori):
    """ check a model's zero rate
    """
    num_transformer_blocks = 12


    sum_list_2, zero_sum_2 = 0, 0
    for layer in range(num_transformer_blocks):

        sum_list_2 = sum_list_2 + float(model_ori.roberta.encoder.layer[layer].attention.self.query.weight.nelement())
        zero_sum_2 = zero_sum_2 + float(torch.sum(model_ori.roberta.encoder.layer[layer].attention.self.query.weight == 0))
        sum_list_2 = sum_list_2 + float(model_ori.roberta.encoder.layer[layer].attention.self.query.bias.nelement())
        zero_sum_2 = zero_sum_2 + float(torch.sum(model_ori.roberta.encoder.layer[layer].attention.self.query.bias == 0))

        sum_list_2 = sum_list_2 + float(model_ori.roberta.encoder.layer[layer].attention.self.value.weight.nelement())
        zero_sum_2 = zero_sum_2 + float(torch.sum(model_ori.roberta.encoder.layer[layer].attention.self.value.weight == 0))
        sum_list_2 = sum_list_2 + float(model_ori.roberta.encoder.layer[layer].attention.self.value.bias.nelement())
        zero_sum_2 = zero_sum_2 + float(torch.sum(model_ori.roberta.encoder.layer[layer].attention.self.value.bias == 0))

        sum_list_2 = sum_list_2 + float(model_ori.roberta.encoder.layer[layer].attention.self.key.weight.nelement())
        zero_sum_2 = zero_sum_2 + float(torch.sum(model_ori.roberta.encoder.layer[layer].attention.self.key.weight == 0))
        sum_list_2 = sum_list_2 + float(model_ori.roberta.encoder.layer[layer].attention.self.key.bias.nelement())
        zero_sum_2 = zero_sum_2 + float(torch.sum(model_ori.roberta.encoder.layer[layer].attention.self.key.bias == 0))

        sum_list_2 = sum_list_2 + float(model_ori.roberta.encoder.layer[layer].attention.output.dense.weight.nelement())
        zero_sum_2 = zero_sum_2 + float(torch.sum(model_ori.roberta.encoder.layer[layer].attention.output.dense.weight == 0))
        sum_list_2 = sum_list_2 + float(model_ori.roberta.encoder.layer[layer].attention.output.dense.bias.nelement())
        zero_sum_2 = zero_sum_2 + float(torch.sum(model_ori.roberta.encoder.layer[layer].attention.output.dense.bias == 0))

        sum_list_2 = sum_list_2 + float(model_ori.roberta.encoder.layer[layer].intermediate.dense.weight.nelement())
        zero_sum_2 = zero_sum_2 + float(torch.sum(model_ori.roberta.encoder.layer[layer].intermediate.dense.weight == 0))
        sum_list_2 = sum_list_2 + float(model_ori.roberta.encoder.layer[layer].intermediate.dense.bias.nelement())
        zero_sum_2 = zero_sum_2 + float(torch.sum(model_ori.roberta.encoder.layer[layer].intermediate.dense.bias == 0))

        sum_list_2 = sum_list_2 + float(model_ori.roberta.encoder.layer[layer].output.dense.weight.nelement())
        zero_sum_2 = zero_sum_2 + float(torch.sum(model_ori.roberta.encoder.layer[layer].output.dense.weight == 0))
        sum_list_2 = sum_list_2 + float(model_ori.roberta.encoder.layer[layer].output.dense.bias.nelement())
        zero_sum_2 = zero_sum_2 + float(torch.sum(model_ori.roberta.encoder.layer[layer].output.dense.bias == 0))

    bert_zero_rate = 100 * zero_sum_2 / sum_list_2
    print('RoBERTa zero rate is {0:.2f}'.format(bert_zero_rate))
    return bert_zero_rate


"""As expected, the 50% pruned model's WER drop to 92.8%. 
This is why an additional re-training is necessary to recover the loss. Following previous work, we apply the pruning mask back to pre-trained initialization, 
and followed by another round of ASR finetuning. 

To do that, we first define the operation for applying existing pruning mask. 
"""

def apply_pruning_mask(model_ori, mask_dict):
    """
    apply existing pruning mask to a pre-trained wav2vec 2.0.
    """
    num_transformer_blocks = 12

    parameters_to_prune =[]
    mask_list_w, mask_list_b = [], [] # maks list for weight and bias
    for layer in range(num_transformer_blocks):
        parameters_to_prune.append(model_ori.roberta.encoder.layer[layer].attention.self.query)
        mask_list_w.append(mask_dict['roberta.encoder.layer.' + str(layer) + '.attention.self.query.weight_mask'])
        mask_list_b.append(mask_dict['roberta.encoder.layer.' + str(layer) + '.attention.self.query.bias_mask'])


        parameters_to_prune.append(model_ori.roberta.encoder.layer[layer].attention.self.key)
        mask_list_w.append(mask_dict['roberta.encoder.layer.' + str(layer) + '.attention.self.key.weight_mask'])
        mask_list_b.append(mask_dict['roberta.encoder.layer.' + str(layer) + '.attention.self.key.bias_mask'])

        parameters_to_prune.append(model_ori.roberta.encoder.layer[layer].attention.self.value)
        mask_list_w.append( mask_dict['roberta.encoder.layer.' + str(layer) + '.attention.self.value.weight_mask'])
        mask_list_b.append(mask_dict['roberta.encoder.layer.' + str(layer) + '.attention.self.value.bias_mask'])

        parameters_to_prune.append(model_ori.roberta.encoder.layer[layer].attention.output.dense)
        mask_list_w.append(mask_dict['roberta.encoder.layer.' + str(layer) + '.attention.output.dense.weight_mask'])
        mask_list_b.append(mask_dict['roberta.encoder.layer.' + str(layer) + '.attention.output.dense.bias_mask'])

        parameters_to_prune.append(model_ori.roberta.encoder.layer[layer].intermediate.dense)
        mask_list_w.append(mask_dict['roberta.encoder.layer.' + str(layer) + '.intermediate.dense.weight_mask'])
        mask_list_b.append(mask_dict['roberta.encoder.layer.' + str(layer) + '.intermediate.dense.bias_mask'])

        parameters_to_prune.append(model_ori.roberta.encoder.layer[layer].output.dense)
        mask_list_w.append(mask_dict['roberta.encoder.layer.' + str(layer) + '.output.dense.weight_mask'])
        mask_list_b.append(mask_dict['roberta.encoder.layer.' + str(layer) + '.output.dense.bias_mask'])

    for ii in range(0, len(parameters_to_prune)): # applying both weight+bias masks
        prune.CustomFromMask.apply(parameters_to_prune[ii], 'weight', mask=mask_list_w[ii])
        prune.CustomFromMask.apply(parameters_to_prune[ii], 'bias', mask=mask_list_b[ii])


def compute_magnitude_prune(args, model,accelerator=None):

    """Now, we can prune the finetuned wav2vec 2.0 at a pre-defined sparsity, say 50%. Double check indeed 50% of the model parameters are pruned.
    and masked already"""

    if accelerator is not None:
        model_ori = accelerator.unwrap_model(model)
    else:
        model_ori = model
    pruning_bert(model_ori, args.pruning_rate)
    see_weight_rate(model_ori)

    """Store the pruning mask, and optionally the pruned weights. """

    mask_dict = {}; weight_dict = {}
    model_dict = model_ori.state_dict()

    for key in model_dict.keys():
        if 'mask' in key:
            mask_dict[key] = model_dict[key]
        else:
            weight_dict[key] = model_dict[key]

    if accelerator is not None and accelerator.is_main_process:
        # weight are fixed, so mask should be the same, no need to gather
        torch.save(mask_dict,  os.path.join(args.output_dir, 'pruned_' + str(args.pruning_rate) + '_mask.pt'))
        torch.save(weight_dict, os.path.join(args.output_dir, 'pruned_' + str(args.pruning_rate) + '_weight.pt'))

    # bellow for OMP **********************
    # load pre-trained model (not the finetuned one)
    # pretrained_model,_ = utils.lookfor_model_posttrain(args)

    # apply the 50% pruning mask back to pre-traiend initialization
    # apply_pruning_mask(pretrained_model.cuda(), mask_dict)

    # double-check the pre-trained model now has 50% sparsity
    # see_weight_rate(pretrained_model)

    # Next: re-training is neede for end-task
    # bellow for OMP **********************

    return mask_dict



def unprune_bert(model,accelerator):
    """
    remove pruning forward pre-hook. This is useful when we want to tweek the learned pruned mask, which is used in PARP.
     i.e. make pruning permanent
    """
    num_transformer_blocks = 12
    model_ori = accelerator.unwrap_model(model)

    parameters_to_prune = []
    for layer in range(num_transformer_blocks):

        parameters_to_prune.append(model_ori.roberta.encoder.layer[layer].attention.self.query)
        parameters_to_prune.append(model_ori.roberta.encoder.layer[layer].attention.self.key)
        parameters_to_prune.append(model_ori.roberta.encoder.layer[layer].attention.self.value)

        parameters_to_prune.append(model_ori.roberta.encoder.layer[layer].attention.output.dense)
        parameters_to_prune.append(model_ori.roberta.encoder.layer[layer].intermediate.dense)
        parameters_to_prune.append(model_ori.roberta.encoder.layer[layer].output.dense)


    for ii in range(0, len(parameters_to_prune)):  # applying both weight+bias masks
        prune.remove(parameters_to_prune[ii], 'weight')
        prune.remove(parameters_to_prune[ii], 'bias')



# bellow prune head only -----------------------------------------

def gather_importance(head_importance):
    head_importance_list = [torch.zeros_like(head_importance) for _ in range(dist.get_world_size())]
    dist.all_gather(tensor_list=head_importance_list, tensor=head_importance.contiguous()) # everyone need to do this
    head_importance_list = torch.stack(head_importance_list)
    head_importance = torch.mean(head_importance_list,dim=0)
    return head_importance


def compute_heads_importance(args,config, model, eval_dataloader,accelerator,position,run_distill=False):

    # model.train() # train results in NAN
    n_layer, n_heads = config.num_hidden_layers, config.num_attention_heads

    intermediate_importance = torch.zeros(n_layer, config.intermediate_size).to(args.device)
    output_importance = torch.zeros(n_layer, config.hidden_size).to(args.device)
    intermediate_mask = torch.ones(n_layer, config.intermediate_size).to(args.device)
    output_mask = torch.ones(n_layer, config.hidden_size).to(args.device)
    intermediate_mask.requires_grad_(requires_grad=True)
    output_mask.requires_grad_(requires_grad=True)


    head_importance = torch.zeros(n_layer, n_heads).to(args.device)
    head_mask = torch.ones(n_layer, n_heads).to(args.device)


    head_mask.requires_grad_(requires_grad=True)

    #TODO: if we want to do .abs() for grad, we probably should not use normal standarization

    tot_tokens = 0.0

    for step, inputs in enumerate(tqdm(eval_dataloader, desc="Iteration")):
        outputs = model(inputs,head_mask=head_mask,intermediate_mask=intermediate_mask,output_mask=output_mask,prune_mdoel=True,run_distill=run_distill)
        loss = outputs.loss

        accelerator.backward(loss)
        head_importance += head_mask.grad.detach()
        intermediate_importance += intermediate_mask.grad.detach()
        output_importance += output_mask.grad.detach()

        tot_tokens += inputs["attention_mask"].float().detach().sum().data

    # Normalize
    head_importance /= tot_tokens
    intermediate_importance /= tot_tokens
    output_importance /= tot_tokens

    # Print/save matrices
    accelerator.wait_for_everyone()

    head_importance = gather_importance(head_importance)
    intermediate_importance = gather_importance(intermediate_importance)
    output_importance = gather_importance(output_importance)

    if accelerator.is_main_process:
        if args.task == 0 and position=='pre': # this is for head # save in another file
            np.save(os.path.join(args.output_dir + '../base/', "head_importance.npy"), head_importance.detach().cpu().numpy())
            np.save(os.path.join(args.output_dir + '../base/', "intermediate_importance.npy"),intermediate_importance.detach().cpu().numpy())
            np.save(os.path.join(args.output_dir + '../base/', "output_importance.npy"), output_importance.detach().cpu().numpy())
            np.savetxt(os.path.join(args.output_dir + '../base/', "output_importance.txt"), output_importance.detach().cpu().numpy(),delimiter='\t')
            np.savetxt(os.path.join(args.output_dir + '../base/', "intermediate_importance.txt"), intermediate_importance.detach().cpu().numpy(),delimiter='\t')
            np.savetxt(os.path.join(args.output_dir + '../base/', "head_importance.txt"), head_importance.detach().cpu().numpy(),delimiter='\t')


        else:
            np.save(os.path.join(args.output_dir, "head_importance.npy"), head_importance.detach().cpu().numpy())
            np.save(os.path.join(args.output_dir, "intermediate_importance.npy"),intermediate_importance.detach().cpu().numpy())
            np.save(os.path.join(args.output_dir, "output_importance.npy"), output_importance.detach().cpu().numpy())


            print('head_importance: ',head_importance)
            print('intermediate_importance: ',intermediate_importance)
            print('output_importance: ',output_importance)

    return head_importance, intermediate_importance, output_importance




