import numpy as np
import cv2

from .constants import COEF_8
from .fourier_utils import calculate_2dft, calculate_2dift


def check_image(img):
    n_dims = len(img.shape)

    if n_dims == 3:
        assert img.shape[-1] in [1, 3]
        if img.shape[-1] == 1:
            img = img[:, :, 0]
    elif n_dims == 2:
        pass
    else:
        raise Exception("def check_image(...): "
                        "error: n_dims of an image should be equal to 2 (H, W) or 3 (H, W, C),"
                        "but found: {}.".format(n_dims))
    return img


def maxmin_norm(img):
    img = img.astype(np.float32)
    return (img - np.amin(img)) / (np.amax(img) - np.amin(img))


def load(imname, normalize=False):
    img = cv2.imread(imname, -1)
    assert img.dtype == np.uint8
    img = check_image(img)

    if len(img.shape) == 3:
        img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)

    if normalize:
        img = img.astype(np.float32) / COEF_8

    return img


def save(img, imname):
    img = img.astype(np.float32)
    img = (np.clip(img, 0., 1.)*COEF_8).astype(np.uint8)
    cv2.imwrite(imname, img)


def get_corrupted_image(image,
                        lowpass_filter, ift="abs", LR=False, centre=None, radius=None,  # Gibbs
                        sigma=1e-2):  # Gauss
    assert len(image.shape) == 2
    assert image.shape[0] == image.shape[1]
    assert image.shape[0] % 2 != 0

    h0 = image.shape[0]  # initial height
    h1 = h0  # height of the corrupted image

    ft = calculate_2dft(image)
    ft_lowpass = ft * lowpass_filter
    if LR:
        ft_lowpass = ft_lowpass[(centre-radius):(centre+radius+1), (centre-radius):(centre+radius+1)]
    h1 = ft_lowpass.shape[0]
    sketch = calculate_2dift(ft_lowpass, mode=ift) * float(h1*h1) / float(h0*h0)

    if sigma > 0:
        sketch = sketch + np.random.normal(0.0, sigma, (h1, h1))

    return sketch.astype(np.float32)
