import numpy as np
import time
import torch
import os
import sys
import math
import torch.nn as nn
import torch.nn.functional as F


class T_MIFPE_untargeted():
    def __init__(self, model, n_iter=100, norm='Linf', n_restarts=1, eps=None,
                 seed=0, loss='ce', eot_iter=1, rho=.75, verbose=False,
                 device='cuda', decay_step='linear'  ):
        self.model = model
        self.n_iter = n_iter
        self.eps = eps - 1.9 * 1e-8
        self.norm = norm
        self.n_restarts = n_restarts
        self.seed = seed
        self.loss = loss
        self.eot_iter = eot_iter
        self.thr_decr = rho
        self.verbose = verbose
        self.device = device
        self.decay_step = decay_step
         

    def dlr_loss(self, x, y):
        x_sorted, ind_sorted = x.sort(dim=1)
        ind = (ind_sorted[:, -1] == y).float()

        return -(x[np.arange(x.shape[0]), y] - x_sorted[:, -2] * ind - x_sorted[:, -1] * (1. - ind)) / (
                x_sorted[:, -1] - x_sorted[:, -3] + 1e-12)

    def cw_loss(self,x,y):
        x_sorted, ind_sorted = x.sort(dim=1)
        ind = (ind_sorted[:, -1] == y).float()

        return -(x[np.arange(x.shape[0]), y] - x_sorted[:, -2] * ind - x_sorted[:, -1] * (1. - ind))

    def check_right_index(self, output, labels):
        output_index = output.argmax(dim=-1) == labels
        mask = output_index.to(dtype=torch.int8)

        mask = torch.unsqueeze(mask, -1)
        return mask


    def get_output_scale_one(self, output):
        std_max_out = []
        maxk = max((10,))
        pred_val_out, pred_id_out = output.topk(maxk, 1, True, True)
        std_max_out.extend((pred_val_out[:, 0] - pred_val_out[:, 1]).cpu().numpy())
        scale_list = [item / 1.0 for item in std_max_out]
        scale_list = torch.tensor(scale_list).to(self.device)
        scale_list = torch.unsqueeze(scale_list, -1)
        return scale_list
 

    
    def get_output_scale_star(self, output, labels, t_star_low_bound=1, t_star_upper_bound_decay=0.9):
        """
        Compute the new scale_new based on logits.
        
        Args:
            output (torch.Tensor): Input logits with shape (batch_size, K).
            labels (torch.Tensor): Ground truth labels.
            t_star_low_bound (float): Lower bound for t_star (default: 1).
            t_star_upper_bound_decay (float): Decay factor for upper bound (default: 0.9).
            
        Returns:
            scale_new (torch.Tensor): New scale values with shape (batch_size, 1).
        """
        # Ensure output is a 2D tensor
        assert output.dim() == 2, "output must be a 2D tensor (batch_size, K)"
        
        # Compute t^*
        t_star = self.compute_t_star(output, labels, low_bound=t_star_low_bound, decay=t_star_upper_bound_decay)  # Call the pre-defined compute_t_star function
        
        # Sort output to get z_pi1 and z_pi2
        sorted_output, _ = torch.sort(output, dim=1, descending=True)
        z_pi1 = sorted_output[:, 0]  # Maximum value
        z_pi2 = sorted_output[:, 1]  # Second-largest value
        
        # Compute Delta = z_pi1 - z_pi2
        Delta = z_pi1 - z_pi2
        
        # Compute scale_new = t^* / Delta
        scale_new = Delta / t_star 
        
        # Reshape scale_new into a 2D tensor (batch_size, 1)
        scale_new = torch.unsqueeze(scale_new, -1)
        
        return scale_new
    
    
    def compute_t_star(self, logits, labels, low_bound=1.278, decay=0.98, n_iter=10):
        """
        Compute t^* in parallel for Case 1 using vectorized operations.
        """
        batch_size, num_classes = logits.shape
        device = logits.device
        labels = labels.to(device).long()

        # Sort logits and compute Delta
        sorted_logits, _ = torch.sort(logits, dim=1, descending=True)
        z_pi1, z_pi2 = sorted_logits[:, 0], sorted_logits[:, 1]
        Delta = z_pi1 - z_pi2
        Delta_value = Delta.detach()

        # Get true label logits and max non-true logits
        z_y = torch.gather(logits, 1, labels.unsqueeze(1)).squeeze(1)
        mask = torch.ones_like(logits, dtype=torch.bool)
        mask.scatter_(1, labels.unsqueeze(1), 0)
        max_z_not_y = torch.where(mask, logits, -torch.inf).max(dim=1)[0]

        # Initialize t_star
        t_star = torch.zeros(batch_size, device=device)

        # Dictionary to store cases where t_star < 1
        t_star_lt1_dict = {
            't_star': [],
            'sorted_logits': [],
            'normalized_logits_diff': []  # (sorted_logits - z_pi1)/(z_pi1 - z_pi2)
        }

        # Case 1: max_{i ≠ y} z_i - z_y < 0 (z_y is the maximum)
        case1_mask = (max_z_not_y - z_y) < 0
        if case1_mask.any():
            case1_logits = logits[case1_mask]
            case1_sorted_logits = sorted_logits[case1_mask]  # Store sorted logits for this case
            case1_z_y = z_y[case1_mask]
            case1_Delta = Delta_value[case1_mask]

            # Define the vectorized g(t) for all samples in the batch
            def g_vectorized(t, z_values, z_pi1, delta_value):
                c = t / delta_value
                z_pi1 = z_pi1.unsqueeze(1)   
                exponent = (z_values - z_pi1) * c.unsqueeze(1)  # Shape: [batch_size * n_grid, num_classes]
                sum_exp = torch.sum(torch.exp(exponent), dim=1)  # Shape: [batch_size * n_grid]
                 
                p_pi1 = 1 / sum_exp
                 
                raw_value = c * (1 - p_pi1)
            
                # Add small epsilon to avoid log(0)
                safe_value = torch.clamp(raw_value, min=1e-10)
                return torch.log(safe_value)  # ln(c*(1-p_pi1^c))

            # Set fixed bounds: t ∈ [1.278, 10]
            t_lower = torch.full((case1_logits.shape[0],), self.low_bound, device=device)  # Lower bound = 1.278
            t_upper = torch.full((case1_logits.shape[0],), 10.512, device=device)    # Upper bound = 10.512

            # Generate t_grid for all samples (shape: [batch_size, n_grid])
            n_grid = 1000
            t_grid = torch.linspace(0, 1, n_grid, device=device).unsqueeze(0)  # Shape: [1, n_grid]
            t_grid = t_lower.unsqueeze(1) + t_grid * (t_upper - t_lower).unsqueeze(1)  # Scale to [1.278, 10.512]

            # Compute g(t) for all t_grid in parallel
            g_values = g_vectorized(
                t_grid.flatten(),
                case1_logits.repeat_interleave(n_grid, dim=0),
                case1_z_y.repeat_interleave(n_grid, dim=0),
                case1_Delta.repeat_interleave(n_grid, dim=0)
            ).view(-1, n_grid)

            # Find t_opt for each sample
            t_opt_indices = torch.argmax(g_values, dim=1)
            t_star_case1 = t_grid[torch.arange(t_grid.shape[0]), t_opt_indices]

            t_star[case1_mask] = t_star_case1

        # Case 2: max_{i ≠ y} z_i - z_y > 0
        case2_mask = (max_z_not_y - z_y) > 0
        if case2_mask.any():
            z_pi1_case2 = z_pi1[case2_mask]
            z_pi2_case2 = z_pi2[case2_mask]
            z_y_case2 = z_y[case2_mask]
            
            t_star_low_bound_case2 = torch.maximum(
                torch.tensor(low_bound, device=device),
                z_pi1_case2 - z_pi2_case2
            )
            
            safe_denominator = torch.where(
                (z_pi1_case2 - z_pi2_case2) == 0,
                torch.ones_like(z_pi1_case2 - z_pi2_case2) * 1e-8,
                (z_pi1_case2 - z_pi2_case2)
            )
            exponent_factor = (z_y_case2 - z_pi1_case2) / safe_denominator
            
            max_exp = -torch.log(torch.tensor(torch.finfo(torch.float32).tiny, device=device))
            t_star_upper_bound_case2 = torch.where(
                exponent_factor < 0,
                max_exp / exponent_factor.abs(),
                torch.tensor(torch.inf, device=device)
            )
            
            t_star_case2 = torch.maximum(
                t_star_low_bound_case2,
                t_star_upper_bound_case2 * decay
            )
            t_star[case2_mask] = t_star_case2

        # Case 3: max_{i ≠ y} z_i - z_y = 0
        case3_mask = (max_z_not_y - z_y) == 0
        if case3_mask.any():
            t_star[case3_mask] = low_bound

        return t_star

    def attack_single_run(self, x_in, y_in):
        x = x_in.clone() if len(x_in.shape) == 4 else x_in.clone().unsqueeze(0)
        y = y_in.clone() if len(y_in.shape) == 1 else y_in.clone().unsqueeze(0)

        
        x_adv = x.detach()
        x_adv = x_adv.clamp(0., 1.)
        x_best_adv = x_adv.clone()
        
        x_adv.requires_grad_()

        acc = self.model(x_adv).max(1)[1] == y

        step_size_begin = self.eps * torch.ones([x.shape[0], 1, 1, 1]).to(self.device).detach() * torch.Tensor([2.0]).to(
            self.device).detach().reshape([1, 1, 1, 1])
        x_adv_old = x_adv.clone()

        for i in range(self.n_iter):
            ### gradient step
            if self.decay_step == 'linear':
                step_size = step_size_begin * (1 - i / self.n_iter)
            elif self.decay_step == 'cosine':
                step_size = step_size_begin * (1 + math.cos(i / self.n_iter * math.pi)) * 0.5
            elif self.decay_step == 'cos':
                step_size = step_size_begin * math.cos(i / self.n_iter * math.pi * 0.5)
            elif self.decay_step == 'constant':
                step_size = torch.ones([x.shape[0], 1, 1, 1]).to(self.device).detach() * torch.Tensor([2.0 / 255.0]).to(
                    self.device).detach().reshape([1, 1, 1, 1])

            ### get gradient
            x_adv.requires_grad_()
            grad = torch.zeros_like(x)

            for _ in range(self.eot_iter):
                with torch.enable_grad():
                    x_adv_input = x_adv
                    outputs = self.model(x_adv_input)   
                    out_adv = outputs

                    mask_out_adv = self.check_right_index(out_adv, y)
                    mask_out_adv_grad = torch.unsqueeze(torch.unsqueeze(mask_out_adv.clone(), -1), -1)  # #############

                    if self.loss == 'ce':
                        criterion_indiv = nn.CrossEntropyLoss(reduce=False, reduction='none')
                        logits_prev = out_adv
                        
                    elif self.loss == 'dlr': 
                        criterion_indiv = self.dlr_loss
                        logits_prev = out_adv
                    elif self.loss == 'cw':  
                        criterion_indiv = self.cw_loss
                        logits_prev = out_adv
                    elif self.loss == 'mifpe':
                        criterion_indiv = nn.CrossEntropyLoss(reduce=False, reduction='none')
                        scale_output_one = self.get_output_scale_one(out_adv.clone().detach())
                        logits_prev = out_adv / scale_output_one

                    elif self.loss == 't_mifpe':
                        criterion_indiv = nn.CrossEntropyLoss(reduce=False, reduction='none')
                        scale_output_star = self.get_output_scale_star(out_adv.clone().detach(), y.clone().detach())
                        logits_prev = out_adv / scale_output_star
                    else:
                        raise ValueError(f'Unknown loss type: {self.loss}')
                    
                    loss_indiv_prev = criterion_indiv(logits_prev, y)
                    
                    loss_prev = loss_indiv_prev.sum()
                    logits = out_adv

                grad += torch.autograd.grad(loss_prev, [x_adv])[0].detach()   

            grad /= float(self.eot_iter)
            with torch.no_grad():
                x_adv = x_adv.detach()
                grad2 = x_adv - x_adv_old
                x_adv_old = x_adv.clone()

                a = 0.75 if i > 0 else 1.0

                if self.norm == 'Linf':
                    x_adv_1 = x_adv + mask_out_adv_grad * step_size * torch.sign(grad)
                    x_adv_1 = torch.clamp(torch.min(torch.max(x_adv_1, x - self.eps), x + self.eps), 0.0, 1.0)

                    x_adv_1 = torch.clamp(torch.min(
                        torch.max(x_adv + mask_out_adv_grad * ((x_adv_1 - x_adv) * a + grad2 * (1 - a)),
                                  x - self.eps),
                        x + self.eps), 0.0, 1.0)
                     
                elif self.norm == 'L2':
                    x_adv_1 = x_adv + step_size * grad / ((grad ** 2).sum(dim=(1, 2, 3), keepdim=True).sqrt() + 1e-12)
                    x_adv_1 = torch.clamp(x + (x_adv_1 - x) / (
                            ((x_adv_1 - x) ** 2).sum(dim=(1, 2, 3), keepdim=True).sqrt() + 1e-12) * torch.min(
                        self.eps * torch.ones(x.shape).to(self.device).detach(),
                        ((x_adv_1 - x) ** 2).sum(dim=(1, 2, 3), keepdim=True).sqrt()), 0.0, 1.0)
                    x_adv_1 = x_adv + (x_adv_1 - x_adv) * a + grad2 * (1 - a)
                    x_adv_1 = torch.clamp(x + (x_adv_1 - x) / (
                            ((x_adv_1 - x) ** 2).sum(dim=(1, 2, 3), keepdim=True).sqrt() + 1e-12) * torch.min(
                        self.eps * torch.ones(x.shape).to(self.device).detach(),
                        ((x_adv_1 - x) ** 2).sum(dim=(1, 2, 3), keepdim=True).sqrt() + 1e-12), 0.0, 1.0)

                x_adv = x_adv_1 + 0.
            out_adv = self.model(x_adv)
            pred = out_adv.detach().max(1)[1] == y
            acc = torch.min(acc, pred)

            x_best_adv[(pred == 0).nonzero().squeeze()] = x_adv[(pred == 0).nonzero().squeeze()] + 0.
         
        x_best_adv[(pred == 1).nonzero().squeeze()] = x_adv[(pred == 1).nonzero().squeeze()] + 0.
        return acc, x_best_adv

    def perturb(self, x_in,  y_in, best_loss=False, cheap=True):
        self.seed = 0
        assert self.norm in ['Linf', 'L2']
        x = x_in.clone() if len(x_in.shape) == 4 else x_in.clone().unsqueeze(0)
        y = y_in.clone() if len(y_in.shape) == 1 else y_in.clone().unsqueeze(0)

        adv = x.clone()
        x_input = x
        acc = self.model(x_input).max(1)[1] == y
        if self.verbose:
            print('-------------------------- running {}-attack with epsilon {:.4f} --------------------------'.format(
                self.norm, self.eps))
            print('initial accuracy: {:.2%}'.format(acc.float().mean()))
        startt = time.time()

        torch.random.manual_seed(self.seed)
        torch.cuda.random.manual_seed(self.seed)
        ind_to_fool = acc.nonzero().squeeze()
        if len(ind_to_fool.shape) == 0: ind_to_fool = ind_to_fool.unsqueeze(0)
        if ind_to_fool.numel() != 0:
            x_to_fool, y_to_fool = x[ind_to_fool].clone(), y[ind_to_fool].clone()
            acc_curr, adv_curr = self.attack_single_run(x_to_fool, y_to_fool)
            ind_curr = (acc_curr == 0).nonzero().squeeze()
            #
            acc[ind_to_fool[ind_curr]] = 0
            adv[ind_to_fool[ind_curr]] = adv_curr[ind_curr].clone()
            if self.verbose:
                print('restart {} - robust accuracy: {:.2%} - cum. time: {:.1f} s  '.format(
                    counter, acc.float().mean(), time.time() - startt ))

        return acc, adv


class T_MIFPE_targeted():
    def __init__(self, model, n_iter=100, norm='Linf', n_restarts=1, eps=None,
                 seed=0, loss='ce', eot_iter=1, rho=.75, verbose=False, device='cuda',
                 n_target_classes=9, decay_step='linear'  ):
        self.model = model
        self.n_iter = n_iter
        self.eps = eps - 1.9 * 1e-8
        self.norm = norm
        self.n_restarts = n_restarts
        self.seed = seed
        self.eot_iter = eot_iter
        self.thr_decr = rho
        self.verbose = verbose
        self.target_class = None
        self.device = device
        self.n_target_classes = n_target_classes
        self.loss = loss
        self.decay_step = decay_step
         

    def check_right_index(self, output, labels):
        output_index = output.argmax(dim=-1) == labels
        mask = output_index.to(dtype=torch.int8)
        mask = torch.unsqueeze(mask, -1)
        return mask

    def dlr_loss_targeted(self, x, y, y_target):
        x_sorted, ind_sorted = x.sort(dim=1)

        return -(x[np.arange(x.shape[0]), y] - x[np.arange(x.shape[0]), y_target]) / (
                x_sorted[:, -1] - .5 * x_sorted[:, -3] - .5 * x_sorted[:, -4] + 1e-12)

    def cw_loss_target(self,x,y,y_target):
        return -(x[np.arange(x.shape[0]), y] - x[np.arange(x.shape[0]), y_target])

    def ce_targeted(self, x, y, y_target):
        criterion = nn.CrossEntropyLoss(reduce=False, reduction='none')
        return -criterion(x, y_target)

    def get_output_scale_one(self, output):
        std_max_out = []
        maxk = max((10,))
        pred_val_out, pred_id_out = output.topk(maxk, 1, True, True)
        std_max_out.extend((pred_val_out[:, 0] - pred_val_out[:, 1]).cpu().numpy())
        scale_list = [item /1.0 for item in std_max_out]
        scale_list = torch.tensor(scale_list).to(self.device)
        scale_list = torch.unsqueeze(scale_list, -1)
        return scale_list


    def get_output_scale_star(self, output, labels, t_star_low_bound=1, t_star_upper_bound_decay=0.9):
        """
        Compute the new scale_new based on logits.
        
        Args:
            output (torch.Tensor): Input logits with shape (batch_size, K).
            labels (torch.Tensor): Ground truth labels.
            t_star_low_bound (float): Lower bound for t_star (default: 1).
            t_star_upper_bound_decay (float): Decay factor for upper bound (default: 0.9).
            
        Returns:
            scale_new (torch.Tensor): New scale values with shape (batch_size, 1).
        """
        # Ensure output is a 2D tensor
        assert output.dim() == 2, "output must be a 2D tensor (batch_size, K)"
        
        # Compute t^*
        t_star = self.compute_t_star(output, labels, low_bound=t_star_low_bound, decay=t_star_upper_bound_decay)  # Call the pre-defined compute_t_star function
        
        # Sort output to get z_pi1 and z_pi2
        sorted_output, _ = torch.sort(output, dim=1, descending=True)
        z_pi1 = sorted_output[:, 0]  # Maximum value
        z_pi2 = sorted_output[:, 1]  # Second-largest value
        
        # Compute Delta = z_pi1 - z_pi2
        Delta = z_pi1 - z_pi2
        
        # Compute scale_new = t^* / Delta
        scale_new = Delta / t_star 
        
        # Reshape scale_new into a 2D tensor (batch_size, 1)
        scale_new = torch.unsqueeze(scale_new, -1)
        
        return scale_new
    
    
    def compute_t_star(self, logits, labels, low_bound=1.278, decay=0.98, n_iter=10):
        """
        Compute t^* in parallel for Case 1 using vectorized operations.
        """
        batch_size, num_classes = logits.shape
        device = logits.device
        labels = labels.to(device).long()

        # Sort logits and compute Delta
        sorted_logits, _ = torch.sort(logits, dim=1, descending=True)
        z_pi1, z_pi2 = sorted_logits[:, 0], sorted_logits[:, 1]
        Delta = z_pi1 - z_pi2
        Delta_value = Delta.detach()

        # Get true label logits and max non-true logits
        z_y = torch.gather(logits, 1, labels.unsqueeze(1)).squeeze(1)
        mask = torch.ones_like(logits, dtype=torch.bool)
        mask.scatter_(1, labels.unsqueeze(1), 0)
        max_z_not_y = torch.where(mask, logits, -torch.inf).max(dim=1)[0]

        # Initialize t_star
        t_star = torch.zeros(batch_size, device=device)

        # Dictionary to store cases where t_star < 1
        t_star_lt1_dict = {
            't_star': [],
            'sorted_logits': [],
            'normalized_logits_diff': []  # (sorted_logits - z_pi1)/(z_pi1 - z_pi2)
        }

        # Case 1: max_{i ≠ y} z_i - z_y < 0 (z_y is the maximum)
        case1_mask = (max_z_not_y - z_y) < 0
        if case1_mask.any():
            case1_logits = logits[case1_mask]
            case1_sorted_logits = sorted_logits[case1_mask]  # Store sorted logits for this case
            case1_z_y = z_y[case1_mask]
            case1_Delta = Delta_value[case1_mask]

            # Define the vectorized g(t) for all samples in the batch
            def g_vectorized(t, z_values, z_pi1, delta_value):
                c = t / delta_value
                z_pi1 = z_pi1.unsqueeze(1)   
                exponent = (z_values - z_pi1) * c.unsqueeze(1)  # Shape: [batch_size * n_grid, num_classes]
                sum_exp = torch.sum(torch.exp(exponent), dim=1)  # Shape: [batch_size * n_grid]
                 
                p_pi1 = 1 / sum_exp
                 
                raw_value = c * (1 - p_pi1)
            
                # Add small epsilon to avoid log(0)
                safe_value = torch.clamp(raw_value, min=1e-10)
                return torch.log(safe_value)  # ln(c*(1-p_pi1^c))

            # Set fixed bounds: t ∈ [1.278, 10]
            t_lower = torch.full((case1_logits.shape[0],), self.low_bound, device=device)  # Lower bound = 1.278
            t_upper = torch.full((case1_logits.shape[0],), 10.512, device=device)    # Upper bound = 10.512

            # Generate t_grid for all samples (shape: [batch_size, n_grid])
            n_grid = 1000
            t_grid = torch.linspace(0, 1, n_grid, device=device).unsqueeze(0)  # Shape: [1, n_grid]
            t_grid = t_lower.unsqueeze(1) + t_grid * (t_upper - t_lower).unsqueeze(1)  # Scale to [1.278, 10.512]

            # Compute g(t) for all t_grid in parallel
            g_values = g_vectorized(
                t_grid.flatten(),
                case1_logits.repeat_interleave(n_grid, dim=0),
                case1_z_y.repeat_interleave(n_grid, dim=0),
                case1_Delta.repeat_interleave(n_grid, dim=0)
            ).view(-1, n_grid)

            # Find t_opt for each sample
            t_opt_indices = torch.argmax(g_values, dim=1)
            t_star_case1 = t_grid[torch.arange(t_grid.shape[0]), t_opt_indices]

            t_star[case1_mask] = t_star_case1

        # Case 2: max_{i ≠ y} z_i - z_y > 0
        case2_mask = (max_z_not_y - z_y) > 0
        if case2_mask.any():
            z_pi1_case2 = z_pi1[case2_mask]
            z_pi2_case2 = z_pi2[case2_mask]
            z_y_case2 = z_y[case2_mask]
            
            t_star_low_bound_case2 = torch.maximum(
                torch.tensor(low_bound, device=device),
                z_pi1_case2 - z_pi2_case2
            )
            
            safe_denominator = torch.where(
                (z_pi1_case2 - z_pi2_case2) == 0,
                torch.ones_like(z_pi1_case2 - z_pi2_case2) * 1e-8,
                (z_pi1_case2 - z_pi2_case2)
            )
            exponent_factor = (z_y_case2 - z_pi1_case2) / safe_denominator
            
            max_exp = -torch.log(torch.tensor(torch.finfo(torch.float32).tiny, device=device))
            t_star_upper_bound_case2 = torch.where(
                exponent_factor < 0,
                max_exp / exponent_factor.abs(),
                torch.tensor(torch.inf, device=device)
            )
            
            t_star_case2 = torch.maximum(
                t_star_low_bound_case2,
                t_star_upper_bound_case2 * decay
            )
            t_star[case2_mask] = t_star_case2

        # Case 3: max_{i ≠ y} z_i - z_y = 0
        case3_mask = (max_z_not_y - z_y) == 0
        if case3_mask.any():
            t_star[case3_mask] = low_bound

        return t_star


    def attack_single_run(self, x_in, y_in):
        x = x_in.clone() if len(x_in.shape) == 4 else x_in.clone().unsqueeze(0)
        y = y_in.clone() if len(y_in.shape) == 1 else y_in.clone().unsqueeze(0)

        x_adv = x.detach()
        x_adv = torch.clamp(torch.min(torch.max(x_adv, x - self.eps), x + self.eps), 0.0, 1.0)
        x_adv = x_adv.clamp(0., 1.)
        x_best_adv = x_adv.clone()

        x_input = x
        output = self.model(x_input)
        y_target = output.sort(dim=1)[1][:, -self.target_class]

        x_adv.requires_grad_()
        grad = torch.zeros_like(x)
        

        acc = self.model(x).max(1)[1] == y

        step_size_begin = self.eps * torch.ones([x.shape[0], 1, 1, 1]).to(self.device).detach() * torch.Tensor(
            [2.0]).to(self.device).detach().reshape([1, 1, 1, 1])

        x_adv_old = x_adv.clone()

        for i in range(self.n_iter):
            ### gradient step
            if self.decay_step == 'linear':
                step_size = step_size_begin * (1 - i / self.n_iter)
            elif self.decay_step == 'cosine':
                step_size = step_size_begin * (1 + math.cos(i / self.n_iter * math.pi)) * 0.5
            elif self.decay_step == 'cos':
                step_size = step_size_begin * math.cos(i / self.n_iter * math.pi * 0.5)
            elif self.decay_step == 'constant':
                step_size = torch.ones([x.shape[0], 1, 1, 1]).to(self.device).detach() * torch.Tensor([2.0 / 255.0]).to(
                    self.device).detach().reshape([1, 1, 1, 1])

            ### get gradient
            x_adv.requires_grad_()
            grad = torch.zeros_like(x)
            for _ in range(self.eot_iter):
                with torch.enable_grad():
                    x_adv_input = x_adv
                    outputs = self.model(x_adv_input)  # 1 forward pass (eot_iter = 1)
                    out_adv = outputs

                    mask_out_adv = self.check_right_index(out_adv, y)
                    mask_out_adv_grad = torch.unsqueeze(torch.unsqueeze(mask_out_adv.clone(), -1), -1)  ##

                    if self.loss == 'ce_targeted':
                        criterion_indiv = self.ce_targeted 
                        logits_prev = out_adv
                        
                    elif self.loss == 'dlr_targeted':     
                        criterion_indiv = self.dlr_loss_targeted
                        logits_prev = out_adv
                        
                    elif self.loss == 'cw_targeted':       
                        criterion_indiv = self.cw_loss_targeted
                        logits_prev = out_adv

                    elif self.loss == 'mifpe_targeted':
                        criterion_indiv = self.ce_targeted 
                        scale_output_one = self.get_output_scale_one(out_adv.clone().detach())
                        logits_prev = out_adv / scale_output_one

                    elif self.loss == 't_mifpe_targeted':
                        criterion_indiv = self.ce_targeted
                        scale_output_star = self.get_output_scale_star(out_adv.clone().detach(), y_target.clone().detach())
                        logits_prev = out_adv / scale_output_star
                    else:
                        raise ValueError(f'Unknown loss type: {self.loss}')
                    
                    loss_indiv_prev = criterion_indiv(logits_prev, y, y_target)
                    loss_prev = loss_indiv_prev.sum()
                    logits = out_adv

                grad += torch.autograd.grad(loss_prev, [x_adv])[0].detach()  # 1 backward pass (eot_iter = 1)

            grad /= float(self.eot_iter)
            with torch.no_grad():
                x_adv = x_adv.detach()
                grad2 = x_adv - x_adv_old
                x_adv_old = x_adv.clone()

                a = 0.75 if i > 0 else 1.0

                if self.norm == 'Linf':
                    x_adv_1 = x_adv + mask_out_adv_grad * step_size * torch.sign(grad)
                    x_adv_1 = torch.clamp(torch.min(torch.max(x_adv_1, x - self.eps), x + self.eps), 0.0, 1.0)

                    x_adv_1 = torch.clamp(torch.min(
                        torch.max(x_adv + mask_out_adv_grad * ((x_adv_1 - x_adv) * a + grad2 * (1 - a)),
                                  x - self.eps),
                        x + self.eps), 0.0, 1.0)
                    
                elif self.norm == 'L2':
                    x_adv_1 = x_adv + step_size[0] * grad / (
                            (grad ** 2).sum(dim=(1, 2, 3), keepdim=True).sqrt() + 1e-12)
                    x_adv_1 = torch.clamp(x + (x_adv_1 - x) / (
                            ((x_adv_1 - x) ** 2).sum(dim=(1, 2, 3), keepdim=True).sqrt() + 1e-12) * torch.min(
                        self.eps * torch.ones(x.shape).to(self.device).detach(),
                        ((x_adv_1 - x) ** 2).sum(dim=(1, 2, 3), keepdim=True).sqrt()), 0.0, 1.0)
                    x_adv_1 = x_adv + (x_adv_1 - x_adv) * a + grad2 * (1 - a)
                    x_adv_1 = torch.clamp(x + (x_adv_1 - x) / (
                            ((x_adv_1 - x) ** 2).sum(dim=(1, 2, 3), keepdim=True).sqrt() + 1e-12) * torch.min(
                        self.eps * torch.ones(x.shape).to(self.device).detach(),
                        ((x_adv_1 - x) ** 2).sum(dim=(1, 2, 3), keepdim=True).sqrt() + 1e-12), 0.0, 1.0)

                x_adv = x_adv_1 + 0.

            out_adv = self.model(x_adv)
            pred = out_adv.detach().max(1)[1] == y
            acc = torch.min(acc, pred)

            x_best_adv[(pred == 0).nonzero().squeeze()] = x_adv[(pred == 0).nonzero().squeeze()] + 0.

        x_best_adv[(pred == 1).nonzero().squeeze()] = x_adv[(pred == 1).nonzero().squeeze()] + 0.

        return acc, x_best_adv

    def perturb(self, x_in, y_in, best_loss=False, cheap=True):
        self.seed = 0
        assert self.norm in ['Linf', 'L2']
        x = x_in.clone() if len(x_in.shape) == 4 else x_in.clone().unsqueeze(0)
        y = y_in.clone() if len(y_in.shape) == 1 else y_in.clone().unsqueeze(0)

        adv = x.clone()
        x_input = x
        acc = self.model(x_input).max(1)[1] == y
        # loss = -1e10 * torch.ones_like(acc).float()
        if self.verbose:
            print('-------------------------- running {}-attack with epsilon {:.4f} --------------------------'.format(
                self.norm, self.eps))
            print('initial accuracy: {:.2%}'.format(acc.float().mean()))
        startt = time.time()

        torch.random.manual_seed(self.seed)
        torch.cuda.random.manual_seed(self.seed)

        for target_class in range(2, self.n_target_classes + 2):
            self.target_class = target_class
            ind_to_fool = acc.nonzero().squeeze()
            if len(ind_to_fool.shape) == 0: ind_to_fool = ind_to_fool.unsqueeze(0)
            if ind_to_fool.numel() != 0:
                x_to_fool, y_to_fool = x[ind_to_fool].clone(), y[ind_to_fool].clone()
                acc_curr, adv_curr = self.attack_single_run(x_to_fool, y_to_fool)
                ind_curr = (acc_curr == 0).nonzero().squeeze()
                #
                acc[ind_to_fool[ind_curr]] = 0
                adv[ind_to_fool[ind_curr]] = adv_curr[ind_curr].clone()
                if self.verbose:
                    print(
                        'restart {} - target_class {} - robust accuracy: {:.2%} at eps = {:.5f} - cum. time: {:.1f} s  '.format(
                            counter, self.target_class, acc.float().mean(), self.eps, time.time() - startt ))

        return acc, adv
