# MIT License
#
# Copyright (C) The Adversarial Robustness Toolbox (ART) Authors 2022
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
# documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit
# persons to whom the Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the
# Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
# WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""
This module implements Sleeper Agent attack on Neural Networks.

| Paper link: https://arxiv.org/abs/2106.08970
"""
from __future__ import absolute_import, division, print_function, unicode_literals

import logging
from typing import Tuple, TYPE_CHECKING, List, Union
import random

import numpy as np
from tqdm.auto import trange
from PIL import Image

from art.attacks.poisoning.custom_gradient_matching_attack import CustomGradientMatchingAttack
from art.estimators.classification.pytorch import PyTorchClassifier
from art.estimators.classification import TensorFlowV2Classifier
from art.preprocessing.standardisation_mean_std.pytorch import StandardisationMeanStdPyTorch
from art.preprocessing.standardisation_mean_std.tensorflow import StandardisationMeanStdTensorFlow
from art.attacks.poisoning.perturbations import add_pattern_bd, add_single_bd, insert_image
import copy
import cv2
import torch
from datetime import datetime
import scipy.stats as st

import time

def format_duration(seconds):
    """Formats the duration in seconds to a more readable format (hours, minutes, seconds)."""
    hours = int(seconds // 3600)
    minutes = int((seconds % 3600) // 60)
    seconds = seconds % 60
    return f"{hours}h {minutes}m {seconds:.2f}s"

def loop_time_estimator(iterable):
    """
    A generator that wraps a loop, estimating the time for each iteration and the total time to complete the loop.

    Parameters:
    iterable (iterable): The iterable to loop over.

    Yields:
    tuple: The current item from the iterable and timing information (iter_duration, elapsed_time, estimated_total_time, remaining_time).
    """
    start_time = time.time()
    total_iterations = len(iterable)

    for i, item in enumerate(iterable):
        iter_start_time = time.time()
        
        yield item, {}  # Yield the current item with empty timing info initially
        
        iter_end_time = time.time()
        iter_duration = iter_end_time - iter_start_time
        
        # Estimate total remaining time
        elapsed_time = iter_end_time - start_time
        estimated_total_time = (elapsed_time / (i + 1)) * total_iterations
        remaining_time = estimated_total_time - elapsed_time
        
        timing_info = {
            'iter_duration': iter_duration,
            'elapsed_time': elapsed_time,
            'estimated_total_time': estimated_total_time,
            'remaining_time': remaining_time
        }

        yield timing_info, item
    
    total_end_time = time.time()
    total_duration = total_end_time - start_time


def refool_blend_images(img_t, img_r, max_image_size=560, ghost_rate=0.49, alpha_t=-1., offset=(0, 0), sigma=-1,
                 ghost_alpha=-1.):
    """
    Blend transmit layer and reflection layer together (include blurred & ghosted reflection layer) and
    return the blended image and precessed reflection image
    Note: t and r should be in range (0,1) and of type np.float32
    """
    t = np.float32(img_t)
    r = np.float32(img_r)
    h, w, _ = t.shape
    # convert t.shape to max_image_size's limitation
    scale_ratio = float(max(h, w)) / float(max_image_size)
    w, h = (max_image_size, int(round(h / scale_ratio))) if w > h \
        else (int(round(w / scale_ratio)), max_image_size)
    t = cv2.resize(t, (w, h), cv2.INTER_CUBIC)
    r = cv2.resize(r, (w, h), cv2.INTER_CUBIC)

    if alpha_t < 0:
        alpha_t = 1. - random.uniform(0.05, 0.45)

    if random.random() < ghost_rate:
        t = np.power(t, 2.2)
        r = np.power(r, 2.2)

        # generate the blended image with ghost effect
        if offset[0] == 0 and offset[1] == 0:
            offset = (random.randint(3, 8), random.randint(3, 8))
        r_1 = np.lib.pad(r, ((0, offset[0]), (0, offset[1]), (0, 0)),
                         'constant', constant_values=0)
        r_2 = np.lib.pad(r, ((offset[0], 0), (offset[1], 0), (0, 0)),
                         'constant', constant_values=(0, 0))
        if ghost_alpha < 0:
            ghost_alpha = abs(round(random.random()) - random.uniform(0.15, 0.5))

        ghost_r = r_1 * ghost_alpha + r_2 * (1 - ghost_alpha)
        ghost_r = cv2.resize(ghost_r[offset[0]: -offset[0], offset[1]: -offset[1], :],
                             (w, h), cv2.INTER_CUBIC)
        reflection_mask = ghost_r * (1 - alpha_t)

        blended = reflection_mask + t * alpha_t

        transmission_layer = np.power(t * alpha_t, 1 / 2.2)

        ghost_r = np.clip(np.power(reflection_mask, 1 / 2.2), 0, 1)
        blended = np.clip(np.power(blended, 1 / 2.2), 0, 1)

        reflection_layer = ghost_r
    else:
        # generate the blended image with focal blur
        if sigma < 0:
            sigma = random.uniform(1, 5)

        t = np.power(t, 2.2)
        r = np.power(r, 2.2)

        sz = int(2 * np.ceil(2 * sigma) + 1)
        r_blur = cv2.GaussianBlur(r, (sz, sz), sigma, sigma, 0)
        blend = r_blur + t

        # get the reflection layers' proper range
        att = 1.08 + np.random.random() / 10.0
        for i in range(3):
            maski = blend[:, :, i] > 1
            mean_i = max(1., np.sum(blend[:, :, i] * maski) / (maski.sum() + 1e-6))
            r_blur[:, :, i] = r_blur[:, :, i] - (mean_i - 1) * att
        r_blur[r_blur >= 1] = 1
        r_blur[r_blur <= 0] = 0

        def gen_kernel(kern_len=100, nsig=1):
            """Returns a 2D Gaussian kernel array."""
            interval = (2 * nsig + 1.) / kern_len
            x = np.linspace(-nsig - interval / 2., nsig + interval / 2., kern_len + 1)
            # get normal distribution
            kern1d = np.diff(st.norm.cdf(x))
            kernel_raw = np.sqrt(np.outer(kern1d, kern1d))
            kernel = kernel_raw / kernel_raw.sum()
            kernel = kernel / kernel.max()
            return kernel

        h, w = r_blur.shape[:2]
        new_w = np.random.randint(0, max_image_size - w - 10) if w < max_image_size - 10 else 0
        new_h = np.random.randint(0, max_image_size - h - 10) if h < max_image_size - 10 else 0

        g_mask = gen_kernel(max_image_size, 3)
        g_mask = np.dstack((g_mask, g_mask, g_mask))
        alpha_r = g_mask[new_h: new_h + h, new_w: new_w + w, :] * (1. - alpha_t / 2.)

        r_blur_mask = np.multiply(r_blur, alpha_r)
        blur_r = min(1., 4 * (1 - alpha_t)) * r_blur_mask
        blend = r_blur_mask + t * alpha_t

        transmission_layer = np.power(t * alpha_t, 1 / 2.2)
        r_blur_mask = np.power(blur_r, 1 / 2.2)
        blend = np.power(blend, 1 / 2.2)
        blend[blend >= 1] = 1
        blend[blend <= 0] = 0

        blended = blend
        reflection_layer = r_blur_mask

    return blended, transmission_layer, reflection_layer

def refool_reflect_images(img_t, img_r, max_image_size=32, ghost_rate=0.49, alpha_t=-1., offset=(0, 0), sigma=-1,
                       ghost_alpha=-1., blend_location='random'):
    """
    Blend transmit layer and reflection layer together for CIFAR-10 images and
    return the blended image and processed reflection image.
    
    img_t: The larger image of shape (3, 32, 32).
    img_r: The smaller image of shape (3, 8, 8) or similar.
    blend_location: 'random' for random location, or (x, y) tuple for fixed location.
    """
    # Transpose CIFAR-10 images to (H, W, 3) format
    img_t = np.transpose(img_t, (1, 2, 0))
    img_r = np.transpose(img_r, (1, 2, 0))
    
    h_t, w_t, _ = img_t.shape
    h_r, w_r, _ = img_r.shape
    
    # Choose a location to blend the smaller image within the larger image
    if blend_location == 'random':
        x_cord = np.random.randint(0, h_t - h_r)
        y_cord = np.random.randint(0, w_t - w_r)
    else:
        x_cord, y_cord = blend_location
    
    # Create a mask for the smaller image within the larger image
    img_t_blend = img_t.copy()
    img_t_blend[x_cord:x_cord + h_r, y_cord:y_cord + w_r, :] = img_r
    
    # Proceed with the blending as usual
    blended, transmission_layer, reflection_layer = refool_blend_images(img_t, img_t_blend, max_image_size, ghost_rate,
                                                                 alpha_t, offset, sigma, ghost_alpha)

    # Transpose the output back to (3, H, W) format
    blended = np.transpose(blended, (2, 0, 1)).astype(np.float32)
    transmission_layer = np.transpose(transmission_layer, (2, 0, 1)).astype(np.float32)
    reflection_layer = np.transpose(reflection_layer, (2, 0, 1)).astype(np.float32)

    return blended, transmission_layer, reflection_layer




if TYPE_CHECKING:
    # pylint: disable=C0412
    from art.utils import CLASSIFIER_NEURALNETWORK_TYPE

logger = logging.getLogger(__name__)


class GradientStorm(CustomGradientMatchingAttack):
    """
    Implementation of Gradient Storm Attack
    """

    def __init__(
        self,
        classifier: "CLASSIFIER_NEURALNETWORK_TYPE",
        num_poisons: List[int],
        num_cycles: int,
        cycle_rounds: int,
        patches: List[np.ndarray],
        trigger_types: List[str],
        target_indices: List[int],
        epsilon: float = 0.1,
        max_epochs: int = 250,
        learning_rate_schedule: Tuple[List[float], List[int]] = ([1e-1, 1e-2, 1e-3, 1e-4], [100, 150, 200, 220]),
        batch_size: int = 128,
        clip_values: Tuple[float, float] = (0, 1.0),
        verbose: int = 1,
        patching_strategy: str = "random",
        selection_strategy: str = "random",
        model_retraining_epoch: int = 1,
        device_name: str = "cpu",
        retrain_batch_size: int = 128,
        **kwargs
    ):
        """
        Initialize a Sleeper Agent poisoning attack.

        :param classifier: The proxy classifier used for the attack.
        :param percent_poison: The ratio of samples to poison among x_train, with range [0,1].
        :param patch: The patch to be applied as trigger.
        :param target_indices: The indices of training data having target label.
        :param epsilon: The L-inf perturbation budget.
        :param max_trials: The maximum number of restarts to optimize the poison.
        :param max_epochs: The maximum number of epochs to optimize the train per trial.
        :param learning_rate_schedule: The learning rate schedule to optimize the poison.
            A List of (learning rate, epoch) pairs. The learning rate is used
            if the current epoch is less than the specified epoch.
        :param batch_size: Batch size.
        :param clip_values: The range of the input features to the classifier.
        :param verbose: Show progress bars.
        :param patching_strategy: Patching strategy to be used for adding trigger, either random/fixed.
        :param selection_strategy: Selection strategy for getting the indices of
                             poison examples - either random/maximum gradient norm.
        :param model_retraining_epoch: The epochs for which retraining has to be applied.
        :param retrain_batch_size: Batch size required for model retraining.
        """
        if isinstance(classifier.preprocessing, (StandardisationMeanStdPyTorch, StandardisationMeanStdTensorFlow)):
            clip_values_normalised = (
                classifier.clip_values - classifier.preprocessing.mean  # type: ignore
            ) / classifier.preprocessing.std
            clip_values_normalised = (clip_values_normalised[0], clip_values_normalised[1])
            epsilon_normalised = epsilon * (clip_values_normalised[1] - clip_values_normalised[0])  # type: ignore
            normalised_patches = []
            for patch, trigger_type in zip(patches, trigger_types):
                if trigger_type in {'patch','blend','refool'}:
                    normalised_patches.append(((patch - classifier.preprocessing.mean) / classifier.preprocessing.std).squeeze()) # I added the squeeze because of the new normalization
                elif trigger_type in {'horizontal_sinusoidal','vertical_sinusoidal'}:
                    normalised_patches.append(patch)
        else:
            raise ValueError("classifier.preprocessing not an instance of pytorch/tensorflow")

        percent_poison = 0.0 # just for the sake of handling errors       
        max_trials = 1
        super().__init__(
            classifier,
            percent_poison,
            epsilon_normalised,
            max_trials,
            max_epochs,
            learning_rate_schedule,
            batch_size,
            clip_values_normalised,
            verbose,
            **kwargs
        )
        self.num_cycles = num_cycles
        self.cycle_rounds = cycle_rounds
        self.target_indices = target_indices
        self.num_poisons = num_poisons
        self.selection_strategy = selection_strategy
        self.patching_strategy = patching_strategy
        self.model_retraining_epoch = model_retraining_epoch
        self.indices_poison: np.ndarray
        self.patches = normalised_patches
        self.trigger_types = trigger_types
        self.device_name = device_name
        self.initial_epoch = 0
        self.retrain_batch_size = retrain_batch_size
        if 'x_test_poison' in kwargs and 'y_test_poison' in kwargs:
            self.x_test_poison = kwargs['x_test_poison']
            self.y_test_poison = kwargs['y_test_poison']
        if 'augment' in kwargs:
            self.augment = kwargs['augment']
        else:
            self.augment = False

    # pylint: disable=W0221
    def poison(  # type: ignore
        self,
        x_triggers: List[np.ndarray],
        y_triggers: List[np.ndarray],
        x_train: np.ndarray,
        y_train: np.ndarray,
        x_test: np.ndarray,
        y_test: np.ndarray,
    ) -> Tuple[np.ndarray, np.ndarray]:
        """
        Optimizes a portion of poisoned samples from x_train to make a model classify x_target
        as y_target by matching the gradients.

        :param x_triggers: A list of samples to use as triggers.
        :param y_triggers: A list of target classes to classify the triggers into.
        :param x_train: A list of training data to poison a portion of.
        :param y_train: A list of labels for x_train.
        :return: x_train, y_train and indices of poisoned samples.
                 Here, x_train are the samples selected from target class
                 in training data.
        """
        # Apply Normalisation
        x_train = np.copy(x_train)
        if isinstance(
            self.substitute_classifier.preprocessing, (StandardisationMeanStdPyTorch, StandardisationMeanStdTensorFlow)
        ):
            for idx, x_trigger in enumerate(x_triggers):
                x_triggers[idx] = (
                    x_trigger - self.substitute_classifier.preprocessing.mean
                ) / self.substitute_classifier.preprocessing.std
            x_train = (
                x_train - self.substitute_classifier.preprocessing.mean
            ) / self.substitute_classifier.preprocessing.std

        if isinstance(self.substitute_classifier, PyTorchClassifier):
            poisoner = self._poison__pytorch
            finish_poisoning = self._finish_poison_pytorch
            initializer = self._initialize_poison_pytorch
        else:
            raise NotImplementedError("GradientStorm is currently implemented only for PyTorch.")

        # Choose samples to poison.
        for idx, (x_trigger, patch, trigger_type) in enumerate(zip(x_triggers, self.patches, self.trigger_types)):
            if trigger_type == 'patch':
                x_triggers[idx] = self._apply_trigger_patch(x_trigger, patch)
            elif trigger_type == 'blend':
                x_triggers[idx] = self._apply_trigger_blend(x_trigger, patch)
            elif trigger_type == 'refool':
                x_triggers[idx] = self._apply_refool(x_trigger, patch)
            elif trigger_type == 'horizontal_sinusoidal':
                x_triggers[idx] = self._apply_horizontal_sinusoidal_signal(x_trigger)
            elif trigger_type == 'vertical_sinusoidal':
                x_triggers[idx] = self._apply_vertical_sinusoidal_signal(x_trigger)

        # Try poisoning num_trials times and choose the best one.
        best_indices_poison: np.ndarray
        x_train_orig = x_train.copy()
        y_train_orig = y_train.copy()
        x_train_target_samples_orig = [
            x_train[target_indices_].copy() for target_indices_ in self.target_indices
        ]
        y_train_target_samples_orig = [
            y_train[target_indices_].copy() for target_indices_ in self.target_indices
        ]
        
        best_asr_score = 0.0

        self.num_attacks = len(self.patches)
        for cycle_id in range(self.num_cycles):
            x_train = x_train_orig.copy()
            y_train = y_train_orig.copy()
            x_train_target_samples = copy.deepcopy(x_train_target_samples_orig)
            y_train_target_samples = copy.deepcopy(y_train_target_samples_orig)
            indices_poison = []
            B_list = []
            cycle_asr_list = []
            for attack_id, (x_trigger, y_trigger, num_poison) in enumerate(zip(x_triggers, y_triggers, self.num_poisons)):
                num_poison_samples_per_round = int(num_poison/self.cycle_rounds)
                dummy_round_x_poison = np.random.rand(num_poison_samples_per_round, 3, 32, 32)
                dummy_round_y_poison = np.random.rand(num_poison_samples_per_round, 10)
                initializer(x_trigger, y_trigger, dummy_round_x_poison, dummy_round_y_poison)
                for i in range(self.cycle_rounds):
                    current_datetime = datetime.now()
                    formatted_datetime = current_datetime.strftime("%Y-%m-%d %H:%M:%S")
                    self.lprint(f' --- attack {attack_id+1}, cycle {cycle_id+1}, round {i+1} started at {formatted_datetime} --- ')
                    round_indices_poison = self._select_poison_indices(
                            self.substitute_classifier, x_train_target_samples[attack_id],
                            y_train_target_samples[attack_id], num_poison_samples_per_round,
                            indices_poison
                    )
                    indices_poison.extend(list(round_indices_poison.reshape(-1)))
                    round_x_poison = x_train_target_samples[attack_id][round_indices_poison].copy()
                    round_y_poison = y_train_target_samples[attack_id][round_indices_poison].copy()
                    # x_poisoned, B_ = poisoner(round_x_poison, round_y_poison, attack_id=attack_id, round_indices_poison=round_indices_poison, x_train=x_train)
                    x_poisoned, B_ = poisoner(round_x_poison, round_y_poison)
                    B_list.append(B_)
                    x_train[self.target_indices[attack_id][round_indices_poison]] = x_poisoned
                    curr_asr_list = self._model_retraining(x_train, y_train, x_test, y_test)
                    cycle_asr_list.extend(curr_asr_list[-int(self.model_retraining_epoch/10):])
            
            
            mean_asr_score = 0.0
            for asr_ in cycle_asr_list:
                mean_asr_score += asr_
            mean_asr_score /= len(cycle_asr_list)
            self.lprint(f'cycle {cycle_id+1}, current mean asr: {mean_asr_score}') 
            if mean_asr_score > best_asr_score:
                best_asr_score = mean_asr_score
                best_x_train = x_train.copy()
                best_y_train = y_train.copy()
                best_indices_poison = indices_poison
                self.indices_poison = best_indices_poison
                
            # mean_score = 0.0
            # for item in B_list:
            #     curr_mean = np.mean(item)
            #     mean_score += curr_mean
            # mean_score /= len(B_list) 
            # self.lprint(f'cycle {cycle_id+1}, current mean_score: {mean_score}') 
            # if mean_score < best_mean_score:
            #     best_mean_score = mean_score
            #     best_x_train = x_train.copy()
            #     best_y_train = y_train.copy()
            #     best_indices_poison = indices_poison
            #     self.indices_poison = best_indices_poison
        finish_poisoning()
        self.lprint(f'best mean score: {best_asr_score}')   
        x_train = x_train_orig.copy()
        best_x_train = (
                best_x_train * self.substitute_classifier.preprocessing.std
                + self.substitute_classifier.preprocessing.mean
        )
        return best_x_train, best_y_train
            
    def _select_target_train_samples(self, x_train: np.ndarray, y_train: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        """
        Used for selecting train samples from target class
        :param x_train: clean training data
        :param y_train: labels fo clean training data
        :return x_train_target_samples, y_train_target_samples:
        samples and labels selected from target class in train data
        """
        x_train_samples = np.copy(x_train)
        index_target = np.where(y_train.argmax(axis=1) == self.class_target)[0]
        x_train_target_samples = x_train_samples[index_target]
        y_train_target_samples = y_train[index_target]
        return x_train_target_samples, y_train_target_samples

    def get_poison_indices(self) -> np.ndarray:
        """
        :return: indices of best poison index
        """
        return self.indices_poison

    def _model_retraining(
        self,
        x_train: np.ndarray,
        y_train: np.ndarray,
        x_test: np.ndarray,
        y_test: np.ndarray,
    ):
        """
        Applies retraining to substitute model

        :param poisoned_samples: poisoned array.
        :param x_train: clean training data.
        :param y_train: labels for training data.
        :param x_test: clean test data.
        :param y_test: labels for test data.
        """
        if isinstance(
            self.substitute_classifier.preprocessing, (StandardisationMeanStdPyTorch, StandardisationMeanStdTensorFlow)
        ):
            x_train_un = np.copy(x_train)
            # x_train_un[self.indices_target[self.indices_poison]] = poisoned_samples
            x_train_un = x_train_un * self.substitute_classifier.preprocessing.std
            x_train_un += self.substitute_classifier.preprocessing.mean

        if isinstance(self.substitute_classifier, PyTorchClassifier):
            check_train = self.substitute_classifier.model.training
            model_pt, asr_list = self._create_model(
                x_train_un,
                y_train,
                x_test,
                y_test,
                batch_size=self.retrain_batch_size,
                epochs=self.model_retraining_epoch,
            )
            self.substitute_classifier = model_pt
            self.substitute_classifier.model.training = check_train
            return asr_list

        elif isinstance(self.substitute_classifier, TensorFlowV2Classifier):
            check_train = self.substitute_classifier.model.trainable
            model_tf = self._create_model(
                x_train_un,
                y_train,
                x_test,
                y_test,
                batch_size=self.retrain_batch_size,
                epochs=self.model_retraining_epoch,
            )

            self.substitute_classifier = model_tf
            self.substitute_classifier.model.trainable = check_train

        else:
            raise NotImplementedError("SleeperAgentAttack is currently implemented only for PyTorch and TensorFlowV2.")

    def _create_model(
        self,
        x_train: np.ndarray,
        y_train: np.ndarray,
        x_test: np.ndarray,
        y_test: np.ndarray,
        batch_size: int = 128,
        epochs: int = 80,
    ) -> Union["TensorFlowV2Classifier", "PyTorchClassifier"]:
        """
        Creates a new model.

        :param x_train: Samples of train data.
        :param y_train: Labels of train data.
        :param x_test: Samples of test data.
        :param y_test: Labels of test data.
        :param num_classes: Number of classes of labels in train data.
        :param batch_size: The size of batch used for training.
        :param epochs: The number of epochs for which training need to be applied.
        :return model, loss_fn, optimizer - trained model, loss function used to train the model and optimizer used.
        """
        if isinstance(self.substitute_classifier, PyTorchClassifier):
            # Reset Weights of the newly initialized model
            model_pt = self.substitute_classifier.clone_for_refitting()
            for layer in model_pt.model.children():
                if hasattr(layer, "reset_parameters"):
                    layer.reset_parameters()  # type: ignore
            asr_list = model_pt.fit(x_train, y_train, batch_size=batch_size, nb_epochs=epochs, verbose=True, x_test=x_test, y_test=y_test, log_path = self.log_path, x_test_poison=self.x_test_poison, y_test_poison=self.y_test_poison, augment=self.augment, num_attacks=self.num_attacks)
            predictions = model_pt.predict(x_test)
            accuracy = np.sum(np.argmax(predictions, axis=1) == np.argmax(y_test, axis=1)) / len(y_test)
            logger.info("Accuracy of retrained model : %s", accuracy * 100.0)
            return model_pt, asr_list

        if isinstance(self.substitute_classifier, TensorFlowV2Classifier):

            self.substitute_classifier.model.trainable = True
            model_tf = self.substitute_classifier.clone_for_refitting()
            model_tf.fit(x_train, y_train, batch_size=batch_size, nb_epochs=epochs, verbose=False)
            predictions = model_tf.predict(x_test)
            accuracy = np.sum(np.argmax(predictions, axis=1) == np.argmax(y_test, axis=1)) / len(y_test)
            logger.info("Accuracy of retrained model : %s", accuracy * 100.0)
            return model_tf

        raise ValueError("SleeperAgentAttack is currently implemented only for PyTorch and TensorFlowV2.")

    # This function is responsible for returning indices of poison images with maximum gradient norm
    def _select_poison_indices(
        self, classifier: "CLASSIFIER_NEURALNETWORK_TYPE", x_samples: np.ndarray, y_samples: np.ndarray, num_poison: int, current_indices: list
    ) -> np.ndarray:
        """
        Select indices of poisoned samples

        :classifier: Substitute Model.
        :x_samples: Samples of poison. [x_samples are normalised]
        :y_samples: Labels of samples of poison.
        :num_poison: Number of poisoned samples to be selected out of all x_samples.
        :return indices - Indices of samples to be poisoned.
        """
        if isinstance(self.substitute_classifier, PyTorchClassifier):
            import torch

            device = torch.device(self.device_name)
            grad_norms = []
            criterion = torch.nn.CrossEntropyLoss()
            model = classifier.model
            model.eval()
            differentiable_params = [p for p in classifier.model.parameters() if p.requires_grad]
            for x, y in zip(x_samples, y_samples):
                image = torch.tensor(x, dtype=torch.float32).float().to(device)
                label = torch.tensor(y).to(device)
                loss_pt = criterion(model(image.unsqueeze(0)), label.unsqueeze(0))
                gradients = list(torch.autograd.grad(loss_pt, differentiable_params, only_inputs=True))
                grad_norm = torch.tensor(0, dtype=torch.float32).to(device)
                for grad in gradients:
                    grad_norm += grad.detach().pow(2).sum()
                grad_norms.append(grad_norm.sqrt())
        elif isinstance(self.substitute_classifier, TensorFlowV2Classifier):
            import tensorflow as tf

            model_trainable = classifier.model.trainable
            classifier.model.trainable = False
            grad_norms = []
            for i in range(len(x_samples) - 1):
                image = tf.constant(x_samples[i : i + 1])
                label = tf.constant(y_samples[i : i + 1])
                with tf.GradientTape() as t:  # pylint: disable=C0103
                    t.watch(classifier.model.weights)
                    output = classifier.model(image, training=False)
                    loss_tf = classifier.loss_object(label, output)  # type: ignore
                    gradients = list(t.gradient(loss_tf, classifier.model.weights))
                    gradients = [w for w in gradients if w is not None]
                    grad_norm = tf.constant(0, dtype=tf.float32)
                    for grad in gradients:
                        grad_norm += tf.reduce_sum(tf.math.square(grad))
                    grad_norms.append(tf.math.sqrt(grad_norm))
            classifier.model.trainable = model_trainable
        else:
            raise NotImplementedError("SleeperAgentAttack is currently implemented only for PyTorch and TensorFlowV2.")
        indices = sorted(range(len(grad_norms)), key=lambda k: grad_norms[k])  # type: ignore
        current_indices_set = set(current_indices)
        indices = [idx for idx in indices if idx not in current_indices_set]
        indices = indices[-num_poison:]
        return np.array(indices)  # this will get only indices for target class

    # This function is responsible for applying trigger patches to the images
    # fixed - where the trigger is applied at the bottom right of the image
    # random - where the trigger is applied at random location of the image
    def _apply_trigger_patch(self, x_trigger: np.ndarray, patch: np.ndarray) -> np.ndarray:
        """
        Select indices of poisoned samples

        :x_trigger: Samples to be used for trigger.
        :return tensor with applied trigger patches.
        """
        patch_size = patch.shape[1]
        if self.patching_strategy == "fixed":
            if self.estimator.channels_first:
                x_trigger[:, :, -patch_size:, -patch_size:] = patch
            else:
                x_trigger[:, -patch_size:, -patch_size:, :] = patch
        else:
            for x in x_trigger:
                if self.estimator.channels_first:
                    x_cord = random.randrange(0, x.shape[1] - patch.shape[1] + 1)
                    y_cord = random.randrange(0, x.shape[2] - patch.shape[2] + 1)
                    x[:, x_cord : x_cord + patch_size, y_cord : y_cord + patch_size] = patch
                else:
                    x_cord = random.randrange(0, x.shape[0] - patch.shape[0] + 1)
                    y_cord = random.randrange(0, x.shape[1] - patch.shape[1] + 1)
                    x[x_cord : x_cord + patch_size, y_cord : y_cord + patch_size, :] = patch

        return x_trigger

    def _apply_trigger_blend(self, x_trigger: np.ndarray, patch: np.ndarray, alpha: float = 0.5) -> np.ndarray:
        """
        Apply a blended trigger patch to images in x_trigger.

        :param x_trigger: Samples to be used for the trigger (numpy array of shape (N, C, H, W) or (N, H, W, C)).
        :param patch: The patch to blend into the images (numpy array of shape (C, patch_size, patch_size) or (patch_size, patch_size, C)).
        :param alpha: The blending coefficient, 0.0 means only the image, 1.0 means only the patch.
        :return: The modified x_trigger with applied trigger patches.
        """
        patch_size = patch.shape[1]
        if self.patching_strategy == "fixed":
            if self.estimator.channels_first:
                # Channels first blending
                x_trigger[:, :, -patch_size:, -patch_size:] = (
                    (1 - alpha) * x_trigger[:, :, -patch_size:, -patch_size:] +
                    alpha * patch
                )
            else:
                # Channels last blending
                x_trigger[:, -patch_size:, -patch_size:, :] = (
                    (1 - alpha) * x_trigger[:, -patch_size:, -patch_size:, :] +
                    alpha * patch
                )
        else:
            for x in x_trigger:
                if self.estimator.channels_first:
                    x_cord = random.randrange(0, x.shape[1] - patch.shape[1] + 1)
                    y_cord = random.randrange(0, x.shape[2] - patch.shape[2] + 1)
                    
                    # Perform blending
                    x[:, x_cord : x_cord + patch_size, y_cord : y_cord + patch_size] = (
                        (1 - alpha) * x[:, x_cord : x_cord + patch_size, y_cord : y_cord + patch_size] +
                        alpha * patch
                    )
                else:
                    x_cord = random.randrange(0, x.shape[0] - patch.shape[0] + 1)
                    y_cord = random.randrange(0, x.shape[1] - patch.shape[1] + 1)
                    
                    # Perform blending
                    x[x_cord : x_cord + patch_size, y_cord : y_cord + patch_size, :] = (
                        (1 - alpha) * x[x_cord : x_cord + patch_size, y_cord : y_cord + patch_size, :] +
                        alpha * patch
                    )

        return x_trigger
            
    def _apply_horizontal_sinusoidal_signal(self, x_: np.ndarray) -> np.ndarray:
        x_trigger = np.copy(x_)
        def denormalize(image, mean, std):
            """
            Denormalizes an image using the provided mean and std.
            """
            # mean = np.array(mean).reshape(3, 1, 1)
            # std = np.array(std).reshape(3, 1, 1)
            return image * std + mean

        def normalize(image, mean, std):
            """
            Normalizes an image using the provided mean and std.
            """
            # mean = np.array(mean).reshape(3, 1, 1)
            # std = np.array(std).reshape(3, 1, 1)
            return (image - mean) / std
        
        def add_sinusoidal_signal(image, delta, f, mean, std):
            """
            Adds a horizontal sinusoidal signal to an image normalized by mean and std.

            Parameters:
            image (numpy.ndarray): The input normalized image of shape (3, 32, 32).
            delta (float): The amplitude of the sinusoidal signal.
            f (float): The frequency of the sinusoidal signal.
            mean (list or tuple): The mean used for normalization.
            std (list or tuple): The standard deviation used for normalization.

            Returns:
            numpy.ndarray: The image with the added sinusoidal signal, normalized by mean and std.
            """
            # Ensure the image has the correct shape
            assert image.shape == (3, 32, 32), "Image must be of shape (3, 32, 32)"
            
            # Denormalize the image
            image_denormalized = denormalize(image, mean, std).squeeze()
            # image_denormalized = np.copy(image)
            
            # Get the dimensions of the image
            _, l, m = image_denormalized.shape
            
            # Generate the sinusoidal signal
            j = np.arange(1, m+1)
            v = delta * np.sin(2 * np.pi * j * f / m)
            
            # Add the signal to each row of each channel
            image_with_signal = image_denormalized.copy()
            for channel in range(3):
                for i in range(l):
                    image_with_signal[channel, i, :] += v
            
            # Clip values to ensure they remain in the valid range [0, 1]
            image_with_signal = np.clip(image_with_signal, 0, 1)
            
            # Normalize the image again
            image_with_signal_normalized = normalize(image_with_signal, mean, std).squeeze()
            # image_with_signal_normalized = np.copy(image_with_signal)
            
            return image_with_signal_normalized

        result = np.array([add_sinusoidal_signal(x, delta=0.1, f=5.0, mean=self.substitute_classifier.preprocessing.mean, std=self.substitute_classifier.preprocessing.std) for x in x_trigger])
        return result
    
    def _apply_vertical_sinusoidal_signal(self, x_: np.ndarray) -> np.ndarray:
        x_trigger = np.copy(x_)
        def denormalize(image, mean, std):
            """
            Denormalizes an image using the provided mean and std.
            """
            # mean = np.array(mean).reshape(3, 1, 1)
            # std = np.array(std).reshape(3, 1, 1)
            return image * std + mean

        def normalize(image, mean, std):
            """
            Normalizes an image using the provided mean and std.
            """
            # mean = np.array(mean).reshape(3, 1, 1)
            # std = np.array(std).reshape(3, 1, 1)
            return (image - mean) / std
        
        def add_sinusoidal_signal(image, delta, f, mean, std):
            """
            Adds a horizontal sinusoidal signal to an image normalized by mean and std.

            Parameters:
            image (numpy.ndarray): The input normalized image of shape (3, 32, 32).
            delta (float): The amplitude of the sinusoidal signal.
            f (float): The frequency of the sinusoidal signal.
            mean (list or tuple): The mean used for normalization.
            std (list or tuple): The standard deviation used for normalization.

            Returns:
            numpy.ndarray: The image with the added sinusoidal signal, normalized by mean and std.
            """
            # Ensure the image has the correct shape
            assert image.shape == (3, 32, 32), "Image must be of shape (3, 32, 32)"
            
            # Denormalize the image
            image_denormalized = denormalize(image, mean, std).squeeze()
            # image_denormalized = np.copy(image)
            
            # Get the dimensions of the image
            _, l, m = image_denormalized.shape
            
            # Generate the sinusoidal signal
            i = np.arange(1, l+1)
            v = delta * np.sin(2 * np.pi * i * f / l)
            
            # Add the signal to each row of each channel
            image_with_signal = image_denormalized.copy()
            for channel in range(3):
                for j in range(m):
                    image_with_signal[channel, :, j] += v
            
            # Clip values to ensure they remain in the valid range [0, 1]
            image_with_signal = np.clip(image_with_signal, 0, 1)
            
            # Normalize the image again
            image_with_signal_normalized = normalize(image_with_signal, mean, std).squeeze()
            # image_with_signal_normalized = np.copy(image_with_signal)
            
            return image_with_signal_normalized

        result = np.array([add_sinusoidal_signal(x, delta=0.1, f=5.0, mean=self.substitute_classifier.preprocessing.mean, std=self.substitute_classifier.preprocessing.std) for x in x_trigger])
        return result

    def _apply_refool(self, img_array, img_r, max_image_size=32, ghost_rate=0.49, alpha_t=-1., offset=(0, 0), sigma=-1,
                          ghost_alpha=-1., blend_location='random'):
        """
        Blends each image in an array of images with a smaller image and returns the blended images.
    
        Parameters:
        img_array (numpy.ndarray): Array of images of shape (N, 3, H, W).
        img_r (numpy.ndarray): The smaller image of shape (3, h, w) to blend with each image.
        blend_location: 'random' for random location, or (x, y) tuple for fixed location.
    
        Returns:
        numpy.ndarray: Array of blended images of shape (N, 3, H, W).
        """
        def denormalize(image, mean, std):
            """
            Denormalizes an image using the provided mean and std.
            """
            # mean = np.array(mean).reshape(3, 1, 1)
            # std = np.array(std).reshape(3, 1, 1)
            return image * std + mean
    
        def normalize(image, mean, std):
            """
            Normalizes an image using the provided mean and std.
            """
            # mean = np.array(mean).reshape(3, 1, 1)
            # std = np.array(std).reshape(3, 1, 1)
            return (image - mean) / std


        denormalized_img_array = []
        for img in img_array:
            img_denormalized = denormalize(img, mean=self.substitute_classifier.preprocessing.mean,
                                           std=self.substitute_classifier.preprocessing.std).squeeze()
            denormalized_img_array.append(img_denormalized)
        denormalized_img_array = np.array(denormalized_img_array)
        
        blended_images = []
        img_r_denormalized = denormalize(img_r, mean=self.substitute_classifier.preprocessing.mean,
                                           std=self.substitute_classifier.preprocessing.std).squeeze()
    
        for img_t in denormalized_img_array:
            # Blend each image in the array with img_r
            blended, _, _ = refool_reflect_images(img_t, img_r_denormalized, max_image_size, ghost_rate, alpha_t, offset, sigma, 
                                               ghost_alpha, blend_location)
            blended_images.append(blended)

        normalized_blended_images = []
        for img in blended_images:
            normalized_img = normalize(img, mean=self.substitute_classifier.preprocessing.mean,
                                           std=self.substitute_classifier.preprocessing.std).squeeze()
            normalized_blended_images.append(normalized_img)
        normalized_blended_images = np.array(normalized_blended_images).astype(np.float32)
    
        return normalized_blended_images