import abc
import logging
import typing

import jax
import jax.numpy as jnp
import numpy as np
import torch

import architectures
import eval.data as data
import unrolled_canaries


class BaselineCanaryGenerator(object, metaclass=abc.ABCMeta):
    @abc.abstractmethod
    def generate(
        self,
        num_canaries: int,
        image_shape: tuple[int, int, int],
        num_classes: int,
        replaced_images: torch.Tensor,
        replaced_targets: torch.Tensor,
        global_seed: int,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        pass


class RandomCanaryGenerator(BaselineCanaryGenerator):
    def generate(
        self,
        num_canaries: int,
        image_shape: tuple[int, int, int],
        num_classes: int,
        replaced_images: torch.Tensor,
        replaced_targets: torch.Tensor,
        global_seed: int,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        rng = np.random.default_rng(global_seed)

        results = []
        for _ in range(num_canaries):
            (rng_current,) = rng.spawn(1)

            image = rng_current.integers(0, 256, size=image_shape)
            target = rng_current.integers(0, num_classes)

            del rng_current

            results.append((image, target))

        canaries = torch.from_numpy(np.stack([image for image, _ in results])).to(torch.float32) / 255.0
        targets = torch.tensor([target for _, target in results])

        return canaries, targets


class InDistributionCanaryGenerator(BaselineCanaryGenerator):
    def generate(
        self,
        num_canaries: int,
        image_shape: tuple[int, int, int],
        num_classes: int,
        replaced_images: torch.Tensor,
        replaced_targets: torch.Tensor,
        global_seed: int,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        if num_canaries > replaced_images.shape[0]:
            raise ValueError(
                f"Requested more canaries ({num_canaries}) than the available number of samples ({replaced_images.shape[0]})"
            )

        rng = np.random.default_rng(global_seed)

        canary_indices = rng.choice(replaced_images.shape[0], size=num_canaries, replace=False)
        canaries = replaced_images[canary_indices]
        targets = replaced_targets[canary_indices]

        return canaries, targets


class MislabeledCanaryGenerator(BaselineCanaryGenerator):
    def generate(
        self,
        num_canaries: int,
        image_shape: tuple[int, int, int],
        num_classes: int,
        replaced_images: torch.Tensor,
        replaced_targets: torch.Tensor,
        global_seed: int,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        if num_canaries > replaced_images.shape[0]:
            raise ValueError(
                f"Requested more canaries ({num_canaries}) than the available number of samples ({replaced_images.shape[0]})"
            )

        rng = np.random.default_rng(global_seed)
        rng_indices, rng_targets = rng.spawn(2)

        canary_indices = rng_indices.choice(replaced_images.shape[0], size=num_canaries, replace=False)
        canaries = replaced_images[canary_indices]
        original_targets = replaced_targets[canary_indices]

        # Randomly change targets
        changed_targets = torch.from_numpy(rng_targets.integers(0, num_classes - 1, size=num_canaries))
        targets = torch.where(changed_targets < original_targets, changed_targets, changed_targets + 1)
        assert torch.all((targets >= 0) & (targets < num_classes))
        assert torch.all(targets != original_targets)

        return canaries, targets


class AdversarialCanaryGenerator(BaselineCanaryGenerator):
    def __init__(
        self,
        learning_rate: float,
        momentum: float,
        num_epochs: int,
        batch_size: int,
        standardize: bool,
        sample_non_canaries: bool,
        mlp_width: int,
        attack_eps: float,
        attack_norm,
        attack_num_models: int,
        attack_aggregation: typing.Literal["mean", "min"],
        dataset_loader: data.DatasetLoader,
    ):
        self.learning_rate = learning_rate
        self.momentum = momentum
        self.num_epochs = num_epochs
        self.batch_size = batch_size
        self.standardize = standardize
        self.sample_non_canaries = sample_non_canaries
        self.attack_eps = attack_eps
        self.attack_num_models = attack_num_models
        self.attack_aggregation = attack_aggregation
        self.dataset_loader = dataset_loader

        self.architecture = architectures.MLP(width=mlp_width)

        if attack_norm == "linf":
            self.attack_norm = jnp.inf
        elif attack_norm == "l2":
            self.attack_norm = 2
        else:
            raise ValueError(f"Invalid attack norm: {attack_norm}")

        self._log = logging.getLogger(self.__class__.__name__)

    def generate(
        self,
        num_canaries: int,
        image_shape: tuple[int, int, int],
        num_classes: int,
        replaced_images: torch.Tensor,
        replaced_targets: torch.Tensor,
        global_seed: int,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        # Load raw data
        self._log.info("Loading raw data")
        self.dataset_loader.prepare_raw_data()
        train_images_full, train_targets_full = self.dataset_loader.load_train_data()
        test_images, test_targets = self.dataset_loader.load_val_data()

        # Train ensemble of victim models
        # Train a model
        key = jax.random.PRNGKey(global_seed)

        # Prepare images
        images_mean_std = self.dataset_loader.dataset_mean_std
        dataset_mean = jnp.array(images_mean_std[0], dtype=jnp.float32)
        dataset_std = jnp.array(images_mean_std[1], dtype=jnp.float32)

        def standardize_images(images):
            return (images - dataset_mean.reshape(1, 1, 1, -1)) / dataset_std.reshape(1, 1, 1, -1)

        def prepare_data(images, targets, standardize=True):
            # Move training data to jax
            # Code expects (N, H, W, C) in [0, 1], but images are torch (N, C, H, W)
            images = jnp.array(images.permute(0, 2, 3, 1))
            if targets is not None:
                targets = jnp.array(targets, dtype=jnp.int16)
            else:
                targets = None

            # Standardize with given statistics
            if standardize:
                images = standardize_images(images)

            return images, targets

        images_train, targets_train = prepare_data(train_images_full, train_targets_full, standardize=self.standardize)

        # Initialize models
        key, *keys_init = jax.random.split(key, self.attack_num_models + 1)
        image_shape_hwc = images_train.shape[1:]
        init_state = unrolled_canaries.create_train_state(
            keys_init,
            learning_rate=self.learning_rate,
            momentum=self.momentum,
            num_models=self.attack_num_models,
            arch=self.architecture,
            image_shape=image_shape_hwc,
            dp_params=None,
        )
        del keys_init

        # Build epoch permutations
        key, key_perms = jax.random.split(key)
        epoch_perms = unrolled_canaries.generate_model_perms(
            key=key_perms,
            num_models=self.attack_num_models,
            num_epochs=self.num_epochs,
            num_samples=images_train.shape[0],
            batch_size=self.batch_size,
            sample_non_canaries=self.sample_non_canaries,
            canary_idx=None,
        )
        del key_perms

        # Train
        trained_state, _, _, _ = unrolled_canaries.train_model(
            state=init_state,
            epoch_perms=epoch_perms,
            train_images=images_train,
            train_targets=targets_train,
            test_images=None,
            test_targets=None,
            use_dp=False,
            verbose=False,
        )

        # Generate untargeted adversarial examples for all test images
        # All inputs to the adversarial attack are unstandardized to no mess up eps; things will get standardized on the fly
        images_test, targets_test = prepare_data(test_images, test_targets, standardize=False)

        def model_fn(images):
            images = standardize_images(images)
            return unrolled_canaries.flatten_over_models(unrolled_canaries.get_logits(trained_state, images))

        # Find adversarial examples for every test image (untargeted, ensemble)
        images_adversarial = self.fast_gradient_method_ensemble(
            model_fn,
            x=images_test,
            eps=self.attack_eps,
            norm=self.attack_norm,
            clip_min=0.0,
            clip_max=1.0,
            aggregation=self.attack_aggregation,
        )

        # For every model x sample pair, determine if the adversarial example has a flipped label
        logits_original = unrolled_canaries.flatten_over_models(
            unrolled_canaries.get_logits(trained_state, standardize_images(images_test))
        )
        preds_original = jnp.argmax(logits_original, axis=-1)

        logits_adversarial = unrolled_canaries.flatten_over_models(
            unrolled_canaries.get_logits(trained_state, standardize_images(images_adversarial))
        )
        assert logits_adversarial.shape == logits_original.shape
        preds_adversarial = jnp.argmax(logits_adversarial, axis=-1)
        assert preds_adversarial.shape == preds_original.shape

        test_accuracy_original = jnp.mean(preds_original == targets_test)
        test_accuracy_adversarial = jnp.mean(preds_adversarial == targets_test)

        self._log.info("Test accuracy original: %.4f", test_accuracy_original)
        self._log.info("Test accuracy adversarial: %.4f", test_accuracy_adversarial)

        preds_mismatch = preds_adversarial != preds_original
        intersection_indices = jnp.argwhere(jnp.all(preds_mismatch, axis=0))[:, 0]

        self._log.info("Found %d test images that are adversarial for all models", len(intersection_indices))

        # Convert images back to expected format
        selected_images = images_adversarial[intersection_indices]
        # No need to unstandardize, because we never standardized the adversarial images

        # Convert to torch float32 and NHWC to NCHW
        selected_images = torch.from_numpy(np.array(selected_images)).to(torch.float32)
        selected_images = selected_images.permute((0, 3, 1, 2))
        assert selected_images.shape == (len(intersection_indices), *train_images_full.shape[1:])
        assert selected_images.min() >= 0 and selected_images.max() <= 1
        assert isinstance(selected_images, torch.Tensor)

        # Use original clean labels as targets for the canary
        selected_targets = test_targets[np.array(intersection_indices)]
        assert isinstance(selected_targets, torch.Tensor)
        assert selected_targets.shape == (len(intersection_indices),)
        assert selected_targets.min() >= 0 and selected_targets.max() < num_classes
        assert selected_targets.dtype == torch.int64

        # Finally, select a subset of the canaries
        key, key_canaries = jax.random.split(key)
        canary_indices = np.array(
            jax.random.choice(
                key_canaries,
                len(selected_targets),
                shape=(num_canaries,),
                replace=False,
            )
        )
        canaries = selected_images[canary_indices]
        targets = selected_targets[canary_indices]

        return canaries, targets

    @classmethod
    def fast_gradient_method_ensemble(
        cls,
        model_fn,
        x,
        eps,
        norm,
        clip_min=None,
        clip_max=None,
        y=None,
        targeted=False,
        aggregation: typing.Literal["mean", "min"] = "min",
    ):
        # TODO: This is from cleverhans (modified to work with ensembles); include MIT license!
        """
        JAX implementation of the Fast Gradient Method with added ensemble support.
        :param model_fn: a callable that takes an input tensor and returns the model logits.
        :param x: input tensor.
        :param eps: epsilon (input variation parameter); see https://arxiv.org/abs/1412.6572.
        :param norm: Order of the norm (mimics NumPy). Possible values: np.inf or 2.
        :param clip_min: (optional) float. Minimum float value for adversarial example components.
        :param clip_max: (optional) float. Maximum float value for adversarial example components.
        :param y: (optional) Tensor with one-hot true labels. If targeted is true, then provide the
                target one-hot label. Otherwise, only provide this parameter if you'd like to use true
                labels when crafting adversarial samples. Otherwise, model predictions are used
                as labels to avoid the "label leaking" effect (explained in this paper:
                https://arxiv.org/abs/1611.01236). Default is None. This argument does not have
                to be a binary one-hot label (e.g., [0, 1, 0, 0]), it can be floating points values
                that sum up to 1 (e.g., [0.05, 0.85, 0.05, 0.05]).
        :param targeted: (optional) bool. Is the attack targeted or untargeted?
                Untargeted, the default, will try to make the label incorrect.
                Targeted will instead try to move in the direction of being more like y.
        :param aggregation: (optional) str. How to aggregate the losses across models.
        :return: a tensor for the adversarial example
        """
        if norm not in [jnp.inf, 2]:
            raise ValueError("Norm order must be either np.inf or 2.")

        if y is None:
            # Using model predictions as ground truth to avoid label leaking
            x_labels = jnp.argmax(model_fn(x), -1)
            y = jax.nn.one_hot(x_labels, 10)

        def loss_adv(image, label):
            pred = model_fn(image[None])[:, 0]
            per_model_losses = -jnp.sum(jax.nn.log_softmax(pred) * label, axis=-1)
            if targeted:
                per_model_losses = -per_model_losses

            if aggregation == "mean":
                return jnp.mean(per_model_losses)
            elif aggregation == "min":
                return jnp.min(per_model_losses)
            else:
                raise ValueError(f"Invalid aggregation method: {aggregation}")

        grads_fn = jax.vmap(jax.grad(loss_adv), in_axes=(0, 1), out_axes=0)
        grads = grads_fn(x, y)

        axis = list(range(1, len(grads.shape)))
        avoid_zero_div = 1e-12
        if norm == jnp.inf:
            perturbation = eps * jnp.sign(grads)
        elif norm == 1:
            raise NotImplementedError("L_1 norm has not been implemented yet.")
        elif norm == 2:
            square = jnp.maximum(avoid_zero_div, jnp.sum(jnp.square(grads), axis=axis, keepdims=True))
            perturbation = grads / jnp.sqrt(square)

        adv_x = x + perturbation

        # If clipping is needed, reset all values outside of [clip_min, clip_max]
        if (clip_min is not None) or (clip_max is not None):
            # We don't currently support one-sided clipping
            assert clip_min is not None and clip_max is not None
            adv_x = jnp.clip(adv_x, a_min=clip_min, a_max=clip_max)

        return adv_x
