from pysepm import SNRseg, fwSNRseg, pesq

class AudioMetrics:
    def __init__(self, sr=16000):
        self.sr = sr
        self.aqa = {
            'snrseg': SNRseg,
            'fwsnrseg': fwSNRseg
        }
    
    def batch_metric(self, metric, x, y):
        out = [metric(cur_x, cur_y, self.sr) for cur_x, cur_y in zip(x, y)]
        return sum(out) / len(out)

    def report(self, x, y):
        # x, y : (1, L), between 0 and 1
        assert x.max() <= 1 and 0 <= x.min(), 'audio x is not in the correct range of [0, 1]'
        assert y.max() <= 1 and 0 <= y.min(), 'audio y is not in the correct range of [0, 1]'
        assert x.ndim == 2, 'audio x is not in the correct shape (bs, L)'
        assert y.ndim == 2, 'audio y is not in the correct shape (bs, L)'
        
        out = {name: self.batch_metric(metric, x, y) for name, metric in self.aqa.items()}
        return out