


from tqdm.auto import tqdm
import torch
import torch.distributed as dist
import os


def gather_importance(head_importance):
    head_importance_list = [torch.zeros_like(head_importance) for _ in range(dist.get_world_size())]
    dist.all_gather(tensor_list=head_importance_list, tensor=head_importance.contiguous()) # everyone need to do this
    head_importance_list = torch.stack(head_importance_list)
    head_importance = torch.mean(head_importance_list,dim=0)
    return head_importance



def hat_compute(model,accelerator,mask_pre,mask,get_view_for,args):
    mask_pre_path = os.path.join(args.output_dir + '../', 'mask_pre')
    mask_back_path = os.path.join(args.output_dir + '../', 'mask_back')
    model_ori = accelerator.unwrap_model(model)
    config = model_ori.config


    for key, value in mask.items():
        mask[key] = torch.autograd.Variable(value.data.clone(), requires_grad=False)
    if args.task == 0:
        mask_pre = mask
    else:
        for key, value in mask_pre.items():
            mask_pre[key] = torch.max(mask_pre[key], mask[key])

    # Weights mask
    mask_back = {}
    for n, p in model.named_parameters():
        vals = get_view_for(n, p, mask_pre,config)
        if vals is not None:
            mask_back[n] = 1 - vals


    accelerator.wait_for_everyone()

    for k, v in mask_pre.items():
        mask_pre[k] = gather_importance(mask_pre[k])

    for k, v in mask_back.items():
        mask_back[k] = gather_importance(mask_back[k])

    if accelerator.is_main_process:
        torch.save(mask_pre, mask_pre_path)
        torch.save(mask_back, mask_back_path)
