import os

import numpy as np

BLIT = ("hflip", "vflip", "rotate90", "translate_int")
GEOMETRIC = ("scale", "rotate_frac", "aniso", "translate_frac")
COLOR = ("brightness", "contrast", "hue", "saturation")  # "lumaflip"
CORRUPT = ("cutout",)
AVAILABLE_AUGMENTATIONS = (
    "hflip", "vflip",
    "scale", "rotate_frac", "aniso", "translate_frac",
    "brightness", "contrast", "lumaflip", "hue", "saturation",
    "imgfilter",
)
EDM_AUGMENTATIONS = (
    "hflip", "vflip", "scale", "rotate_frac", "aniso", "translate_frac",
)
TADA_AUGMENTATIONS = (
    "scale", "rotate_frac", "aniso", "translate_frac",
    "brightness", "contrast", "hue", "saturation",
)
AUGMENTATION_SET = {
    "blit": BLIT,
    "geo": GEOMETRIC,
    "color": COLOR,
    "filter": "imgfilter",
    "corrupt": CORRUPT,
    "edm": EDM_AUGMENTATIONS,
    "tada": TADA_AUGMENTATIONS,
    "scale": "scale",
    "rotate_frac": "rotate_frac",
    "aniso": "aniso",
    "translate_frac": "translate_frac"
}


def sigmoid(x, slope=1):
    return 1.0 / (1.0 + np.exp(-slope * x))


def calibrate_snr(snr, image_size, base_image_size=64):
    # from 'simple diffusion' (https://arxiv.org/pdf/2301.11093.pdf)
    snr_calibrated = snr / (base_image_size/image_size)**2
    return snr_calibrated


def adapt_linear(
        snr, s_r, s_f, image_size, slope_r=-0.5, slope_f=10, min_strength=0.1
):
    assert s_r < s_f
    calibrated_snr = calibrate_snr(snr, image_size)
    logsnr = np.log10(calibrated_snr)

    if s_r <= logsnr <= s_f:
        mv = -0.25 * (s_r - s_f)**2
        m = (-1*min_strength/mv) * (logsnr-s_r) * (logsnr-s_f) + min_strength
        return np.minimum(m, 1)

    def linear(x, b, slope):
        return slope * (x - b)

    rough = np.maximum(linear(logsnr, s_r, slope_r) + min_strength, 0.0)
    fine = np.maximum(linear(logsnr, s_f, slope_f) + min_strength, 0.0)
    adapt_quad = rough + fine
    min_strength = np.ones_like(adapt_quad) * min_strength

    return np.minimum(adapt_quad, np.ones_like(adapt_quad))


def adapt_step(snr, r_rough, r_fine, image_size, slope_r=-0.5, min_strength=0.1):
    assert r_rough < r_fine
    calibrated_snr = calibrate_snr(snr, image_size)
    logsnr = np.log10(calibrated_snr)

    # def linear(x, b, slope):
    #     return slope * (x - b)

    # if logsnr < r_rough:
    #     adapt_snr = np.maximum(linear(logsnr, r_rough, slope_r) + min_strength, 0.0)
    #     min_strength = np.ones_like(adapt_snr) * min_strength
    #     return np.minimum(adapt_snr, np.ones_like(adapt_snr))
    # elif r_rough <= logsnr <= r_fine:
    #     mv = -0.25 * (r_rough - r_fine)**2
    #     m = (-1*min_strength/mv) * (logsnr-r_rough) * (logsnr-r_fine) + min_strength
    #     return m
    # elif logsnr > r_fine:
    #     return 1.0
    # else:
    #     raise ValueError("SNR out of bounds", snr, r_rough, r_fine, logsnr)
    if r_fine < logsnr:
        s = 1.0
    else:
        s = 0.0
    # print(logsnr, s)
    return s


def adapt_quad(snr, r_rough, r_fine, image_size, kappa=None, delta=0.1):
    do_calibrate = os.environ["CALIBRATE_SNR"]
    # do_calibrate=False
    if bool(int(do_calibrate)):
        # print("calibrate")
        calibrated_snr = calibrate_snr(snr, image_size)
        logsnr = np.log10(calibrated_snr)
    else:
        # print("no calibrate")
        logsnr = np.log10(snr)

    # kappa = -0.25 * (r_rough - r_fine)**2  # (r_rough-r_fine)^2 / 4
    # kappa = -delta / kappa

    if kappa is None:  # automatically compute the kappa parameter.
        kappa = 4.0*delta / (r_rough - r_fine)**2

    w = kappa * (logsnr-r_rough) * (logsnr-r_fine) + delta
    return np.maximum(np.minimum(w, 1), 0)


## Legacy ##
def linear(t, x1, y1, x2, y2):
    m = (y2 - y1) / (x2 - x1)
    n = y1 - m * x1
    return m * t + n


def adapt(snr, image_size, bound, min_strength=0.01, soft=True, slope=50.0):
    snr_calibrated = calibrate_snr(snr, image_size)
    logsnr = np.log10(snr_calibrated)
    # logsnr = np.log10(snr)
    bound1, bound2 = bound
    if soft:
        middle = 0.5 * (bound1 + bound2)
        if logsnr < middle:
            adapt_factor = sigmoid((logsnr - bound1 - 0.5), -slope)
        else:
            adapt_factor = sigmoid((logsnr - bound2 + 0.5), slope)
        adapt_factor = (1.0 - min_strength) * adapt_factor + min_strength
    else:
        adapt_factor = np.ones_like(logsnr)
        if bound1 < logsnr < bound2:
            adapt_factor *= min_strength
    return adapt_factor
