from skimage.metrics import structural_similarity as SS
import torch


def PSNR(original, compressed):
    mse = torch.mean((original - compressed) ** 2)
    max_pixel = 1.
    psnr = 10 * torch.log10(max_pixel ** 2 / mse)
    return psnr


def RMSE(original, compressed):
    mse = torch.mean((original - compressed) ** 2)
    return torch.sqrt(mse)

def ZeroOne(original, compressed):
    """ average zero-one loss, clips predictions to {0, 1}"""
    return torch.sum(torch.abs(original - torch.round(compressed))) / original.numel()


b_PSNR = torch.vmap(PSNR)
b_RMSE = torch.vmap(RMSE)
b_ZeroOne = torch.vmap(ZeroOne)

def get_metrics(original, compressed):
    return {
        "PSNR": PSNR(original, compressed).item(),
        "RMSE": RMSE(original, compressed).item(),
        "L1": torch.mean(torch.abs(original - compressed)).item(),
        "ZeroOne": ZeroOne(original, compressed).item(),
        "SS": SS(original.detach().cpu().numpy(), compressed.detach().cpu().numpy(), full=True, data_range=1)[0].item()
    }

def print_metrics(original, compressed):
    for k, v in get_metrics(original, compressed).items():
        print(k, v)