import tensorflow as tf
import numpy as np
from typing import Optional, Callable


def read_png(filename, channels=3):
  """Loads a PNG image file."""
  string = tf.io.read_file(filename)
  return tf.image.decode_image(string, channels=channels)


def write_png(filename, image):
  """Saves an image to a PNG file."""
  string = tf.image.encode_png(image)
  tf.io.write_file(filename, string)


def quantize_image(image):
  return tf.saturate_cast(tf.round(image), tf.uint8)


def center_crop_image(image, target_height, target_width):
  # Based on https://github.com/keras-team/keras/blob/v2.10.0/keras/layers/preprocessing/image_preprocessing.py#L202
  input_shape = tf.shape(image)
  H_AXIS = -3
  W_AXIS = -2
  h_diff = input_shape[H_AXIS] - target_height
  w_diff = input_shape[W_AXIS] - target_width

  tf.debugging.assert_greater_equal(h_diff, 0)
  tf.debugging.assert_greater_equal(w_diff, 0)

  h_start = tf.cast(h_diff / 2, tf.int32)
  w_start = tf.cast(w_diff / 2, tf.int32)
  return tf.image.crop_to_bounding_box(image, h_start, w_start, target_height, target_width)


def mse_psnr(x, y, max_val=255.):
  """Compute MSE and PSNR b/w two image tensors."""
  x = tf.cast(x, tf.float32)
  y = tf.cast(y, tf.float32)

  squared_diff = tf.math.squared_difference(x, y)
  axes_except_batch = list(range(1, len(squared_diff.shape)))

  # Results have shape [batch_size]
  mses = tf.reduce_mean(tf.math.squared_difference(x, y), axis=axes_except_batch)  # per img
  # psnrs = -10 * (np.log10(mses) - 2 * np.log10(255.))
  psnrs = -10 * (tf.math.log(mses) - 2 * tf.math.log(max_val)) / tf.math.log(10.)
  return mses, psnrs


def maybe_pad_img(x, div: int, padding_mode='reflect', padding_around='center'):
  """
  Return x_padded, offset
  :param x:
  :param div:
  :param padding_mode:
  :param padding_around:
  :return: x_padded, offset; x_padded is a potentially padded version of x whose height and width are divisible by
  div, and such that, x_padded[offset[0]: (offset[0] + x_size[0]), offset[1]:(offset[1] + x_size[1])] == x
  """
  assert len(x.shape) == 3, 'must be a single RGB image'
  assert padding_mode in ('interpolate', 'reflect', 'symmetric')
  x_size = tf.shape(x)[:2]  # img size
  div = tf.constant([div, div], dtype=tf.int32)
  ratio = tf.math.ceil(x_size / div)  # say cel([768, 512] / [100, 100]) = [8, 6]
  ratio = tf.cast(ratio, tf.int32)
  padded_size = tf.multiply(ratio, div)
  if tf.reduce_all(padded_size == x_size):  # special case, no need for padding
    return x, tf.constant([0, 0], dtype=tf.int32)

  if padding_mode == 'interpolate':
    assert padding_around == 'center'

  # offset as in the top left corner of the crop; https://www.tensorflow.org/api_docs/python/tf/image/crop_to_bounding_box
  if padding_around == 'center':
    offset = tf.cast(tf.math.floor((padded_size - x_size) / 2), tf.int32)
  else:
    assert padding_around == 'bottom_right'
    offset = tf.constant([0, 0], dtype=tf.int32)

  if padding_mode == 'interpolate':
    # First expand the image to target size, then set x to be in the center
    x_padded = tf.image.resize(x, padded_size, method='bicubic', preserve_aspect_ratio=False,
                               antialias=True)
    x_padded = tf.saturate_cast(x_padded, dtype='uint8')
    x_padded = x_padded.numpy()  # to get around tf tensor not supporting assignment
    x_padded[offset[0]: (offset[0] + x_size[0]), offset[1]:(offset[1] + x_size[1])] = x
    x_padded = tf.convert_to_tensor(x_padded)
  else:  # use tf.pad implementation
    paddings = np.zeros([3, 2], dtype='int32')
    slack = padded_size - x_size  # e.g., [800, 600] - [768, 512] = [32, 88]
    if padding_around == 'center':
      paddings[0:2, 0] = np.floor(slack / 2)  # e.g., [16, 44]
      paddings[0:2, 1] = slack - np.floor(slack / 2)
    else:
      assert padding_around == 'bottom_right'
      paddings[0:2, 1] = slack
    x_padded = tf.pad(x, paddings, padding_mode)

  assert tf.reduce_all(
    x_padded[offset[0]: (offset[0] + x_size[0]), offset[1]:(offset[1] + x_size[1])] == x)
  return x_padded, offset


# TODO: consider unifying this with the above padding code, so one is the inverse of the other. Perhaps just use simple
# padding_around="bottom_right".
def reshape_spatially_as(x, y):
  """
  Crop away extraneous padding from upsampled tfc.SignalConv2D; used by the decoder for decompression.
  :param x: 3D tensor to be reshaped spatially
  :param y: target 3D tensor
  :return:  reshaped x
  """
  y_shape = tf.shape(y)
  return x[:, :y_shape[1], :y_shape[2], :]
