"""Image restoration metrics."""
import numpy as np
from skimage.metrics import structural_similarity as ssim_fn
from skimage.metrics import peak_signal_noise_ratio as psnr_fn

def _compute_metric(metric_fn, est_images, true, data_range, metric):
  n = est_images.shape[0]
  if len(true.shape) == 3:
    # Repeat same true image across batch.
    true_images = np.tile(true, (n, 1, 1, 1))
  else:
    true_images = true

  if data_range is not None:
    data_ranges = [data_range] * n
  else:
    data_ranges = np.max(true_images, axis=(1, 2, 3)) - np.min(true_images, axis=(1, 2, 3))

  num_channels = est_images.shape[-1]
  if num_channels == 1:
    return np.array(
      [metric_fn(true_im[:, :, 0], est_im[:, :, 0], data_range=datarange) \
       for (true_im, est_im, datarange) in zip(true_images, est_images, data_ranges)])
  else:
    if metric == 'ssim':
      return np.array(
        [metric_fn(true_im, est_im, channel_axis=-1, data_range=datarange) \
        for (true_im, est_im, datarange) in zip(true_images, est_images, data_ranges)])
    else:
      return np.array(
        [metric_fn(true_im, est_im, data_range=datarange) \
        for (true_im, est_im, datarange) in zip(true_images, est_images, data_ranges)])

def compute_mse(est, true):
  """MSE of each estimated image."""
  sq_err = np.square(est - true)
  mse = np.mean(sq_err, axis=(1, 2, 3))
  return mse

def compute_ssim(est, true, data_range=None):
  """SSIM of each estimated image."""
  return _compute_metric(ssim_fn, est, true, data_range, metric='ssim')
  
def compute_psnr(est, true, data_range=None):
  """PSNR of each estimated image."""
  return _compute_metric(psnr_fn, est, true, data_range, metric='psnr')
