import torch
from  torch import nn
import numpy as np

class PSNR(nn.Module):
    """_summary_

    Args:
        max_val : the iamge resolution
    """
    def __init__(self, max_val):
        super(PSNR, self).__init__()

        # base10 = torch.log(torch.tensor(10.0))
        max_val = torch.tensor(max_val).float()

        # self.register_buffer('base10', base10)
        self.register_buffer('max_val', 20 * torch.log10(max_val) )

    def __call__(self, a, b):
        a = torch.clamp((a+1) * 127.5, 0, 255)
        b = torch.clamp((b+1) * 127.5, 0, 255)
        mse = torch.mean((a.float() - b.float()) ** 2)

        if mse == 0:
            return torch.tensor(0)

        return self.max_val - 10 * torch.log10(mse)

def MAE(a, b):
    img1 = (a + 1) * 127.5
    img2 = (b + 1) * 127.5
    return np.sum(np.abs(img1 - img2)) / np.sum(img1 + img2)
    