import os
import math
import random
import numpy as np
import torch
import cv2
from hashlib import sha256


# ----------------------------------------
# get uint8 image of size HxWxn_channles (RGB)
# ----------------------------------------
def imread_uint(path, n_channels=3):
    #  input: path
    # output: HxWx3(RGB or GGG), or HxWx1 (G)
    if n_channels == 1:
        img = cv2.imread(path, 0)  # cv2.IMREAD_GRAYSCALE
        img = np.expand_dims(img, axis=2)  # HxWx1
    elif n_channels == 3:
        img = cv2.imread(path, cv2.IMREAD_UNCHANGED)  # BGR or G
        if img.ndim == 2:
            img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)  # GGG
        else:
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # RGB
    return img


def imsave(img, img_path):
    img = np.squeeze(img)
    if img.ndim == 3:
        img = img[:, :, [2, 1, 0]]
    cv2.imwrite(img_path, img)


# convert uint (HxWxn_channels) to 4-dimensional torch tensor
def uint2tensor4(img, data_range):
    if img.ndim == 2:
        img = np.expand_dims(img, axis=2)
    return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255./data_range).unsqueeze(0)


# convert torch tensor to uint
def tensor2uint(img, data_range):
    img = img.data.squeeze().float().clamp_(0, 1*data_range).cpu().numpy()
    if img.ndim == 3:
        img = np.transpose(img, (1, 2, 0))
    return np.uint8((img*255.0/data_range).round())


def add_noise(img_gt, noise_level, img_name):
    noise_level = noise_level / 255
    img_name = os.path.basename(img_name)
    seed = np.frombuffer(sha256(img_name.encode("utf-8")).digest(), dtype="uint32")
    rstate = np.random.RandomState(seed)
    noise = rstate.normal(0, noise_level, img_gt.shape)
    return noise

