from typing import Literal
class MovingAvg():
    def __init__(self):
        self.mean=None
        self.num=0
    def update(self,update_mean,update_num):
        if self.mean is None:
            self.mean=update_mean
            self.num=update_num
        else:
            self.mean=(self.mean*self.num+update_mean*update_num)/(self.num+update_num)
            self.num=self.num+update_num

    def __repr__(self):
        return f"mean:{self.mean}, num:{self.num}"
    
    def get_sum(self):
        return self.mean*self.num
    
def getDataloaderAvg(dataloader,func,funcReduction:Literal["mean","sum"]):        
    value=MovingAvg()
    for batch_idx, (input_args,output_args) in enumerate(dataloader):
        tempValue=func(input_args,output_args)
        if funcReduction=="sum":
            value.update(tempValue/len(input_args[0]),len(input_args[0]))
        if funcReduction=="mean":
            value.update(tempValue,len(input_args[0]))
    return value

def composefunc(f1,f2):
    return lambda *args:f1(f2(*args))