import numpy as np
from skimage.metrics import peak_signal_noise_ratio, structural_similarity

import tensorflow as tf
from tensorflow.image import total_variation

from .constants import METRICS


def calc_image_metrics(gt, inp, max_val=1.0, target_metrics=["l1", "l2", "psnr", "ssim", "tv_gt", "tv_in"]):
    assert set(target_metrics).issubset(METRICS)
    assert gt.shape == inp.shape and len(gt.shape) == 2
    assert gt.dtype == inp.dtype

    gt = gt.astype(np.float32)
    inp = inp.astype(np.float32)
    metrics = dict.fromkeys(target_metrics)

    for k, _ in metrics.items():
        if k == "l1":
            metrics[k] = np.mean(np.abs(gt - inp)) / max_val
        elif k == "l2":
            metrics[k] = np.mean((gt - inp) ** 2) / (max_val * max_val)
        elif k == "psnr":
            metrics[k] = peak_signal_noise_ratio(gt, inp, data_range=max_val)
        elif k == "ssim":
            metrics[k] = structural_similarity(gt, inp, data_range=max_val,
                                               # multichannel=True,  # deprecated
                                               channel_axis=None,  # last axis corresponds to channels
                                               win_size=11,
                                               gaussian_weights=True,
                                               K1=0.01,
                                               K2=0.03,
                                               sigma=1.5)
        elif k == "tv_gt":
            metrics[k] = total_variation(tf.convert_to_tensor(gt[..., np.newaxis] / float(max_val),
                                         dtype=tf.float32)).numpy()
        elif k == "tv_in":
            metrics[k] = total_variation(tf.convert_to_tensor(inp[..., np.newaxis] / float(max_val),
                                         dtype=tf.float32)).numpy()
    return metrics
