


from tqdm.auto import tqdm
import torch
import torch.distributed as dist
import os


def gather_importance(head_importance):
    head_importance_list = [torch.zeros_like(head_importance) for _ in range(dist.get_world_size())]
    dist.all_gather(tensor_list=head_importance_list, tensor=head_importance.contiguous()) # everyone need to do this
    head_importance_list = torch.stack(head_importance_list)
    head_importance = torch.mean(head_importance_list,dim=0)
    return head_importance



def derpp_compute(train_dataloader_prune,model,buffer,args):


    buffer_path = os.path.join(args.output_dir + '../', 'buffer')

    # Init
    progress_bar = tqdm(range(len(train_dataloader_prune)))

    with torch.no_grad():
        inputs = next(iter(train_dataloader_prune)) # only one batch

        # for step, inputs in enumerate(train_dataloader_prune):
        outputs = model(inputs)

        input_ids = inputs['input_ids']
        labels = inputs['labels']
        attention_mask = inputs['attention_mask']

        buffer.add_data(examples=input_ids, labels=labels,  logits=outputs.hidden_states[-1], attention_mask=attention_mask)

        progress_bar.update(1)
        progress_bar.set_description('DERPP Compute Iter ')  # show the loss, mean while


    print('buffer size: ',buffer.get_size())
    torch.save(buffer, buffer_path)

