import os
import random

import numpy as np
import PIL.Image
import scipy
import torch
import torch.nn.functional as F
from scipy import ndimage

from tada.augment_pipes.utils import (AVAILABLE_AUGMENTATIONS,
                                      EDM_AUGMENTATIONS, TADA_AUGMENTATIONS,
                                      AUGMENTATION_SET,
                                      adapt_linear, adapt_quad, adapt_step)

#----------------------------------------------------------------------------
# Coefficients of various wavelet decomposition low-pass filters.

wavelets = {
    'haar': [0.7071067811865476, 0.7071067811865476],
    'db1':  [0.7071067811865476, 0.7071067811865476],
    'db2':  [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025],
    'db3':  [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569],
    'db4':  [-0.010597401784997278, 0.032883011666982945, 0.030841381835986965, -0.18703481171888114, -0.02798376941698385, 0.6308807679295904, 0.7148465705525415, 0.23037781330885523],
    'db5':  [0.003335725285001549, -0.012580751999015526, -0.006241490213011705, 0.07757149384006515, -0.03224486958502952, -0.24229488706619015, 0.13842814590110342, 0.7243085284385744, 0.6038292697974729, 0.160102397974125],
    'db6':  [-0.00107730108499558, 0.004777257511010651, 0.0005538422009938016, -0.031582039318031156, 0.02752286553001629, 0.09750160558707936, -0.12976686756709563, -0.22626469396516913, 0.3152503517092432, 0.7511339080215775, 0.4946238903983854, 0.11154074335008017],
    'db7':  [0.0003537138000010399, -0.0018016407039998328, 0.00042957797300470274, 0.012550998556013784, -0.01657454163101562, -0.03802993693503463, 0.0806126091510659, 0.07130921926705004, -0.22403618499416572, -0.14390600392910627, 0.4697822874053586, 0.7291320908465551, 0.39653931948230575, 0.07785205408506236],
    'db8':  [-0.00011747678400228192, 0.0006754494059985568, -0.0003917403729959771, -0.00487035299301066, 0.008746094047015655, 0.013981027917015516, -0.04408825393106472, -0.01736930100202211, 0.128747426620186, 0.00047248457399797254, -0.2840155429624281, -0.015829105256023893, 0.5853546836548691, 0.6756307362980128, 0.3128715909144659, 0.05441584224308161],
    'sym2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025],
    'sym3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569],
    'sym4': [-0.07576571478927333, -0.02963552764599851, 0.49761866763201545, 0.8037387518059161, 0.29785779560527736, -0.09921954357684722, -0.012603967262037833, 0.0322231006040427],
    'sym5': [0.027333068345077982, 0.029519490925774643, -0.039134249302383094, 0.1993975339773936, 0.7234076904024206, 0.6339789634582119, 0.01660210576452232, -0.17532808990845047, -0.021101834024758855, 0.019538882735286728],
    'sym6': [0.015404109327027373, 0.0034907120842174702, -0.11799011114819057, -0.048311742585633, 0.4910559419267466, 0.787641141030194, 0.3379294217276218, -0.07263752278646252, -0.021060292512300564, 0.04472490177066578, 0.0017677118642428036, -0.007800708325034148],
    'sym7': [0.002681814568257878, -0.0010473848886829163, -0.01263630340325193, 0.03051551316596357, 0.0678926935013727, -0.049552834937127255, 0.017441255086855827, 0.5361019170917628, 0.767764317003164, 0.2886296317515146, -0.14004724044296152, -0.10780823770381774, 0.004010244871533663, 0.010268176708511255],
    'sym8': [-0.0033824159510061256, -0.0005421323317911481, 0.03169508781149298, 0.007607487324917605, -0.1432942383508097, -0.061273359067658524, 0.4813596512583722, 0.7771857517005235, 0.3644418948353314, -0.05194583810770904, -0.027219029917056003, 0.049137179673607506, 0.003808752013890615, -0.01495225833704823, -0.0003029205147213668, 0.0018899503327594609],
}

#----------------------------------------------------------------------------
# Helpers for constructing transformation matrices.
def translate2d(tx, ty, **kwargs):
    return np.asarray(
        ([1, 0, tx],
         [0, 1, ty],
         [0, 0, 1]),
        **kwargs)


def translate3d(tx, ty, tz, **kwargs):
    return np.asarray(
        ([1, 0, 0, tx],
         [0, 1, 0, ty],
         [0, 0, 1, tz],
         [0, 0, 0, 1]),
        **kwargs)


def scale2d(sx, sy, **kwargs):
    return np.asarray(
        ([sx, 0,  0],
         [0,  sy, 0],
         [0,  0,  1]),
        **kwargs)


def scale3d(sx, sy, sz, **kwargs):
    return np.asarray(
        ([sx, 0,  0,  0],
         [0,  sy, 0,  0],
         [0,  0,  sz, 0],
         [0,  0,  0,  1]),
        **kwargs)


def rotate2d(theta, **kwargs):
    return np.asarray(
        ([np.cos(theta), np.sin(-theta), 0],
         [np.sin(theta), np.cos(theta),  0],
         [0,             0,              1]),
        **kwargs)


def rotate3d(v, theta, **kwargs):
    vx = v[..., 0]; vy = v[..., 1]; vz = v[..., 2]
    s = np.sin(theta)
    c = np.cos(theta)
    cc = 1 - c
    return np.asarray(
        ([vx*vx*cc+c,    vx*vy*cc-vz*s, vx*vz*cc+vy*s, 0],
         [vy*vx*cc+vz*s, vy*vy*cc+c,    vy*vz*cc-vx*s, 0],
         [vz*vx*cc-vy*s, vz*vy*cc+vx*s, vz*vz*cc+c,    0],
         [0,             0,             0,             1]),
        **kwargs)


def translate2d_inv(tx, ty, **kwargs):
    return translate2d(-tx, -ty, **kwargs)


def scale2d_inv(sx, sy, **kwargs):
    return scale2d(1 / sx, 1 / sy, **kwargs)


def rotate2d_inv(theta, **kwargs):
    return rotate2d(-theta, **kwargs)


def _map_augmentations(augmentations):
    augmentations = augmentations.lower().split(",")

    if augmentations[0] not in ("none", "all", "edm", "tada", "geo", "color", "filter"):
        aug_list = []
    else:
        aug_set = augmentations[0]
        augmentations = augmentations[1:]
        if aug_set == "none":
            return ()
        elif aug_set == "all":
            aug_list =  list(AVAILABLE_AUGMENTATIONS)
        elif aug_set == "edm":
            aug_list = list(EDM_AUGMENTATIONS)
        elif aug_set == "tada":
            aug_list = list(TADA_AUGMENTATIONS)
        elif aug_set == "geo":
            aug_list = AUGMENTATION_SET["geo"]
        elif aug_set == "color":
            aug_list = AUGMENTATION_SET["color"]
        elif aug_set == "filter":
            aug_list = AUGMENTATION_SET["filter"]
        elif aug_set == "scale":
            aug_list = AUGMENTATION_SET["scale"]
        elif aug_set == "rotate_frac":
            aug_list = AUGMENTATION_SET["rotate_frac"]
        elif aug_set == "aniso":
            aug_list = AUGMENTATION_SET["aniso"]
        elif aug_set == "translate_frac":
            aug_list = AUGMENTATION_SET["translate_frac"]

    def check_aug(a):
        if a not in AVAILABLE_AUGMENTATIONS:
            raise ValueError(f"Unknown augmentation: {aug}")
        return True

    for aug in augmentations:
        if aug.startswith("-"):
            check_aug(aug[1:])
            aug_list.remove(aug[1:])
        elif aug not in aug_list:
            check_aug(aug)
            aug_list.append(aug)
        else:
            print(f"Augmentation already in list: {aug}")
    return tuple(aug_list)


class TadaV2:
    def __init__(
        self, adaptive_augment, p=0.5, max_apply=1, augmentations="tada",
        translate_int_max=0.125,
        scale_std=0.2, rotate_frac_max=1, aniso_std=0.2, aniso_rotate_prob=0.5, translate_frac_std=0.125,
        brightness_std=0.2, contrast_std=0.5, hue_max=1, saturation_std=1,
        imgfilter=0.25, imgfilter_bands=[1, 1, 1, 1], imgfilter_std=1,
        cutout_size=0.5,
        *args, **kwargs,
    ):
        self.adaptive_augment   = bool(adaptive_augment)
        self.p                  = float(p)
        self.max_apply          = int(max_apply)
        self.augmentations      = _map_augmentations(augmentations)

        # Pixel blitting.
        self.translate_int_max  = float(translate_int_max)  # Range of integer translation, relative to image dimensions.

        # Geometric transformations.
        self.scale_std          = float(scale_std)          # Log2 standard deviation of isotropic scaling.
        self.rotate_frac_max    = float(rotate_frac_max)    # Range of fractional rotation, 1 = full circle.
        self.aniso_std          = float(aniso_std)          # Log2 standard deviation of anisotropic scaling.
        self.aniso_rotate_prob  = float(aniso_rotate_prob)  # Probability of doing anisotropic scaling w.r.t. rotated coordinate frame.
        self.translate_frac_std = float(translate_frac_std) # Standard deviation of frational translation, relative to image dimensions.

        # Color transformations.
        self.brightness_std     = float(brightness_std)     # Standard deviation of brightness.
        self.contrast_std       = float(contrast_std)       # Log2 standard deviation of contrast.
        self.hue_max            = float(hue_max)            # Range of hue rotation, 1 = full circle.
        self.saturation_std     = float(saturation_std)     # Log2 standard deviation of saturation.

        # Image-space filtering.
        self.imgfilter          = float(imgfilter)          # Probability multiplier for image-space filtering.
        self.imgfilter_bands    = list(imgfilter_bands)     # Probability multipliers for individual frequency bands.
        self.imgfilter_std      = float(imgfilter_std)      # Log2 standard deviation of image-space filter amplification.

        # Image-space corruptions.
        self.cutout_size        = float(cutout_size)        # Size of the cutout rectangle, relative to image dimensions.

        # Construct filter bank for image-space filtering.
        Hz_lo = np.asarray(wavelets['sym2'])            # H(z)
        Hz_hi = Hz_lo * ((-1) ** np.arange(Hz_lo.size)) # H(-z)
        Hz_lo2 = np.convolve(Hz_lo, Hz_lo[::-1]) / 2    # H(z) * H(z^-1) / 2
        Hz_hi2 = np.convolve(Hz_hi, Hz_hi[::-1]) / 2    # H(-z) * H(-z^-1) / 2
        Hz_fbank = np.eye(4, 1)                         # Bandpass(H(z), b_i)
        for i in range(1, Hz_fbank.shape[0]):
            Hz_fbank = np.dstack([Hz_fbank, np.zeros_like(Hz_fbank)]).reshape(Hz_fbank.shape[0], -1)[:, :-1]
            Hz_fbank = scipy.signal.convolve(Hz_fbank, [Hz_lo2])
            Hz_fbank[i, (Hz_fbank.shape[1] - Hz_hi2.size) // 2 : (Hz_fbank.shape[1] + Hz_hi2.size) // 2] += Hz_hi2
        self.Hz_fbank = Hz_fbank

        # TADA
        boundary = os.environ["SNR_BOUNDARY"]
        r_rough, r_fine = boundary.split(",")
        self.r_rough = float(r_rough)
        self.r_fine = float(r_fine)

    def __call__(self, image, *, t=None, snr=None):
        assert isinstance(image, PIL.Image.Image), \
            f"Invalid image type ({type(image)}). Must be PIL.Image.Image."
        image = np.asarray(image).astype(np.float64)
        image = (image / 127.5) - 1.0  # Normalize to [-1, 1].
        if np.random.rand() > self.p:
            return image, np.zeros(22, dtype=np.float32)

        H, W, C = image.shape
        num_apply = torch.randint(1, self.max_apply+1, (1,)).item()
        augmentations = random.sample(self.augmentations, num_apply)
        if self.adaptive_augment:
            r_rough, r_fine = self.r_rough, self.r_fine

            method = os.environ["ADAPT_METHOD"].lower()
            if method == "linear":
                adapt_factor = adapt_linear(snr, r_rough, r_fine, image_size=H, slope_r=-0.4, slope_f=1)
            elif method == "step":
                adapt_factor = adapt_step(snr, r_rough, r_fine, image_size=H, slope_r=-0.4)
            elif method == "quad":
                try:
                    kappa = float(os.environ["KAPPA"])
                except:
                    kappa = None
                adapt_factor = adapt_quad(snr, r_rough, r_fine, image_size=H, kappa=kappa, delta=0.1)
            else:
                raise ValueError(f"Invalid adapt method: {method}")
            # adapt_factor = adapt_ablation(snr, b1, b2, image_size=H)
        else:
            adapt_factor = 1.0

        labels = []
        # print(augmentations)

        # ---------------
        # Pixel blitting.
        # ---------------
        if "hflip" in augmentations:
            image = np.fliplr(image)
            labels.append(1.0)
        else:
            labels.append(0.0)

        if "vflip" in augmentations:
            image = np.flipud(image)
            labels.append(1.0)
        else:
            labels.append(0.0)

        if "rotate90" in augmentations:
            w = np.random.randint(1, 4)
            image = np.rot90(image, k=w)
            labels.append(w/3.0)
        else:
            labels.append(0.0)

        if "translate_int" in augmentations:
            # print("WARNING: translate_int is not implemented yet.")
            pass


        # ------------------------------------------------
        # Select parameters for geometric transformations.
        # ------------------------------------------------
        I_3 = np.eye(3)
        G_inv = I_3

        if "scale" in augmentations:
            scale_std = self.scale_std * adapt_factor
            w = np.random.randn()
            scale = np.exp2(w * scale_std)
            G_inv = G_inv @ scale2d_inv(scale, scale)
            labels.append(w)
        else:
            labels.append(0.0)

        if "rotate_frac" in augmentations:
            rotate_frac_max = self.rotate_frac_max * adapt_factor
            # rotate_frac_max = self.rotate_frac_max * adapt_factor**2
            rot_angle = (np.random.rand() * 2 - 1) * (np.pi * rotate_frac_max)
            G_inv = G_inv @ rotate2d_inv(-rot_angle)
            labels.append(np.cos(rot_angle) - 1)
            labels.append(np.sin(rot_angle))
        else:
            labels.append(0.0)
            labels.append(0.0)

        if "aniso" in augmentations:
            w = np.random.randn()
            if np.random.rand() < self.aniso_rotate_prob:
                aniso_angle = (np.random.rand() * 2 - 1) * np.pi
            else:
                aniso_angle = 0.
            aniso_std = self.aniso_std * adapt_factor

            aniso_scale = np.exp2(w * aniso_std)
            G_inv = (
                G_inv
                @ rotate2d_inv(aniso_angle)
                @ scale2d_inv(aniso_scale, 1 / aniso_scale)
                @ rotate2d_inv(-aniso_angle)
            )
            labels.append((w * np.cos(aniso_angle)))
            labels.append((w * np.sin(aniso_angle)))
        else:
            labels.append(0.0)
            labels.append(0.0)

        if "translate_frac" in augmentations:
            translate_frac_std = self.translate_frac_std * adapt_factor
            w = np.random.randn(2)
            tx = w[0]*(W * translate_frac_std)
            ty = w[1]*(H * translate_frac_std)
            G_inv = G_inv @ translate2d_inv(tx, ty)
            labels.append(w[0])
            labels.append(w[1])
        else:
            labels.append(0.0)
            labels.append(0.0)

        # ----------------------------------
        # Execute geometric transformations.
        # ----------------------------------
        if G_inv is not I_3:
            in_plane_shape = np.asarray([H, W])
            out_plane_shape = in_plane_shape
            out_center = G_inv[:2, :2] @ ((out_plane_shape - 1) / 2)
            in_center = (in_plane_shape - 1) / 2
            offset = in_center - out_center
            offset += G_inv[:2, 2]

            transformed = np.empty_like(image)
            for c in range(C):
                ndimage.affine_transform(
                    image[:, :, c],
                    matrix=G_inv[:2, :2],
                    offset=offset,
                    output=transformed[:, :, c],
                    mode="reflect",
                )
            image = transformed


        # --------------------------------------------
        # Select parameters for color transformations.
        # --------------------------------------------
        I_4 = np.eye(4)
        M = I_4
        luma_axis = np.asarray([1, 1, 1, 0]) / np.sqrt(3)

        if "brightness" in augmentations:
            brightness_std = self.brightness_std * adapt_factor
            w = np.random.randn()
            b = w * brightness_std
            M = translate3d(b, b, b) @ M
            labels.append(w)
        else:
            labels.append(0.0)

        if "contrast" in augmentations:
            contrast_std = self.contrast_std * adapt_factor
            w = np.random.randn()
            c = np.exp2(w * contrast_std)
            M = scale3d(c, c, c) @ M
            labels.append(w)
        else:
            labels.append(0.0)

        if "lumaflip" in augmentations:
            w = np.random.randint(2)
            M = (I_4 - 2 * np.outer(luma_axis, luma_axis) * w) @ M
            labels.append(w)
        else:
            labels.append(0.0)

        if "hue" in augmentations:
            hue_max = self.hue_max * adapt_factor
            w = (np.random.rand() * 2 - 1) * (np.pi * hue_max)
            M = rotate3d(luma_axis, w) @ M
            labels.append(np.cos(w) - 1)
            labels.append(np.sin(w))
        else:
            labels.append(0.0)
            labels.append(0.0)

        if "saturation" in augmentations:
            saturation_std = self.saturation_std * adapt_factor
            w = np.random.randn()
            s = np.exp2(w * saturation_std)
            M = (
                np.outer(luma_axis, luma_axis) + (I_4 - np.outer(luma_axis, luma_axis)) * s
            ) @ M
            labels.append(w)
        else:
            labels.append(0.0)

        # ------------------------------
        # Execute color transformations.
        # ------------------------------
        if M is not I_4:
            image = image.reshape([H*W, C])
            if C == 3:
                image = image @ M[:3, :3] + M[:3, 3:].T  # [HW, 3] @ [3, 3] + [1, 3]
            elif C == 1:
                M = M[:3, :].mean(axis=0, keepdims=True)  # [3, 4] -> [1, 4]
                image = image * M[:, :3].sum(axis=1, keepdims=True) + M[:, 3:]  # [HW, 1] * [1, 1] + [1, 1]
            else:
                raise ValueError("Image must be RGB (3 channels) or L (1 channel)")
            image = image.reshape([H, W, C])


        # ----------------------
        # Image-space filtering.
        # ----------------------
        if "imgfilter" in augmentations:
            num_bands = self.Hz_fbank.shape[0]
            assert len(self.imgfilter_bands) == num_bands
            expected_power = np.array([10, 1, 1, 1]) / 13  # Expected power spectrum (1/f).

            # Apply amplification for each band with probability (imgfilter * strength * band_strength).
            g = np.ones([1, num_bands]) # Global gain vector (identity).
            for i, band_strength in enumerate(self.imgfilter_bands):
                imgfilter_std = self.imgfilter_std * adapt_factor
                w = np.random.randn()
                t_i = np.exp2(w * imgfilter_std)
                if np.random.rand() < self.imgfilter * band_strength:
                    labels.append(t_i)
                else:
                    t_i = 1.0
                    labels.append(0.0)

                t = np.ones([1, num_bands])  # Temporary gain vector.
                t[0, i] = t_i  # Replace i'th element.
                t = t / np.sqrt((expected_power * np.square(t)).sum(axis=-1, keepdims=True))  # Normalize power.
                g = g * t

            # Construct combined amplification filter.
            Hz_prime = g @ self.Hz_fbank  # [1, 4]@[4, tap] -> [1, tap]
            Hz_prime = np.expand_dims(Hz_prime, axis=1)  # [1, tap] -> [1, 1, tap]
            Hz_prime = np.repeat(Hz_prime, C, axis=1)  # [1, 1, tap] -> [1, C, tap]
            Hz_prime = Hz_prime.reshape([C, 1, -1])  # [channels, 1, tap]
            Hz_prime = torch.from_numpy(Hz_prime)

            # Apply filter.
            p = self.Hz_fbank.shape[1] // 2
            tensor_image = torch.from_numpy(image.copy().transpose([2, 0, 1])).unsqueeze(0)
            tensor_image = F.pad(tensor_image, pad=[p,p,p,p], mode='reflect')
            tensor_image = F.conv2d(tensor_image, weight=Hz_prime.unsqueeze(2), groups=C)
            tensor_image = F.conv2d(tensor_image, weight=Hz_prime.unsqueeze(3), groups=C)
            tensor_image = tensor_image.reshape([C, H, W])
            image = tensor_image.permute(1, 2, 0).numpy()
        else:
            labels.append(0.0)
            labels.append(0.0)
            labels.append(0.0)
            labels.append(0.0)


        # ------------------------
        # Image-space corruptions.
        # ------------------------
        if "cutout" in augmentations:
            w = np.random.rand(2)
            cutout_size = self.cutout_size * adapt_factor
            size = w * cutout_size
            size = (round(size[0]*H), round(size[1]*W))
            mask = np.ones_like(image)
            coord_y = np.random.randint(0, H-size[0]+1)
            coord_x = np.random.randint(0, W-size[1]+1)
            mask[coord_y:coord_y+size[0], coord_x:coord_x+size[1], :] = 0
            image = image * mask
            labels.append(w[0])
            labels.append(w[1])
        else:
            labels.append(0.0)
            labels.append(0.0)

        image = np.clip(image, -1.0, 1.0)
        labels = np.asarray(labels, dtype=np.float32)
        return image.copy(), labels
