


from tqdm.auto import tqdm
import torch
import torch.distributed as dist
import os


def gather_importance(head_importance):
    head_importance_list = [torch.zeros_like(head_importance) for _ in range(dist.get_world_size())]
    dist.all_gather(tensor_list=head_importance_list, tensor=head_importance.contiguous()) # everyone need to do this
    head_importance_list = torch.stack(head_importance_list)
    head_importance = torch.mean(head_importance_list,dim=0)
    return head_importance



def fisher_compute(train_dataloader_prune,model,self_fisher,accelerator,args):
    fisher_path = os.path.join(args.output_dir + '../', 'fisher')

    if args.task > 0:
        fisher_old = {}
        for n, _ in model.named_parameters():
            fisher_old[n] = self_fisher[n].clone()


    # Init
    progress_bar = tqdm(range(len(train_dataloader_prune)), disable=not accelerator.is_local_main_process)

    t = args.task
    sbatch = args.per_device_train_batch_size

    fisher={}
    for n,p in model.named_parameters():
        fisher[n]=0*p.data
    # Compute
    model.train()

    # TODO: need the whold things to compute fish, which is terrible!
    for step, inputs in enumerate(train_dataloader_prune):
        model.zero_grad()

        outputs = model(inputs,self_fisher=self_fisher)

        loss = outputs.loss  # loss 1

        loss = loss / args.gradient_accumulation_steps


        # add model needs to be careful! make sure it is in parameters and please double check its gradient
        accelerator.backward(loss)  # sync
        progress_bar.update(1)
        progress_bar.set_description('EWC Fisher Compute Iter (loss=%5.3f)' % loss.item())  # show the loss, mean while
        # Get model
        for n,p in model.named_parameters():
            if p.grad is not None:
                fisher[n]+=sbatch*p.grad.data.pow(2)

    # Mean
    for n,_ in model.named_parameters():
        fisher[n]=fisher[n]/len(train_dataloader_prune)
        fisher[n]=torch.autograd.Variable(fisher[n],requires_grad=False)

    self_fisher = fisher

    if args.task > 0:
        # Watch out! We do not want to keep t models (or fisher diagonals) in memory, therefore we have to merge fisher diagonals
        for n, _ in model.named_parameters():
            self_fisher[n] = (self_fisher[n] + fisher_old[n] * args.task) / (args.task + 1)  # Checked: it is better than the other option
            # self.fisher[n]=0.5*(self.fisher[n]+fisher_old[n])

    accelerator.wait_for_everyone()

    for k,v in self_fisher.items():
        self_fisher[k] = gather_importance(self_fisher[k])

    if accelerator.is_main_process:
        torch.save(self_fisher, fisher_path)

    return fisher

