class AverageMeter(object):
    
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)

def fisher(train_loader,model,device):
    gradients = []
    cfisher = None
    ctfisher = None
    use_batch = True
    model.eval()
    for x,y in train_loader:
        x = x.to(device).float() 
        y = y.to(device).long()
        output = model(x)
        if use_batch:
            loss_list = [torch.nn.functional.cross_entropy(output,y,reduction='mean')]
        else:
            loss_list = torch.nn.functional.cross_entropy(output,y,reduction='none')
        
        for single_loss in loss_list:
            gradient = torch.autograd.grad(single_loss, model.parameters(),retain_graph=True)
            with torch.no_grad():
                gradients.append([_.detach() for _ in gradient])

        with torch.no_grad():
            fs = [[a*a for a in x] for x in gradients]
            fs = [sum([x[i] for x in fs]).detach()*1.0 for i in range(len(fs[0]))]
            if cfisher is None:
                cfisher = fs
                ctfisher = [a*b for a, b in zip(fs, model.parameters())]
            else:
                cfisher = [a+b for a, b in zip(cfisher, fs)]
                ctfisher = [a+b*c for a, b, c in zip(ctfisher, fs, model.parameters())]

    size = len(train_loader) if use_batch else len(train_loader)*args.batch
    cfisher = [a.detach()/size for a in cfisher]
    ctfisher = [a.detach()/size for a in ctfisher]
    return cfisher, ctfisher
