import uuid
import numpy as np
from PIL import Image
import random
from scipy import stats
random.seed(23)

def show_rgb_array(im, wb, sdir):
    out = np.clip(im, 0, 1)
    Image.fromarray(np.uint8(out*255)).save(sdir)

def showArray(im, wb, outdir, tp='Sony'):
    H, W, _ = im.shape
    if tp == "Sony":
        show_rgb_array(im, wb, outdir)
    elif tp == "T1pro":
        show_rgb_array(im, wb, outdir)
    else:
        print("accept type: 'T1pro', 'Sony'")

def loadJpg(name):
    im = np.array(Image.open(name)).astype(np.float32) / 255.0
    return im

def JointDistributionSampling(infs):
    aTL = infs["aTL"]
    bTL = infs["bTL"]
    unbTL = infs["unbTL"]
    ar = infs["ar"]
    br = infs["br"]
    unbr = infs["unbr"]
    K = 27
    logK = np.log(K)
    meanTL = bTL + aTL * logK
    sigTL = stats.norm.rvs(meanTL, unbTL, size=1)
    sigTL = np.exp(sigTL[0]) * 0.5  # Further reduce read noise strength
    meanr = br + ar * logK
    sigr = stats.norm.rvs(meanr, unbr, size=1)
    sigr = np.exp(sigr[0]) * 0.5  # Further reduce row noise strength
    lamb = random.choice(infs["lambs"])
    return (K, sigTL, lamb, sigr)

def addShotNoise(im, K):
    h, w, c = im.shape
    I = im * K
    noiseI = np.random.poisson(I) / K
    noise = noiseI - im  # Center the Poisson noise by subtracting the original image to preserve the mean
    return noise * 0.05  # Reduce Poisson noise strength

def addQuantizationNoise(im, q=0.0009775171065493646):
    h, w, c = im.shape
    loc = -0.5 * q
    scale = 1 / q
    noise = stats.uniform.rvs(loc=loc, scale=scale, size=(h, w, c))
    return noise * 0.005  # Further reduce quantization noise strength

def addReadNoise(im, lamb, sigTL):
    h, w, c = im.shape
    noise = stats.tukeylambda.rvs(lamb, 0, scale=sigTL, size=(h, w, c))
    return noise

def addRowNoise(im, sigr):
    h, w, c = im.shape
    noise = stats.norm.rvs(loc=0, scale=sigr, size=(h, 1, c))
    noise = noise.repeat(w, 1)
    return noise

def addNoise(src, K, lamb, sigTL, sigr, b, f):
    im = src.copy()
    stnoise = addShotNoise(im, K)
    rwnoise = addRowNoise(im, sigr)
    rdnoise = addReadNoise(im, lamb, sigTL)
    qtnoise = addQuantizationNoise(im, b)
    nsim = im + (stnoise + rdnoise + rwnoise + qtnoise) * 0.1  # Further reduce overall noise impact
    # Normalize brightness to preserve the original image mean
    nsim = nsim - np.mean(nsim - im)  # Subtract the mean shift introduced by noise
    nsim = np.clip(nsim, 0, 1)
    return nsim

def getNoisePairSony(infs, im):
    K, sigTL, lamb, sigr = JointDistributionSampling(infs)
    f = 1
    nsim = addNoise(im, K, lamb, sigTL, sigr, 6.103515625e-05, f)
    return (nsim, f)

def getNoisePairT1p(infs, im):
    K, sigTL, lamb, sigr = JointDistributionSampling(infs)
    f = 1
    nsim = addNoise(im, K, lamb, sigTL, sigr, 0.0009765625, f)
    return (nsim, f)

def showSample(ns, cl, wb=[], outDir="", tp="T1pro"):
    sname = str(uuid.uuid1())
    showArray(ns, wb, outDir + sname + "_noise.jpg", tp)
    showArray(cl, wb, outDir + sname + "_clean.jpg", tp)

def getNoisePair(infs, src, tp):
    im = src.copy()
    if tp == "T1pro":
        nsim, f = getNoisePairT1p(infs, im)
    elif tp == "Sony":
        nsim, f = getNoisePairSony(infs, im)
    return (nsim, f)