import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import kornia.geometry.transform as K_geom
import kornia.enhance as K_trans
import kornia.augmentation as K
from kornia.augmentation.container.params import ParamItem
from omegaconf.listconfig import ListConfig
import math
import random
import numpy as np
from configs.augmentations import SimclrAugmentationConfig


def get_transform(cfg, eval=False):
    if cfg.get("crop", True):
        resize = 2 ** (math.ceil(math.log2(cfg.size)))
        transform = [
            transforms.Resize(resize),
        ]
        if cfg.random_crop and not eval:
            if cfg.fixed_random_crop:
                transform.append(transforms.RandomCrop(cfg.size))
            else:
                transform.append(
                    transforms.RandomResizedCrop(cfg.size, scale=(1.0, 1.0))
                )
        else:
            transform.append(transforms.CenterCrop(cfg.size))
    else:
        if isinstance(cfg.size, ListConfig):
            transform = [
                transforms.Resize((cfg.size[0], cfg.size[1])),
            ]
        else:
            transform = [
                transforms.Resize((cfg.size, cfg.size)),
            ]
    if not cfg.get("is_tensor", False):
        transform.append(transforms.ToTensor())
    if cfg.grayscale:
        transform.append(transforms.Lambda(lambda x: x.repeat(3, 1, 1)))
    if cfg.get("random_horizontal_flip", False) and not eval:
        transform.append(transforms.RandomHorizontalFlip())
    transform = transforms.Compose(transform)
    return transform


def get_simclr_transforms(
    size=224,
    s=1,
    scale=(0.08, 1.0),
    use_blur=True,
    use_color_jitter=True,
    use_horizontal_flip=True,
    **kwargs,
):
    sizex, sizey = size, size
    if isinstance(size, ListConfig):
        sizex = size[0]
        sizey = size[1]
    augmentation = []
    if use_horizontal_flip:
        augmentation.append(K.RandomHorizontalFlip())
    if use_color_jitter != False:
        augmentation.append(
            K.ColorJitter(
                p=use_color_jitter["p"],
                saturation=use_color_jitter["saturation"] * s,
                brightness=use_color_jitter["brightness"] * s,
                contrast=use_color_jitter["contrast"] * s,
                hue=use_color_jitter["hue"] * s,
            )
        )
    augmentation += [
        K.RandomGrayscale(p=0.2),
        K.RandomResizedCrop(
            size=(sizex, sizey),
            scale=scale,
            align_corners=False,
        ),
    ]
    if use_blur:
        augmentation.append(
            K.RandomGaussianBlur(
                kernel_size=sizex // 10 + (sizex // 10 + 1) % 2,
                sigma=(0.1, 2.0),
                p=0.5,
            )
        )
    return augmentation


def get_augmentation(
    cfg,
    low_frequency_samples=None,
    size=224,
    eval=False,
    n_context_augs=None,
    name=None,
):
    if name is None:
        name = cfg.name
    elif name == "random_rotation":
        if eval:
            if cfg.max_180_degrees:
                angles = [
                    180 * i / (n_context_augs - 1) if n_context_augs != 1 else 0
                    for i in range(n_context_augs)
                ]
            else:
                angles = [
                    (360 * i / n_context_augs + cfg.angle_offset) % 360
                    for i in range(n_context_augs)
                ]
            augmentation = [
                K.ImageSequential(
                    *[
                        K.RandomRotation(
                            degrees=(angle, angle),
                            p=1,
                        )
                        for angle in angles
                    ],
                    random_apply=1,
                )
            ]
        else:
            augmentation = [
                K.RandomRotation(degrees=180, p=1),
            ]
    elif name == "simclr_augmentations":
        s = cfg.s
        scale_lb = cfg.get("scale_lb", 0.08)
        scale = (scale_lb, 1.0)
        use_blur = cfg.use_blur
        cfg_color_jitter = (
            cfg.get(
                "color_jitter_params", SimclrAugmentationConfig().color_jitter_params
            )
            if cfg.use_color_jitter
            else False
        )
        use_horizontal_flip = cfg.use_horizontal_flip
        if eval and cfg.get("csi_eval", True):
            # CSI proposes to make augmentations less strong during inference
            s = 0.5
            scale = (0.54, 1.0)
            use_blur = False
        augmentation = get_simclr_transforms(
            size=size,
            s=s,
            scale=scale,
            use_blur=use_blur,
            use_color_jitter=cfg_color_jitter,
            use_horizontal_flip=use_horizontal_flip,
        )
    elif name == "random_flip":
        augmentation = [ContextFlip()]
    elif name == "random_invert":
        augmentation = [ContextInvert()]
    elif name == "random_equalize":
        augmentation = [ContextEqualize(clahe=cfg.clahe)]
    elif name == "random_flip_invert":
        assert n_context_augs == 4
        augmentation = [ContextFlipInvert()]
    elif name == "random_flip_equalize":
        assert n_context_augs == 4
        augmentation = [ContextFlipEqualize()]
    elif name == "random_invert_equalize":
        assert n_context_augs == 4
        augmentation = [ContextInvertEqualize()]
    elif name == "random_flip_invert_equalize":
        assert n_context_augs <= 8
        augmentation = [ContextFlipInvertEqualize()]
    return augmentation


def equal_augmentation_params(params1, params2, same_on_batch=False):
    if type(params1) != type(params2):
        return False
    elif isinstance(params1, list):
        if len(params1) != len(params2):
            return False
        for i in range(len(params1)):
            if not equal_augmentation_params(
                params1[i], params2[i], same_on_batch=same_on_batch
            ):
                return False
        return True
    elif isinstance(params1, dict):
        for param in params1:
            if (
                param not in params2
                or (
                    not same_on_batch
                    and not torch.equal(params1[param], params2[param])
                )
                or (
                    same_on_batch
                    and param != "forward_input_shape"
                    and not torch.equal(params1[param][0], params2[param][0])
                )
            ):
                return False
        return True
    elif params1.name != params2.name:
        return False
    elif isinstance(params1.data, list):
        for i in range(len(params1.data)):
            if not equal_augmentation_params(
                params1.data[i], params2.data[i], same_on_batch=same_on_batch
            ):
                return False
        return True
    elif isinstance(params1.data, dict):
        if len(params1.data) != len(params2.data):
            return False
        for param in params1.data:
            if (
                param not in params2.data
                or (
                    not same_on_batch
                    and not torch.equal(params1.data[param], params2.data[param])
                )
                or (
                    same_on_batch
                    and param != "forward_input_shape"
                    and not torch.equal(params1.data[param][0], params2.data[param][0])
                )
            ):
                return False
        return True
    else:
        raise TypeError(f"Unexpected type {type(params1.data)}")


def any_equal_augmentation_params(params1, params2, same_on_batch=False):
    if type(params1) != type(params2):
        return False
    elif isinstance(params1, list):
        for i in range(min(len(params1), len(params2))):
            if not any_equal_augmentation_params(
                params1[i], params2[i], same_on_batch=same_on_batch
            ):
                return False
        return True
    elif isinstance(params1, dict):
        all_equal = True
        for param in params1:
            if param in ["batch_prob", "forward_input_shape"]:
                continue
            if (
                param not in params2
                or (
                    not same_on_batch
                    and not torch.equal(params1[param], params2[param])
                )
                or (
                    same_on_batch
                    and param != "forward_input_shape"
                    and not torch.equal(params1[param][0], params2[param][0])
                )
            ):
                all_equal = False
        return all_equal
    elif params1.name != params2.name:
        return False
    elif isinstance(params1.data, list):
        any_equal = False
        for i in range(min(len(params1.data), len(params2.data))):
            if any_equal_augmentation_params(
                params1.data[i], params2.data[i], same_on_batch=same_on_batch
            ):
                any_equal = True
        return any_equal
    elif isinstance(params1.data, dict):
        all_equal = True
        for param in params1.data:
            if param in ["batch_prob", "forward_input_shape"]:
                continue
            if (
                param not in params2.data
                or (
                    not same_on_batch
                    and not torch.equal(params1.data[param], params2.data[param])
                )
                or (
                    same_on_batch
                    and param != "forward_input_shape"
                    and not torch.equal(params1.data[param][0], params2.data[param][0])
                )
            ):
                all_equal = False
        return all_equal
    else:
        raise TypeError(f"Unexpected type {type(params1.data)}")


class ContextFlip(K.IntensityAugmentationBase2D):
    def __init__(
        self,
        same_on_batch: bool = False,
        p: float = 1.0,
        keepdim: bool = False,
    ):
        super().__init__(p=p, same_on_batch=same_on_batch, keepdim=keepdim)
        self.n_context_augs = 2
        self.vlip = K_geom.vflip

        # Kornia seems to require a param generator even if there are no random parameters
        self._param_generator = K.random_generator.PlainUniformGenerator()

    def apply_transform(self, input, params, flags, transform=None):
        # Extract labels form params
        if "labels" in params:
            labels = params["labels"].to(input.device).float()
        else:
            labels = (
                torch.randint(high=self.n_context_augs, size=(input.shape[0],))
                .to(input.device)
                .float()
            )
        # Augment input
        input_aug = input.clone()
        if input_aug[labels == 1.0].shape[0] != 0:
            input_aug[labels == 1.0] = self.vlip(input_aug[labels == 1.0])
        return input_aug


class ContextInvert(K.IntensityAugmentationBase2D):
    def __init__(
        self,
        same_on_batch: bool = False,
        p: float = 1.0,
        keepdim: bool = False,
    ):
        super().__init__(p=p, same_on_batch=same_on_batch, keepdim=keepdim)
        self.n_context_augs = 2
        self.invert = K.RandomInvert(p=1)

        # Kornia seems to require a param generator even if there are no random parameters
        self._param_generator = K.random_generator.PlainUniformGenerator()

    def apply_transform(self, input, params, flags, transform=None):
        # Extract labels form params
        if "labels" in params:
            labels = params["labels"].to(input.device).float()
        else:
            labels = (
                torch.randint(high=self.n_context_augs, size=(input.shape[0],))
                .to(input.device)
                .float()
            )
        # Augment input
        input_aug = input.clone()
        if input_aug[labels == 1.0].shape[0] != 0:
            input_aug[labels == 1.0] = self.invert(input_aug[labels == 1.0])
        return input_aug


class ContextEqualize(K.IntensityAugmentationBase2D):
    def __init__(
        self,
        clahe=False,
        same_on_batch: bool = False,
        p: float = 1.0,
        keepdim: bool = False,
    ):
        super().__init__(p=p, same_on_batch=same_on_batch, keepdim=keepdim)
        self.n_context_augs = 2

        if clahe:
            # Set CLAHE grid_size to adapt to varying image sizes and clip_limit to account for normalized images (default is 40 for images in [0, 255])
            self.equalize = lambda x: K_trans.equalize_clahe(
                x,
                grid_size=(int(math.log2(x.shape[-1])), int(math.log2(x.shape[-1]))),
                clip_limit=1.0 / 40.0,
            )
        else:
            self.equalize = K_trans.equalize

        # Kornia seems to require a param generator even if there are no random parameters
        self._param_generator = K.random_generator.PlainUniformGenerator()

    def apply_transform(self, input, params, flags, transform=None):
        # Extract labels form params
        if "labels" in params:
            labels = params["labels"].to(input.device).float()
        else:
            labels = (
                torch.randint(high=self.n_context_augs, size=(input.shape[0],))
                .to(input.device)
                .float()
            )
        # Augment input
        input_aug = input.clone()
        if input_aug[labels == 1.0].shape[0] != 0:
            input_aug[labels == 1.0] = self.equalize(input_aug[labels == 1.0])
        return input_aug


class ContextFlipInvert(K.IntensityAugmentationBase2D):
    def __init__(
        self,
        same_on_batch: bool = False,
        p: float = 1.0,
        keepdim: bool = False,
    ):
        super().__init__(p=p, same_on_batch=same_on_batch, keepdim=keepdim)
        self.n_context_augs = 4
        self.invert = K.RandomInvert(p=1)
        self.vlip = K_geom.vflip

        # Kornia seems to require a param generator even if there are no random parameters
        self._param_generator = K.random_generator.PlainUniformGenerator()

    def apply_transform(self, input, params, flags, transform=None):
        # Extract labels form params
        if "labels" in params:
            labels = params["labels"].to(input.device).float()
        else:
            labels = (
                torch.randint(high=self.n_context_augs, size=(input.shape[0],))
                .to(input.device)
                .float()
            )
        # Augment input
        input_aug = input.clone()
        if input_aug[labels == 1.0].shape[0] != 0:
            input_aug[labels == 1.0] = self.invert(input_aug[labels == 1.0])
        if input_aug[labels == 2.0].shape[0] != 0:
            input_aug[labels == 2.0] = self.vlip(input_aug[labels == 2.0])
        if input_aug[labels == 3.0].shape[0] != 0:
            input_aug[labels == 3.0] = self.vlip(self.invert(input_aug[labels == 3.0]))
        return input_aug


class ContextFlipEqualize(K.IntensityAugmentationBase2D):
    def __init__(
        self,
        same_on_batch: bool = False,
        p: float = 1.0,
        keepdim: bool = False,
    ):
        super().__init__(p=p, same_on_batch=same_on_batch, keepdim=keepdim)
        self.n_context_augs = 4
        self.equalize = K_trans.equalize
        self.vlip = K_geom.vflip

        # Kornia seems to require a param generator even if there are no random parameters
        self._param_generator = K.random_generator.PlainUniformGenerator()

    def apply_transform(self, input, params, flags, transform=None):
        # Extract labels form params
        if "labels" in params:
            labels = params["labels"].to(input.device).float()
        else:
            labels = (
                torch.randint(high=self.n_context_augs, size=(input.shape[0],))
                .to(input.device)
                .float()
            )
        # Augment input
        input_aug = input.clone()
        if input_aug[labels == 1.0].shape[0] != 0:
            input_aug[labels == 1.0] = self.equalize(input_aug[labels == 1.0])
        if input_aug[labels == 2.0].shape[0] != 0:
            input_aug[labels == 2.0] = self.vlip(input_aug[labels == 2.0])
        if input_aug[labels == 3.0].shape[0] != 0:
            input_aug[labels == 3.0] = self.vlip(
                self.equalize(input_aug[labels == 3.0])
            )
        return input_aug


class ContextInvertEqualize(K.IntensityAugmentationBase2D):
    def __init__(
        self,
        same_on_batch: bool = False,
        p: float = 1.0,
        keepdim: bool = False,
    ):
        super().__init__(p=p, same_on_batch=same_on_batch, keepdim=keepdim)
        self.n_context_augs = 4
        self.invert = K.RandomInvert(p=1)
        self.equalize = K_trans.equalize

        # Kornia seems to require a param generator even if there are no random parameters
        self._param_generator = K.random_generator.PlainUniformGenerator()

    def apply_transform(self, input, params, flags, transform=None):
        # Extract labels form params
        if "labels" in params:
            labels = params["labels"].to(input.device).float()
        else:
            labels = (
                torch.randint(high=self.n_context_augs, size=(input.shape[0],))
                .to(input.device)
                .float()
            )
        # Augment input
        input_aug = input.clone()
        if input_aug[labels == 1.0].shape[0] != 0:
            input_aug[labels == 1.0] = self.invert(input_aug[labels == 1.0])
        if input_aug[labels == 2.0].shape[0] != 0:
            input_aug[labels == 2.0] = self.equalize(input_aug[labels == 2.0])
        if input_aug[labels == 3.0].shape[0] != 0:
            input_aug[labels == 3.0] = self.invert(
                self.equalize(input_aug[labels == 3.0])
            )
        return input_aug


class ContextFlipInvertEqualize(K.IntensityAugmentationBase2D):
    def __init__(
        self,
        same_on_batch: bool = False,
        p: float = 1.0,
        keepdim: bool = False,
    ):
        super().__init__(p=p, same_on_batch=same_on_batch, keepdim=keepdim)
        self.n_context_augs = 8
        self.invert = K.RandomInvert(p=1)
        self.equalize = K_trans.equalize
        self.vlip = K_geom.vflip

        # Kornia seems to require a param generator even if there are no random parameters
        self._param_generator = K.random_generator.PlainUniformGenerator()

    def apply_transform(self, input, params, flags, transform=None):
        # Extract labels form params
        if "labels" in params:
            labels = params["labels"].to(input.device).float()
        else:
            labels = (
                torch.randint(high=self.n_context_augs, size=(input.shape[0],))
                .to(input.device)
                .float()
            )
        # Augment input
        input_aug = input.clone()
        if input_aug[labels == 1.0].shape[0] != 0:
            input_aug[labels == 1.0] = self.vlip(input_aug[labels == 1.0])
        if input_aug[labels == 2.0].shape[0] != 0:
            input_aug[labels == 2.0] = self.invert(input_aug[labels == 2.0])
        if input_aug[labels == 3.0].shape[0] != 0:
            input_aug[labels == 3.0] = self.equalize(input_aug[labels == 3.0])
        if input_aug[labels == 4.0].shape[0] != 0:
            input_aug[labels == 4.0] = self.vlip(self.invert(input_aug[labels == 4.0]))
        if input_aug[labels == 5.0].shape[0] != 0:
            input_aug[labels == 5.0] = self.vlip(
                self.equalize(input_aug[labels == 5.0])
            )
        if input_aug[labels == 6.0].shape[0] != 0:
            input_aug[labels == 6.0] = self.invert(
                self.equalize(input_aug[labels == 6.0])
            )
        if input_aug[labels == 7.0].shape[0] != 0:
            input_aug[labels == 7.0] = self.vlip(
                self.invert(self.equalize(input_aug[labels == 7.0]))
            )
        return input_aug
