
"""
Meter - Computes and stores the min, max, avg, and current values
"""
class Meter(object):
    def __init__(self):
        self.reset()
    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
        self.max = -float("inf")
        self.min = float("inf")
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
        self.max = max(self.max, val)
        self.min = min(self.min, val)

"""
Solver Stats - Fixed point solver stat recorder
"""
class SplittingMethodStats(object):
    def __init__(self):
        self.fwd_iters = Meter()
        self.fwd_time = Meter()
        self.RESID = []
        self.ERR = []
        self.dZ = []
        # backward
        self.bkwd_iters = Meter()
        self.bkwd_time = Meter()
        self.fwd_lWmax = []
        self.fwd_lWmin = []
        self.bwd_lWmax = []
        self.bwd_lWmin = []
        self.BRESID = []
        self.BERR = []
        self.dL = []
        self.dG = []
        self.dW = []
    def set_options(self,**kwargs):
        self.options = kwargs
    def reset(self):
        self.fwd_iters.reset()
        self.fwd_time.reset()
        self.bkwd_iters.reset()
        self.bkwd_time.reset()
    def report(self):
        print('Fwd iters: {:.2f}\tFwd Time: {:.4f}\tBkwd Iters: {:.2f}\tBkwd Time: {:.4f}'.format(
                self.fwd_iters.avg, self.fwd_time.avg,
                self.bkwd_iters.avg, self.bkwd_time.avg))
        print('dL: ', self.dL,
                '\ndG', self.dG,
                '\ndW', self.dW)


"""
Gradient Stats - Fixed point solver stat recorder
"""
class GradientStats(object):
    def __init__(self,model):
        self.model=model
        self.params = [key for key,val in model.named_parameters()]
        self.grads = {params:[None] for params in self.params}
        self.vals = {params:[None] for params in self.params}
        pass
    def update(self):
        for param in self.params:
            obj = self.model
            for sub_attr in param.split('.'):
                obj = getattr(obj,sub_attr)
            try:
                self.grads[param] += [obj.grad.norm().item()]
                self.vals[param] += [obj.norm().item()]
            except:
                pass
        pass
    def report(self):
        for k in self.grads:
            print(f'{k}: {self.vals[k][-1]},\t d{k} {self.grads[k][-1]}')
        pass