import numpy as np
import torch
import math

 

# class EpsilonScheduler():
#     def __init__(self, schedule_type, init_step, final_step, init_value, final_value, num_steps_per_epoch, mid_point=.25, beta=4.):
#         self.schedule_type = schedule_type
#         self.init_step = init_step
#         self.final_step = final_step
#         self.init_value = init_value
#         self.final_value = final_value
#         self.mid_point = mid_point
#         self.beta = beta
#         self.num_steps_per_epoch = num_steps_per_epoch
#         assert self.final_value >= self.init_value
#         assert self.final_step >= self.init_step
#         assert self.beta >= 2.
#         assert self.mid_point >= 0. and self.mid_point <= 1. 
    
#     def get_eps(self, epoch, step):
#         if self.schedule_type == "smoothed":
#             return self.smooth_schedule(epoch * self.num_steps_per_epoch + step, self.init_step, self.final_step, self.init_value, self.final_value, self.mid_point, self.beta)
#         else:
#             return self.linear_schedule(epoch * self.num_steps_per_epoch + step, self.init_step, self.final_step, self.init_value, self.final_value)
    
#     # Smooth schedule that slowly morphs into a linear schedule.
#     # Code is adapted from DeepMind's IBP implementation:
#     # https://github.com/deepmind/interval-bound-propagation/blob/2c1a56cb0497d6f34514044877a8507c22c1bd85/interval_bound_propagation/src/utils.py#L84
#     def smooth_schedule(self, step, init_step, final_step, init_value, final_value, mid_point=.25, beta=4.):
#         """Smooth schedule that slowly morphs into a linear schedule."""
#         assert final_value >= init_value
#         assert final_step >= init_step
#         assert beta >= 2.
#         assert mid_point >= 0. and mid_point <= 1.
#         mid_step = int((final_step - init_step) * mid_point) + init_step 
#         if mid_step <= init_step:
#             alpha = 1.
#         else:
#             t = (mid_step - init_step) ** (beta - 1.)
#             alpha = (final_value - init_value) / ((final_step - mid_step) * beta * t + (mid_step - init_step) * t)
#         mid_value = alpha * (mid_step - init_step) ** beta + init_value
#         is_ramp = float(step > init_step)
#         is_linear = float(step >= mid_step)
#         return (is_ramp * (
#             (1. - is_linear) * (
#                 init_value +
#                 alpha * float(step - init_step) ** beta) +
#             is_linear * self.linear_schedule(
#                 step, mid_step, final_step, mid_value, final_value)) +
#                 (1. - is_ramp) * init_value)
        
#     # Linear schedule.
#     # Code is adapted from DeepMind's IBP implementation:
#     # https://github.com/deepmind/interval-bound-propagation/blob/2c1a56cb0497d6f34514044877a8507c22c1bd85/interval_bound_propagation/src/utils.py#L73 
#     def linear_schedule(self, step, init_step, final_step, init_value, final_value):
#         """Linear schedule."""
#         assert final_step >= init_step
#         if init_step == final_step:
#             return final_value
#         rate = float(step - init_step) / float(final_step - init_step)
#         linear_value = rate * (final_value - init_value) + init_value
#         return np.clip(linear_value, min(init_value, final_value), max(init_value, final_value))



class EpsilonScheduler():
    def __init__(self, schedule_type, num_steps_per_epoch, eps_params = None, eps_sample=0):
        self.schedule_type = schedule_type
        if schedule_type == 'eval':
            self.schedule_start = -1
            self.eval_eps = torch.tensor(eps_params)
            return
        
        self.test_eps = torch.tensor(eps_params["test_eps"])
        self.num_steps_per_epoch = num_steps_per_epoch
        schedule_start = eps_params["schedule_start"]
        schedule_length = eps_params["schedule_length"]
        self.schedule_start = schedule_start
        self.schedule_length = schedule_length
        

        self.init_step = schedule_start*num_steps_per_epoch
        self.final_step = ((schedule_start + schedule_length) - 1) * num_steps_per_epoch

        self.cur_base_eps = 0
        if schedule_type == 'test':
            end_epsilon = torch.tensor(eps_params["standard_params"]["epsilon"])
            self.final_value = end_epsilon
            return

        if schedule_type in ["standard-linear","standard-smooth","standard-expo"]:
            
            starting_epsilon = torch.tensor(eps_params["standard_params"]["starting_epsilon"])
            end_epsilon = torch.tensor(eps_params["standard_params"]["epsilon"])
            self.init_value = starting_epsilon
            self.final_value = end_epsilon
            self.beta = eps_params["standard_params"].get("beta", 4.)
            self.num_steps_per_epoch = num_steps_per_epoch
            self.mid_point=eps_params["standard_params"].get("mid_point", .25)
            self.max_eps = end_epsilon

            assert self.final_value >= self.init_value
            assert self.final_step >= self.init_step
            assert self.beta >= 2.
            assert self.mid_point >= 0. and self.mid_point <= 1. 

        elif schedule_type == "autoeps":
            self.init_step = schedule_start*num_steps_per_epoch
            self.autoeps_type = eps_params["autoeps_params"]["autoeps_type"]
            self.eps_update_step1 = eps_params["autoeps_params"].get("eps_update_step1", 0.002)
            self.eps_update_step2 = eps_params["autoeps_params"].get("eps_update_step2", 0.01)
            self.eps_warmup_epochs = eps_params["autoeps_params"].get("eps_warmup_epochs", 5)
            self.diff_coff = eps_params["autoeps_params"].get("diff_coff", 4)
            self.step_coef = eps_params["autoeps_params"].get("step_coef", 5)
            self.train_sample_size = eps_params["autoeps_params"].get("train_sample_size", 60000)
            self.eps_sample = torch.zeros(self.train_sample_size)+eps_sample
            self.max_eps = eps_params["autoeps_params"]["max_eps"]
        elif schedule_type in ["autoeps-linear", "autoeps-smooth"]:

            self.init_step = schedule_start*num_steps_per_epoch
            self.autoeps_type = eps_params["autoeps_params"]["autoeps_type"]
            self.eps_update_step1 = eps_params["autoeps_params"].get("eps_update_step1", 0.002)
            self.eps_update_step2 = eps_params["autoeps_params"].get("eps_update_step2", 0.01)
            self.eps_warmup_epochs = eps_params["autoeps_params"].get("eps_warmup_epochs", 5)
            self.diff_coff = eps_params["autoeps_params"].get("diff_coff", 4)
            self.step_coef = eps_params["autoeps_params"].get("step_coef", 5)
            self.train_sample_size = eps_params["autoeps_params"].get("train_sample_size", 60000)
            self.eps_sample = torch.zeros(self.train_sample_size)+eps_sample
            self.max_eps = eps_params["autoeps_params"]["max_eps"]
            starting_epsilon = torch.tensor(eps_params["standard_params"]["starting_epsilon"])
            end_epsilon = torch.tensor(eps_params["standard_params"]["epsilon"])
            self.init_value = starting_epsilon
            self.final_value = end_epsilon
            self.beta = eps_params["standard_params"].get("beta", 4.)
            self.num_steps_per_epoch = num_steps_per_epoch
            self.mid_point=eps_params["standard_params"].get("mid_point", .25)
            self.max_eps = end_epsilon
            self.max_offset = eps_params["autoeps_params"]["max_offset"]

        if "epsnoise_params" in eps_params:
            self.epsnoise_type = eps_params["epsnoise_params"]["epsnoise_type"]
            self.sigma = eps_params["epsnoise_params"]["sigma"]
        else: 
            self.epsnoise_type = 0

        return 
    
    def get_beta(self, epoch, step, crown_final_beta):
        # print(eps, crown_final_beta)
        # if self.schedule_type in ["standard-linear","standard-smooth","autoeps-linear","autoeps-smooth","test"]:
        #     return (self.final_value - eps * (1.0 - crown_final_beta)) / self.final_value
         ## autoeps doesn't have beta schedule yet

        cur_step = epoch * self.num_steps_per_epoch + step
        beta = 1-(min(max(cur_step-self.init_step,0),self.final_step-self.init_step)) * (1.0 - crown_final_beta)\
                 / (self.final_step-self.init_step)
        return beta

    def get_kappa(self, epoch, step, natural_final_factor):
        cur_step = epoch * self.num_steps_per_epoch + step
        kappa = 1-(min(max(cur_step-self.init_step,0),self.final_step-self.init_step)) * (1.0 - natural_final_factor)\
                 / (self.final_step-self.init_step)
        return kappa

    def get_eps(self, epoch, step):
        if self.schedule_type == 'eval':
            return self.eval_eps
        if self.schedule_type == 'test':
            return self.test_eps

        if self.schedule_type == "standard-smooth":
            eps =  self.smooth_schedule(epoch * self.num_steps_per_epoch + step, self.init_step, self.final_step, self.init_value, self.final_value, self.mid_point, self.beta)
        elif self.schedule_type == "standard-linear":
            eps =  self.linear_schedule(epoch * self.num_steps_per_epoch + step, self.init_step, self.final_step, self.init_value, self.final_value)
        elif self.schedule_type == "standard-expo":
            eps =  self.expo_schedule(epoch * self.num_steps_per_epoch + step, self.init_step, self.final_step, self.init_value, self.final_value)
        elif self.schedule_type == "autoeps":
            eps = self.autoeps_schedule(epoch * self.num_steps_per_epoch + step)
        elif self.schedule_type == "autoeps-linear":
            eps = self.autoeps_linear_schedule(epoch * self.num_steps_per_epoch + step, self.init_step, self.final_step, self.init_value, self.final_value)
        elif self.schedule_type == "autoeps-smooth":
            eps = self.autoeps_smooth_schedule(epoch * self.num_steps_per_epoch + step, self.init_step, self.final_step, self.init_value, self.final_value, self.mid_point, self.beta)
        return eps
        
        
    
    # Smooth schedule that slowly morphs into a linear schedule.
    # Code is adapted from DeepMind's IBP implementation:
    # https://github.com/deepmind/interval-bound-propagation/blob/2c1a56cb0497d6f34514044877a8507c22c1bd85/interval_bound_propagation/src/utils.py#L84
    def smooth_schedule(self, step, init_step, final_step, init_value, final_value, mid_point=.25, beta=4.):
        """Smooth schedule that slowly morphs into a linear schedule."""
        assert final_value >= init_value
        assert final_step >= init_step
        assert beta >= 2.
        assert mid_point >= 0. and mid_point <= 1.
        mid_step = int((final_step - init_step) * mid_point) + init_step 
        if mid_step <= init_step:
            alpha = 1.
        else:
            t = (mid_step - init_step) ** (beta - 1.)
            alpha = (final_value - init_value) / ((final_step - mid_step) * beta * t + (mid_step - init_step) * t)
        mid_value = alpha * (mid_step - init_step) ** beta + init_value
        is_ramp = float(step > init_step)
        is_linear = float(step >= mid_step)
        eps =  (is_ramp * (
            (1. - is_linear) * (
                init_value +
                alpha * float(step - init_step) ** beta) +
            is_linear * self.linear_schedule(
                step, mid_step, final_step, mid_value, final_value)) +
                (1. - is_ramp) * init_value)
        self.cur_base_eps  = eps
        return eps
        
    # Linear schedule.
    # Code is adapted from DeepMind's IBP implementation:
    # https://github.com/deepmind/interval-bound-propagation/blob/2c1a56cb0497d6f34514044877a8507c22c1bd85/interval_bound_propagation/src/utils.py#L73 
    def linear_schedule(self, step, init_step, final_step, init_value, final_value):
        """Linear schedule."""
        assert final_step >= init_step
        if init_step == final_step:
            return final_value
        rate = float(step - init_step) / float(final_step - init_step)
        linear_value = rate * (final_value - init_value) + init_value
        eps =  np.clip(linear_value, min(init_value, final_value), max(init_value, final_value))
        self.cur_base_eps  = eps
        return eps
    def expo_schedule(self, step, init_step, final_step, init_value, final_value, mid_point=.25, alpha = 1.):
        """expo schedule."""
        assert final_value >= init_value
        assert final_step >= init_step
        assert mid_point >= 0. and mid_point <= 1.
        rate = torch.log((final_value-init_value)/alpha+1)/(final_step-init_step)
        if step <= init_step:
            eps = init_value
        elif step >=final_step:
            eps = final_value
        else:
            eps = alpha*(torch.exp((step-init_step)*rate)-1)+init_value

        self.cur_base_eps  = eps
        return eps
        
    def autoeps_schedule(self, step):
        if step < self.init_step:
            return torch.tensor(0.)
        return self.eps_sample

    def autoeps_linear_schedule(self, step, init_step, final_step, init_value, final_value):
        auto_eps = self.autoeps_schedule(step)
        linear_eps = self.linear_schedule(step, init_step, final_step, init_value, final_value)
        eps = auto_eps+linear_eps
        eps = torch.max(torch.zeros_like(eps), eps)
        return eps

    def autoeps_smooth_schedule(self, step, init_step, final_step, init_value, final_value, mid_point=.25, beta=4.):
        auto_eps = self.autoeps_schedule(step)
        smooth_eps = self.smooth_schedule(step, init_step, final_step, init_value, final_value, mid_point, beta)
        eps = auto_eps+smooth_eps
        eps = torch.max(torch.zeros_like(eps), eps)
        return eps

    def update_eps_sample(self, index, lb, epoch, labels = None):
        if self.autoeps_type == 0:
            return
        if self.schedule_type in ["autoeps-linear","autoeps-smooth"]:
            if self.autoeps_type == 6:
                if epoch <= self.eps_warmup_epochs:
                    eps_update_step = self.eps_update_step1
                else:
                    eps_update_step = self.eps_update_step2

                if_wrong_mask = (lb<0).any(dim=1)
                index_wrong = index[if_wrong_mask]
                self.eps_sample[index_wrong] -= eps_update_step

                if_correct_mask = ~if_wrong_mask
                index_correct = index[if_correct_mask]
                self.eps_sample[index_correct] += eps_update_step

                self.eps_sample[index] = torch.clamp(self.eps_sample[index],-self.max_offset,0)
                return
            elif self.autoeps_type == 16:
                # autoeps 6 on both side
                if epoch <= self.eps_warmup_epochs:
                    eps_update_step = self.eps_update_step1
                else:
                    eps_update_step = self.eps_update_step2

                if_wrong_mask = (lb<0).any(dim=1)
                index_wrong = index[if_wrong_mask]
                self.eps_sample[index_wrong] -= eps_update_step

                if_correct_mask = ~if_wrong_mask
                index_correct = index[if_correct_mask]
                self.eps_sample[index_correct] += eps_update_step

                self.eps_sample[index] = torch.clamp(self.eps_sample[index],-self.max_offset,self.max_offset)
                return
            elif self.autoeps_type == 17:
                # autoeps 6, negative side
                if epoch <= self.eps_warmup_epochs:
                    eps_update_step = self.eps_update_step1
                else:
                    eps_update_step = self.eps_update_step2

                if_wrong_mask = (lb<0).any(dim=1)
                index_wrong = index[if_wrong_mask]
                self.eps_sample[index_wrong] -= eps_update_step

                if_correct_mask = ~if_wrong_mask
                index_correct = index[if_correct_mask]
                self.eps_sample[index_correct] += eps_update_step

                self.eps_sample[index] = torch.clamp(self.eps_sample[index],0,self.max_offset)
                return

            elif self.autoeps_type == 7:

                lb_ = torch.clone(lb).detach()
                prob = torch.nn.Softmax(dim =1)(-lb_)
                lb_min, lb_min_idx = torch.min(lb_, dim = 1)
                prob_max = torch.gather(prob, 1, lb_min_idx.unsqueeze(-1)).squeeze(-1)
                if_correct = (lb_min>=0).float().cuda()
                lb_[torch.arange(lb.shape[0]), lb_min_idx] = float("Inf")
                lb_min2, lb_min2_idx = torch.min(lb_, dim = 1)
                prob_max2 = torch.gather(prob, 1, lb_min2_idx.unsqueeze(-1)).squeeze(-1)
                prob_correct = torch.gather(prob, 1, labels.unsqueeze(-1)).squeeze(-1)
                margin = ((prob_correct-prob_max)*(1-if_correct)+(prob_correct-prob_max2)*if_correct)
                
                self.eps_sample[index] = -self.max_offset*torch.nn.ReLU()(-margin).cpu()
                return
            elif self.autoeps_type == 8:
# eps offset can be both positive or negative
                lb_ = torch.clone(lb).detach()
                prob = torch.nn.Softmax(dim =1)(-lb_)
                lb_min, lb_min_idx = torch.min(lb_, dim = 1)
                prob_max = torch.gather(prob, 1, lb_min_idx.unsqueeze(-1)).squeeze(-1)
                if_correct = (lb_min>=0).float().cuda()
                lb_[torch.arange(lb.shape[0]), lb_min_idx] = float("Inf")
                lb_min2, lb_min2_idx = torch.min(lb_, dim = 1)
                prob_max2 = torch.gather(prob, 1, lb_min2_idx.unsqueeze(-1)).squeeze(-1)
                prob_correct = torch.gather(prob, 1, labels.unsqueeze(-1)).squeeze(-1)
                margin = ((prob_correct-prob_max)*(1-if_correct)+(prob_correct-prob_max2)*if_correct)
                
                self.eps_sample[index] = min(self.max_offset,self.cur_base_eps)*(margin).cpu()
                return
            elif self.autoeps_type == 9:
# mean(eps offset) = 0
                lb_ = torch.clone(lb).detach()
                prob = torch.nn.Softmax(dim =1)(-lb_)
                lb_min, lb_min_idx = torch.min(lb_, dim = 1)
                prob_max = torch.gather(prob, 1, lb_min_idx.unsqueeze(-1)).squeeze(-1)
                if_correct = (lb_min>=0).float().cuda()
                lb_[torch.arange(lb.shape[0]), lb_min_idx] = float("Inf")
                lb_min2, lb_min2_idx = torch.min(lb_, dim = 1)
                prob_max2 = torch.gather(prob, 1, lb_min2_idx.unsqueeze(-1)).squeeze(-1)
                prob_correct = torch.gather(prob, 1, labels.unsqueeze(-1)).squeeze(-1)
                margin = ((prob_correct-prob_max)*(1-if_correct)+(prob_correct-prob_max2)*if_correct)
                
                self.eps_sample[index] = min(self.max_offset,self.cur_base_eps)*(margin).cpu()
                self.eps_sample[index] = self.eps_sample[index] -self.eps_sample[index].mean()
                return
            elif self.autoeps_type == 10:
# eps offset can be both positive or negative, negative to margin
                lb_ = torch.clone(lb).detach()
                prob = torch.nn.Softmax(dim =1)(-lb_)
                lb_min, lb_min_idx = torch.min(lb_, dim = 1)
                prob_max = torch.gather(prob, 1, lb_min_idx.unsqueeze(-1)).squeeze(-1)
                if_correct = (lb_min>=0).float().cuda()
                lb_[torch.arange(lb.shape[0]), lb_min_idx] = float("Inf")
                lb_min2, lb_min2_idx = torch.min(lb_, dim = 1)
                prob_max2 = torch.gather(prob, 1, lb_min2_idx.unsqueeze(-1)).squeeze(-1)
                prob_correct = torch.gather(prob, 1, labels.unsqueeze(-1)).squeeze(-1)
                margin = ((prob_correct-prob_max)*(1-if_correct)+(prob_correct-prob_max2)*if_correct)
                
                self.eps_sample[index] = min(self.max_offset,self.cur_base_eps)*(-margin).cpu()
                return
            elif self.autoeps_type == 11:
# mean(eps offset) = 0,  negative to margin
                lb_ = torch.clone(lb).detach()
                prob = torch.nn.Softmax(dim =1)(-lb_)
                lb_min, lb_min_idx = torch.min(lb_, dim = 1)
                prob_max = torch.gather(prob, 1, lb_min_idx.unsqueeze(-1)).squeeze(-1)
                if_correct = (lb_min>=0).float().cuda()
                lb_[torch.arange(lb.shape[0]), lb_min_idx] = float("Inf")
                lb_min2, lb_min2_idx = torch.min(lb_, dim = 1)
                prob_max2 = torch.gather(prob, 1, lb_min2_idx.unsqueeze(-1)).squeeze(-1)
                prob_correct = torch.gather(prob, 1, labels.unsqueeze(-1)).squeeze(-1)
                margin = ((prob_correct-prob_max)*(1-if_correct)+(prob_correct-prob_max2)*if_correct)
                
                self.eps_sample[index] = min(self.max_offset,self.cur_base_eps)*(-margin).cpu()
                self.eps_sample[index] = self.eps_sample[index] -self.eps_sample[index].mean()
                return
            elif self.autoeps_type == 12:
# mean(eps offset) = 0,   negative to margin, max_offset became ratio
                lb_ = torch.clone(lb).detach()
                prob = torch.nn.Softmax(dim =1)(-lb_)
                lb_min, lb_min_idx = torch.min(lb_, dim = 1)
                prob_max = torch.gather(prob, 1, lb_min_idx.unsqueeze(-1)).squeeze(-1)
                if_correct = (lb_min>=0).float().cuda()
                lb_[torch.arange(lb.shape[0]), lb_min_idx] = float("Inf")
                lb_min2, lb_min2_idx = torch.min(lb_, dim = 1)
                prob_max2 = torch.gather(prob, 1, lb_min2_idx.unsqueeze(-1)).squeeze(-1)
                prob_correct = torch.gather(prob, 1, labels.unsqueeze(-1)).squeeze(-1)
                margin = ((prob_correct-prob_max)*(1-if_correct)+(prob_correct-prob_max2)*if_correct)
                
                self.eps_sample[index] = -self.max_offset*self.cur_base_eps*(margin).cpu()
                self.eps_sample[index] = self.eps_sample[index] -self.eps_sample[index].mean()
                return
            elif self.autoeps_type == 13:
# mean(eps offset) = 0,   max_offset became ratio
                lb_ = torch.clone(lb).detach()
                prob = torch.nn.Softmax(dim =1)(-lb_)
                lb_min, lb_min_idx = torch.min(lb_, dim = 1)
                prob_max = torch.gather(prob, 1, lb_min_idx.unsqueeze(-1)).squeeze(-1)
                if_correct = (lb_min>=0).float().cuda()
                lb_[torch.arange(lb.shape[0]), lb_min_idx] = float("Inf")
                lb_min2, lb_min2_idx = torch.min(lb_, dim = 1)
                prob_max2 = torch.gather(prob, 1, lb_min2_idx.unsqueeze(-1)).squeeze(-1)
                prob_correct = torch.gather(prob, 1, labels.unsqueeze(-1)).squeeze(-1)
                margin = ((prob_correct-prob_max)*(1-if_correct)+(prob_correct-prob_max2)*if_correct)
                
                self.eps_sample[index] = self.max_offset*self.cur_base_eps*(margin).cpu()
                self.eps_sample[index] = self.eps_sample[index] -self.eps_sample[index].mean()
                return

            elif self.autoeps_type == 14:
# mean(eps offset) = 0, decrease if margin<0, max_offset became ratio
                lb_ = torch.clone(lb).detach()
                prob = torch.nn.Softmax(dim =1)(-lb_)
                lb_min, lb_min_idx = torch.min(lb_, dim = 1)
                prob_max = torch.gather(prob, 1, lb_min_idx.unsqueeze(-1)).squeeze(-1)
                if_correct = (lb_min>=0).float().cuda()
                lb_[torch.arange(lb.shape[0]), lb_min_idx] = float("Inf")
                lb_min2, lb_min2_idx = torch.min(lb_, dim = 1)
                prob_max2 = torch.gather(prob, 1, lb_min2_idx.unsqueeze(-1)).squeeze(-1)
                prob_correct = torch.gather(prob, 1, labels.unsqueeze(-1)).squeeze(-1)
                margin = ((prob_correct-prob_max)*(1-if_correct)+(prob_correct-prob_max2)*if_correct)
                
                self.eps_sample[index] = -self.max_offset*self.cur_base_eps*torch.nn.ReLU()(-margin).cpu()
                self.eps_sample[index] = self.eps_sample[index] -self.eps_sample[index].mean()
                return
            
            elif self.autoeps_type == 15:
# mean(eps offset) = 0, decrease if margin<0, max_offset became ratio
                lb_ = torch.clone(lb).detach()
                prob = torch.nn.Softmax(dim =1)(-lb_)
                lb_min, lb_min_idx = torch.min(lb_, dim = 1)
                prob_max = torch.gather(prob, 1, lb_min_idx.unsqueeze(-1)).squeeze(-1)
                if_correct = (lb_min>=0).float().cuda()
                lb_[torch.arange(lb.shape[0]), lb_min_idx] = float("Inf")
                lb_min2, lb_min2_idx = torch.min(lb_, dim = 1)
                prob_max2 = torch.gather(prob, 1, lb_min2_idx.unsqueeze(-1)).squeeze(-1)
                prob_correct = torch.gather(prob, 1, labels.unsqueeze(-1)).squeeze(-1)
                margin = ((prob_correct-prob_max)*(1-if_correct)+(prob_correct-prob_max2)*if_correct)
                
                self.eps_sample[index] = -self.max_offset*self.cur_base_eps*torch.nn.ReLU()(-margin).cpu()

                return
            elif self.autoeps_type == 16:
# mean(eps offset) = 0, increase if margin<0, max_offset became ratio
                lb_ = torch.clone(lb).detach()
                prob = torch.nn.Softmax(dim =1)(-lb_)
                lb_min, lb_min_idx = torch.min(lb_, dim = 1)
                prob_max = torch.gather(prob, 1, lb_min_idx.unsqueeze(-1)).squeeze(-1)
                if_correct = (lb_min>=0).float().cuda()
                lb_[torch.arange(lb.shape[0]), lb_min_idx] = float("Inf")
                lb_min2, lb_min2_idx = torch.min(lb_, dim = 1)
                prob_max2 = torch.gather(prob, 1, lb_min2_idx.unsqueeze(-1)).squeeze(-1)
                prob_correct = torch.gather(prob, 1, labels.unsqueeze(-1)).squeeze(-1)
                margin = ((prob_correct-prob_max)*(1-if_correct)+(prob_correct-prob_max2)*if_correct)
                
                self.eps_sample[index] = self.max_offset*self.cur_base_eps*torch.nn.ReLU()(-margin).cpu()
                self.eps_sample[index] = self.eps_sample[index] -self.eps_sample[index].mean()
                return
            
            elif self.autoeps_type == 17:
#  increase if margin<0, max_offset became ratio
                lb_ = torch.clone(lb).detach()
                prob = torch.nn.Softmax(dim =1)(-lb_)
                lb_min, lb_min_idx = torch.min(lb_, dim = 1)
                prob_max = torch.gather(prob, 1, lb_min_idx.unsqueeze(-1)).squeeze(-1)
                if_correct = (lb_min>=0).float().cuda()
                lb_[torch.arange(lb.shape[0]), lb_min_idx] = float("Inf")
                lb_min2, lb_min2_idx = torch.min(lb_, dim = 1)
                prob_max2 = torch.gather(prob, 1, lb_min2_idx.unsqueeze(-1)).squeeze(-1)
                prob_correct = torch.gather(prob, 1, labels.unsqueeze(-1)).squeeze(-1)
                margin = ((prob_correct-prob_max)*(1-if_correct)+(prob_correct-prob_max2)*if_correct)
                
                self.eps_sample[index] = self.max_offset*self.cur_base_eps*torch.nn.ReLU()(-margin).cpu()

                return

                

            if epoch <= self.eps_warmup_epochs:
                eps_update_step = self.eps_update_step1
            else:
                eps_update_step = self.eps_update_step2

            if_wrong_mask = (lb<0).any(dim=1)
            index_wrong = index[if_wrong_mask]
            self.eps_sample[index_wrong] -= eps_update_step

            if_correct_mask = ~if_wrong_mask
            index_correct = index[if_correct_mask]
            self.eps_sample[index_correct] += eps_update_step

            self.eps_sample[index] = torch.clamp(self.eps_sample[index],-self.max_offset,self.max_offset)

            return
 
        if self.autoeps_type == 1:
            if epoch <= self.eps_warmup_epochs:
                eps_update_step = self.eps_update_step1
            else:
                eps_update_step = self.eps_update_step2

            if_wrong_mask = (lb<0).any(dim=1)
            index_wrong = index[if_wrong_mask]
            self.eps_sample[index_wrong] -= eps_update_step*self.step_coef

            if_correct_mask = ~if_wrong_mask
            index_correct = index[if_correct_mask]
            self.eps_sample[index_correct] += eps_update_step

            self.eps_sample[index] = torch.max(self.eps_sample[index], torch.zeros_like(self.eps_sample[index]))

            mean = torch.mean(self.eps_sample[index])
            # diff = mean-self.eps_sample
            # self.eps_sample[index_wrong] = self.eps_sample[index_wrong] + diff[index_wrong]/self.diff_coff
            diff = mean-self.eps_sample[index]
            self.eps_sample[index] = self.eps_sample[index] + diff/self.diff_coff
        elif self.autoeps_type == 2:
            if epoch <= self.eps_warmup_epochs:
                eps_update_step = self.eps_update_step1
            else:
                eps_update_step = self.eps_update_step2

            if_wrong_mask = (lb<0).any(dim=1)
            index_wrong = index[if_wrong_mask]
            self.eps_sample[index_wrong] -= eps_update_step*self.step_coef

            if_correct_mask = ~if_wrong_mask
            index_correct = index[if_correct_mask]
            self.eps_sample[index_correct] += eps_update_step

            self.eps_sample[index] = torch.max(self.eps_sample[index], torch.zeros_like(self.eps_sample[index]))

            mean = torch.mean(self.eps_sample[index])
            diff = mean-self.eps_sample
            self.eps_sample[index_wrong] = self.eps_sample[index_wrong] + diff[index_wrong]/self.diff_coff

        elif self.autoeps_type == 4:
            if epoch <= self.eps_warmup_epochs:
                eps_update_step = self.eps_update_step1
            else:
                eps_update_step = self.eps_update_step2

            if_wrong_mask = (lb<0).any(dim=1)
            index_wrong = index[if_wrong_mask]
            self.eps_sample[index_wrong] -= self.eps_sample[index_wrong]*0.025

            if_correct_mask = ~if_wrong_mask
            index_correct = index[if_correct_mask]
            self.eps_sample[index_correct] += (0.025*self.eps_sample[index_correct] + 0.001)

            self.eps_sample[index] = torch.max(self.eps_sample[index], torch.zeros_like(self.eps_sample[index]))
            # self.eps_sample[index] = torch.min(self.eps_sample[index], torch.zeros_like(self.eps_sample[index])+0.45)
            mean = torch.mean(self.eps_sample[index])
            diff = mean-self.eps_sample[index]
            self.eps_sample[index] = self.eps_sample[index]+ diff/self.diff_coff

        elif self.autoeps_type == 5:
            if epoch <= self.eps_warmup_epochs:
                eps_update_step = self.eps_update_step1
            else:
                eps_update_step = self.eps_update_step2

            if_wrong_mask = (lb<0).any(dim=1)
            index_wrong = index[if_wrong_mask]
            index_lessthantesteps = index[self.eps_sample[index]>=self.test_eps+eps_update_step*self.step_coef]
            index_decrease = list(set(index_wrong.tolist()) & set(index_lessthantesteps.tolist()))
            # self.eps_sample[index_decrease] -= eps_update_step*self.step_coef

            if_correct_mask = ~if_wrong_mask
            index_correct = index[if_correct_mask]
            self.eps_sample[index_correct] += eps_update_step

            self.eps_sample[index] = torch.max(self.eps_sample[index], torch.zeros_like(self.eps_sample[index]))

            mean = torch.mean(self.eps_sample[index])
            diff = mean-self.eps_sample
            self.eps_sample[index_wrong] = self.eps_sample[index_wrong] + diff[index_wrong]/self.diff_coff
        

        


        return 
    
