import torch
import math

class VRAda(torch.optim.Optimizer):
    def __init__(self,
                 model,
                 loss_fn=None,
                 a=None,  # to be deprecated
                 b=None,  # to be deprecated
                 alpha=None,  # to be deprecated
                 margin=1.0,
                 lr=0.1,
                 gamma=None,  # to be deprecated
                 clip_value=1.0,
                 weight_decay=1e-5,
                 epoch_decay=2e-3,  # default: gamma=500
                 momentum=0.9,
                 verbose=True,
                 device=None,
                 **kwargs):

        if not device:
            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        else:
            self.device = device
        assert (gamma is None) or (epoch_decay is None), 'You can only use one of gamma and epoch_decay!'
        if gamma is not None:
            assert gamma > 0
            epoch_decay = 1 / gamma

        self.margin = margin
        self.model = model
        self.lr = lr
        self.gamma = gamma  # to be deprecated
        self.clip_value = clip_value
        self.momentum = momentum
        self.weight_decay = weight_decay
        self.epoch_decay = epoch_decay

        self.loss_fn = loss_fn
        if loss_fn != None:
            try:
                self.a = loss_fn.a
                self.b = loss_fn.b
                self.alpha = loss_fn.alpha
            except:
                print('AUCLoss is not found!')
        else:
            self.a = a
            self.b = b
            self.alpha = alpha

        self.model_ref = self.init_model_ref()
        self.model_acc = self.init_model_acc()
        self.T = 0  # for epoch_decay
        self.steps = 0  # total optim steps
        self.verbose = verbose  # print updates for lr/regularizer

        def get_parameters(params):
            for p in params:
                yield p

        if self.a is not None and self.b is not None:
            self.params = get_parameters(list(model.parameters()) + [self.a, self.b])
        else:
            self.params = get_parameters(list(model.parameters()))
        self.defaults = dict(lr=self.lr,
                             margin=margin,
                             a=self.a,
                             b=self.b,
                             alpha=self.alpha,
                             clip_value=clip_value,
                             momentum=momentum,
                             weight_decay=weight_decay,
                             epoch_decay=epoch_decay,
                             model_ref=self.model_ref,
                             model_acc=self.model_acc
                             )

        super(VRAda, self).__init__(self.params, self.defaults)

    def __setstate__(self, state):
        super(VRAda, self).__setstate__(state)
        for group in self.param_groups:
            group.setdefault('nesterov', False)

    def init_model_ref(self):
        self.model_ref = []
        for var in list(self.model.parameters()) + [self.a, self.b]:
            if var is not None:
                self.model_ref.append(torch.empty(var.shape).normal_(mean=0, std=0.01).to(self.device))
        return self.model_ref

    def init_model_acc(self):
        self.model_acc = []
        for var in list(self.model.parameters()) + [self.a, self.b]:
            if var is not None:
                self.model_acc.append(
                    torch.zeros(var.shape, dtype=torch.float32, device=self.device, requires_grad=False).to(
                        self.device))
        return self.model_acc

    @property
    def optim_steps(self):
        return self.steps

    @property
    def get_params(self):
        return list(self.model.parameters())

    @torch.no_grad()
    def step(self, closure=None, delta_x=None,delta_y=None):
        """Performs a single optimization step.
        """
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            weight_decay = group['weight_decay']
            clip_value = group['clip_value']
            momentum = group['momentum']
            self.lr = group['lr']

            epoch_decay = group['epoch_decay']
            model_ref = group['model_ref']
            model_acc = group['model_acc']

            m = group['margin']
            a = group['a']
            b = group['b']
            alpha = group['alpha']
            if delta_x is None or delta_y is None:
                for i, p in enumerate(group['params']):
                    if p.grad is None:
                        continue
                    d_p = torch.clamp(p.grad.data, -clip_value, clip_value) + epoch_decay * (
                                p.data - model_ref[i].data) + weight_decay * p.data
                    if momentum != 0:
                        param_state = self.state[p]
                        d_p = buf = param_state['momentum_buffer'] = torch.clone(d_p).detach()
                    p.data = p.data - group['lr'] * d_p
                    model_acc[i].data = model_acc[i].data + p.data

                if alpha is not None and alpha.grad is not None:
                    d_alpha = 2*(m + b.data - a.data) - 2 * alpha.data
                    buf = group['y_buffer'] = torch.clone(d_alpha).detach()
                    alpha.data = alpha.data + group['lr'] * buf
                    alpha.data = torch.clamp(alpha.data, 0, 999)
            else:
                for i, (p, delta_x_i) in enumerate(zip(group['params'], delta_x)):
                    if p.grad is None:
                        continue
                    d_p = torch.clamp(p.grad.data, -clip_value, clip_value) + epoch_decay * (
                                p.data - model_ref[i].data) + weight_decay * p.data
                    if momentum != 0:
                        param_state = self.state[p]
                        if 'momentum_buffer' not in param_state:
                            raise ValueError("There is a problem with the logic")
                        buf = param_state['momentum_buffer']
                        buf.sub_(delta_x_i).mul_(1 - momentum).add_(d_p)  # todo
                        d_p = buf
                    p.data = p.data - group['lr'] * d_p
                    model_acc[i].data = model_acc[i].data + p.data

                if alpha is not None and alpha.grad is not None:
                    d_alpha = 2*(m + b.data - a.data) - 2 * alpha.data
                    buf = group['y_buffer']
                    buf.sub_(delta_y).mul_(1 - momentum).add_(d_alpha)
                    alpha.data = alpha.data + group['lr'] * buf
                    alpha.data = torch.clamp(alpha.data, 0, 999)

        self.T += 1
        self.steps += 1
        return loss

    def zero_grad(self):
        self.model.zero_grad()
        if self.a is not None and self.b is not None:
            self.a.grad = None
            self.b.grad = None
        if self.alpha is not None:
            self.alpha.grad = None

    def update_lr(self, decay_factor=None):
        if decay_factor != None:
            self.param_groups[0]['lr'] = self.param_groups[0]['lr'] / decay_factor
            if self.verbose:
                print('Reducing learning rate to %.5f @ T=%s!' % (self.param_groups[0]['lr'], self.steps))

    def update_regularizer(self, decay_factor=None):
        if decay_factor != None:
            self.param_groups[0]['lr'] = self.param_groups[0]['lr'] / decay_factor
            if self.verbose:
                print('Reducing learning rate to %.5f @ T=%s!' % (self.param_groups[0]['lr'], self.steps))
        if self.verbose:
            print('Updating regularizer @ T=%s!' % (self.steps))
        for i, param in enumerate(self.model_ref):
            self.model_ref[i].data = self.model_acc[i].data / self.T
        for i, param in enumerate(self.model_acc):
            self.model_acc[i].data = torch.zeros(param.shape, dtype=torch.float32, device=self.device,
                                                 requires_grad=False).to(self.device)
        self.T = 0



class Adam(torch.optim.Optimizer):
    def __init__(self,
                 model,
                 loss_fn=None,
                 a=None,  # to be deprecated
                 b=None,  # to be deprecated
                 alpha=None,  # to be deprecated
                 margin=1.0,
                 lr_x=0.1,
                 eps=1e-8,
                 lr_y=0.1,
                 betas=(0.9,0.999),
                 acc_lr_y=0,
                 exp_avg_y=0,
                 exp_avg_sq_y=0,
                 gamma=None,  # to be deprecated
                 clip_value=1.0,
                 weight_decay=1e-5,
                 epoch_decay=2e-3,  # default: gamma=500
                 momentum=0.9,
                 verbose=True,
                 device=None,
                 **kwargs):

        if not device:
            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        else:
            self.device = device
        assert (gamma is None) or (epoch_decay is None), 'You can only use one of gamma and epoch_decay!'
        if gamma is not None:
            assert gamma > 0
            epoch_decay = 1 / gamma

        self.margin = margin
        self.model = model
        self.lr_x = lr_x
        self.lr_y = lr_y
        self.step = 1
        self.exp_avg_y = exp_avg_y
        self.exp_avg_sq_y = exp_avg_sq_y
        self.gamma = gamma  # to be deprecated
        self.clip_value = clip_value
        self.momentum = momentum
        self.weight_decay = weight_decay
        self.epoch_decay = epoch_decay
        self.acc_lr_y=acc_lr_y
        self.loss_fn = loss_fn
        self.eps = eps
        if loss_fn != None:
            try:
                self.a = loss_fn.a
                self.b = loss_fn.b
                self.alpha = loss_fn.alpha
            except:
                print('AUCLoss is not found!')
        else:
            self.a = a
            self.b = b
            self.alpha = alpha

        self.model_ref = self.init_model_ref()
        self.model_acc = self.init_model_acc()
        self.model_acc_lr_x = self.model_acc_lr_x()
        self.exp_avg = self.exp_avg()
        self.exp_avg_sq = self.exp_avg_sq()
        self.T = 0  # for epoch_decay
        self.steps = 0  # total optim steps
        self.verbose = verbose  # print updates for lr/regularizer

        def get_parameters(params):
            for p in params:
                yield p

        if self.a is not None and self.b is not None:
            self.params = get_parameters(list(model.parameters()) + [self.a, self.b])
        else:
            self.params = get_parameters(list(model.parameters()))
        self.defaults = dict(lr_x=self.lr_x,
                             lr_y=self.lr_y,
                             margin=margin,
                             a=self.a,
                             b=self.b,
                             betas=betas,
                             alpha=self.alpha,
                             clip_value=clip_value,
                             momentum=momentum,
                             weight_decay=weight_decay,
                             epoch_decay=epoch_decay,
                             model_ref=self.model_ref,
                             model_acc=self.model_acc,
                             model_acc_lr_x=self.model_acc_lr_x
                             )

        super(Adam, self).__init__(self.params, self.defaults)

    def __setstate__(self, state):
        super(Adam, self).__setstate__(state)
        for group in self.param_groups:
            group.setdefault('nesterov', False)

    def init_model_ref(self):
        self.model_ref = []
        for var in list(self.model.parameters()) + [self.a, self.b]:
            if var is not None:
                self.model_ref.append(torch.empty(var.shape).normal_(mean=0, std=0.01).to(self.device))
        return self.model_ref

    def init_model_acc(self):
        self.model_acc = []
        for var in list(self.model.parameters()) + [self.a, self.b]:
            if var is not None:
                self.model_acc.append(
                    torch.zeros(var.shape, dtype=torch.float32, device=self.device, requires_grad=False).to(
                        self.device))
        return self.model_acc

    def model_acc_lr_x(self):
        self.model_acc_lr_x = []
        for var in list(self.model.parameters()) + [self.a, self.b]:
            if var is not None:
                self.model_acc_lr_x.append(0)
        return self.model_acc_lr_x


    def exp_avg(self):
        self.exp_avg = []
        for var in list(self.model.parameters()) + [self.a, self.b]:
            if var is not None:
                self.exp_avg.append(
                    torch.zeros(var.shape, dtype=torch.float32, device=self.device, requires_grad=False).to(
                        self.device))
        return self.exp_avg

    def exp_avg_sq(self):
        self.exp_avg_sq = []
        for var in list(self.model.parameters()) + [self.a, self.b]:
            if var is not None:
                self.exp_avg_sq.append(
                    torch.zeros(var.shape, dtype=torch.float32, device=self.device, requires_grad=False).to(
                        self.device))
        return self.exp_avg_sq


    @property
    def optim_steps(self):
        return self.steps

    @property
    def get_params(self):
        return list(self.model.parameters())

    @torch.no_grad()
    def step(self, closure=None, delta_x=None, delta_y=None):
        """Performs a single optimization step."""
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            weight_decay = group['weight_decay']
            clip_value = group['clip_value']
            self.lr_x = group['lr_x']
            self.lr_y = group['lr_y']
            epoch_decay = group['epoch_decay']
            model_ref = group['model_ref']
            model_acc = group['model_acc']
            model_acc_lr_x = group['model_acc_lr_x']
            exp_avg = group['exp_avg']
            exp_avg_sq = group['exp_avg_sq']
            m = group['margin']
            a = group['a']
            b = group['b']
            beta1, beta2 = group['betas']
            alpha = group['alpha']
            bias_correction1 = 1 - beta1 ** self.step
            bias_correction2 = 1 - beta2 ** self.step

            for i, p in enumerate(group['params']):
                if p.grad is None:
                    continue
                d_p = torch.clamp(p.grad.data, -clip_value, clip_value) + epoch_decay * (
                        p.data - model_ref[i].data) + weight_decay * p.data
                exp_avg[i] = exp_avg[i] * beta1 + (1 - beta1) * d_p
                exp_avg_sq[i] = exp_avg_sq[i] * beta2 + (1 - beta2) * d_p ** 2

                exp_avg_hat = exp_avg[i] / bias_correction1
                exp_avg_sq_hat = exp_avg_sq[i] / bias_correction2

                p.data.addcdiv_(exp_avg_hat, (exp_avg_sq_hat + self.eps).sqrt(), value=-group['lr_x'])
                model_acc[i].data = model_acc[i].data + p.data

            if alpha is not None:
                if alpha.grad is not None:
                    d_alpha = 2 * (m + b.data - a.data) - 2 * alpha.data
                    self.exp_avg_y = self.exp_avg_y * beta1 + (1 - beta1) * d_alpha
                    self.exp_avg_sq_y = self.exp_avg_sq_y * beta2 + (1 - beta2) * d_alpha ** 2


                    exp_avg_y_hat = self.exp_avg_y / bias_correction1
                    exp_avg_sq_y_hat = self.exp_avg_sq_y / bias_correction2


                    alpha.data.addcdiv_(exp_avg_y_hat, (exp_avg_sq_y_hat + self.eps).sqrt(), value=group['lr_y'])  # We may also consider using corrected values here
                    alpha.data = torch.clamp(alpha.data, 0, 999)

        self.step += 1
        return loss

    def zero_grad(self):
        self.model.zero_grad()
        if self.a is not None and self.b is not None:
            self.a.grad = None
            self.b.grad = None
        if self.alpha is not None:
            self.alpha.grad = None

    def update_lr(self, decay_factor=None):
        if decay_factor != None:
            self.param_groups[0]['lr_x'] = self.param_groups[0]['lr_x'] / decay_factor
            self.param_groups[0]['lr_y'] = self.param_groups[0]['lr_y'] / decay_factor
            if self.verbose:
                print('Reducing learning rate_x to %.5f @ T=%s!' % (self.param_groups[0]['lr_x'], self.steps))
                print('Reducing learning rate_y to %.5f @ T=%s!' % (self.param_groups[0]['lr_y'], self.steps))

    def update_regularizer(self, decay_factor=None):
        if decay_factor != None:
            self.param_groups[0]['lr_x'] = self.param_groups[0]['lr_x'] / decay_factor
            self.param_groups[0]['lr_y'] = self.param_groups[0]['lr_y'] / decay_factor
            if self.verbose:
                print('Reducing learning rate_x to %.5f @ T=%s!' % (self.param_groups[0]['lr_x'], self.steps))
                print('Reducing learning rate_y to %.5f @ T=%s!' % (self.param_groups[0]['lr_y'], self.steps))
        if self.verbose:
            print('Updating regularizer @ T=%s!' % (self.steps))
        for i, param in enumerate(self.model_ref):
            self.model_ref[i].data = self.model_acc[i].data / self.T
        for i, param in enumerate(self.model_acc):
            self.model_acc[i].data = torch.zeros(param.shape, dtype=torch.float32, device=self.device,
                                                 requires_grad=False).to(self.device)
        self.T = 0



import math

import torch


class TiAda(torch.optim.Optimizer):
    def __init__(self,
                 model,
                 loss_fn=None,
                 a=None,  # to be deprecated
                 b=None,  # to be deprecated
                 alpha=None,  # to be deprecated
                 margin=1.0,
                 lr_x=0.1,
                 lr_y=0.1,
                 acc_lr_y=0,
                 gamma=None,  # to be deprecated
                 clip_value=1.0,
                 weight_decay=1e-5,
                 epoch_decay=2e-3,  # default: gamma=500
                 momentum=0.9,
                 verbose=True,
                 device=None,
                 **kwargs):

        if not device:
            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        else:
            self.device = device
        assert (gamma is None) or (epoch_decay is None), 'You can only use one of gamma and epoch_decay!'
        if gamma is not None:
            assert gamma > 0
            epoch_decay = 1 / gamma

        self.margin = margin
        self.model = model
        self.lr_x = lr_x
        self.lr_y = lr_y
        self.gamma = gamma  # to be deprecated
        self.clip_value = clip_value
        self.momentum = momentum
        self.weight_decay = weight_decay
        self.epoch_decay = epoch_decay
        self.acc_lr_y=acc_lr_y
        self.loss_fn = loss_fn
        if loss_fn != None:
            try:
                self.a = loss_fn.a
                self.b = loss_fn.b
                self.alpha = loss_fn.alpha
            except:
                print('AUCLoss is not found!')
        else:
            self.a = a
            self.b = b
            self.alpha = alpha

        self.model_ref = self.init_model_ref()
        self.model_acc = self.init_model_acc()
        self.model_acc_lr_x = self.model_acc_lr_x()

        self.T = 0  # for epoch_decay
        self.steps = 0  # total optim steps
        self.verbose = verbose  # print updates for lr/regularizer

        def get_parameters(params):
            for p in params:
                yield p

        if self.a is not None and self.b is not None:
            self.params = get_parameters(list(model.parameters()) + [self.a, self.b])
        else:
            self.params = get_parameters(list(model.parameters()))
        self.defaults = dict(lr_x=self.lr_x,
                             lr_y=self.lr_y,
                             margin=margin,
                             a=self.a,
                             b=self.b,
                             alpha=self.alpha,
                             clip_value=clip_value,
                             momentum=momentum,
                             weight_decay=weight_decay,
                             epoch_decay=epoch_decay,
                             model_ref=self.model_ref,
                             model_acc=self.model_acc,
                             model_acc_lr_x=self.model_acc_lr_x
                             )

        super(TiAda, self).__init__(self.params, self.defaults)

    def __setstate__(self, state):
        super(TiAda, self).__setstate__(state)
        for group in self.param_groups:
            group.setdefault('nesterov', False)

    def init_model_ref(self):
        self.model_ref = []
        for var in list(self.model.parameters()) + [self.a, self.b]:
            if var is not None:
                self.model_ref.append(torch.empty(var.shape).normal_(mean=0, std=0.01).to(self.device))
        return self.model_ref

    def init_model_acc(self):
        self.model_acc = []
        for var in list(self.model.parameters()) + [self.a, self.b]:
            if var is not None:
                self.model_acc.append(
                    torch.zeros(var.shape, dtype=torch.float32, device=self.device, requires_grad=False).to(
                        self.device))
        return self.model_acc

    def model_acc_lr_x(self):
        self.model_acc_lr_x = []
        for var in list(self.model.parameters()) + [self.a, self.b]:
            if var is not None:
                self.model_acc_lr_x.append(0)
        return self.model_acc_lr_x




    @property
    def optim_steps(self):
        return self.steps

    @property
    def get_params(self):
        return list(self.model.parameters())

    @torch.no_grad()
    def step(self, closure=None, delta_x=None, delta_y=None):
        """Performs a single optimization step.
        """
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            weight_decay = group['weight_decay']
            clip_value = group['clip_value']
            momentum = group['momentum']
            self.lr_x = group['lr_x']
            self.lr_y = group['lr_y']
            epoch_decay = group['epoch_decay']
            model_ref = group['model_ref']
            model_acc = group['model_acc']
            model_acc_lr_x = group['model_acc_lr_x']
            m = group['margin']
            a = group['a']
            b = group['b']
            alpha = group['alpha']
            # updates
            for i, p in enumerate(group['params']):
                if p.grad is None:
                    continue
                d_p = torch.clamp(p.grad.data, -clip_value, clip_value) + epoch_decay * (
                            p.data - model_ref[i].data) + weight_decay * p.data

                model_acc_lr_x[i] = model_acc_lr_x[i] + torch.norm(d_p) ** 2
                lr_x = 1 / max(model_acc_lr_x[i], self.acc_lr_y) ** 0.5
                p.data = p.data - group['lr_x'] * lr_x * d_p
                model_acc[i].data = model_acc[i].data + p.data

            if alpha is not None:
                if alpha.grad is not None:
                    d_alpha = 2 * (m + b.data - a.data) - 2 * alpha.data
                    self.acc_lr_y += torch.norm(d_alpha) ** 2
                    lr_y = 1 / (self.acc_lr_y ** 0.5)
                    alpha.data = alpha.data + group['lr_y'] * lr_y * d_alpha
                    alpha.data = torch.clamp(alpha.data, 0, 999)

        self.T += 1
        self.steps += 1
        return loss

    def zero_grad(self):
        self.model.zero_grad()
        if self.a is not None and self.b is not None:
            self.a.grad = None
            self.b.grad = None
        if self.alpha is not None:
            self.alpha.grad = None

    def update_lr(self, decay_factor=None):
        if decay_factor != None:
            self.param_groups[0]['lr_x'] = self.param_groups[0]['lr_x'] / decay_factor
            self.param_groups[0]['lr_y'] = self.param_groups[0]['lr_y'] / decay_factor
            if self.verbose:
                print('Reducing learning rate_x to %.5f @ T=%s!' % (self.param_groups[0]['lr_x'], self.steps))
                print('Reducing learning rate_y to %.5f @ T=%s!' % (self.param_groups[0]['lr_y'], self.steps))

    def update_regularizer(self, decay_factor=None):
        if decay_factor != None:
            self.param_groups[0]['lr_x'] = self.param_groups[0]['lr_x'] / decay_factor
            self.param_groups[0]['lr_y'] = self.param_groups[0]['lr_y'] / decay_factor
            if self.verbose:
                print('Reducing learning rate_x to %.5f @ T=%s!' % (self.param_groups[0]['lr_x'], self.steps))
                print('Reducing learning rate_y to %.5f @ T=%s!' % (self.param_groups[0]['lr_y'], self.steps))
        if self.verbose:
            print('Updating regularizer @ T=%s!' % (self.steps))
        for i, param in enumerate(self.model_ref):
            self.model_ref[i].data = self.model_acc[i].data / self.T
        for i, param in enumerate(self.model_acc):
            self.model_acc[i].data = torch.zeros(param.shape, dtype=torch.float32, device=self.device,
                                                 requires_grad=False).to(self.device)
        self.T = 0




import torch


class TiAda_Adam(torch.optim.Optimizer):
    def __init__(self,
                 model,
                 loss_fn=None,
                 a=None,  # to be deprecated
                 b=None,  # to be deprecated
                 alpha=None,  # to be deprecated
                 margin=1.0,
                 lr_x=0.1,
                 eps=1e-8,
                 lr_y=0.1,
                 betas=(0.9,0.999),
                 acc_lr_y=0,
                 exp_avg_y=0,
                 exp_avg_sq_y=0,
                 gamma=None,  # to be deprecated
                 clip_value=1.0,
                 weight_decay=1e-5,
                 epoch_decay=2e-3,  # default: gamma=500
                 momentum=0.9,
                 verbose=True,
                 device=None,
                 **kwargs):

        if not device:
            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        else:
            self.device = device
        assert (gamma is None) or (epoch_decay is None), 'You can only use one of gamma and epoch_decay!'
        if gamma is not None:
            assert gamma > 0
            epoch_decay = 1 / gamma

        self.margin = margin
        self.model = model
        self.lr_x = lr_x
        self.lr_y = lr_y
        self.steps = 0
        self.step_inner = 1
        self.exp_avg_y = exp_avg_y
        self.exp_avg_sq_y = exp_avg_sq_y
        self.gamma = gamma  # to be deprecated
        self.clip_value = clip_value
        self.momentum = momentum
        self.weight_decay = weight_decay
        self.epoch_decay = epoch_decay
        self.acc_lr_y=acc_lr_y
        self.loss_fn = loss_fn
        self.eps = eps
        if loss_fn != None:
            try:
                self.a = loss_fn.a
                self.b = loss_fn.b
                self.alpha = loss_fn.alpha
            except:
                print('AUCLoss is not found!')
        else:
            self.a = a
            self.b = b
            self.alpha = alpha

        self.model_ref = self.init_model_ref()
        self.model_acc = self.init_model_acc()
        self.model_acc_lr_x = self.init_model_acc_lr_x()
        self.exp_avg = self.init_exp_avg()
        self.exp_avg_sq = self.init_exp_avg_sq()
        self.T = 0  # for epoch_decay
        self.verbose = verbose  # print updates for lr/regularizer

        def get_parameters(params):
            for p in params:
                yield p

        if self.a is not None and self.b is not None:
            self.params = get_parameters(list(model.parameters()) + [self.a, self.b])
        else:
            self.params = get_parameters(list(model.parameters()))
        self.defaults = dict(lr_x=self.lr_x,
                             lr_y=self.lr_y,
                             margin=margin,
                             a=self.a,
                             b=self.b,
                             betas=betas,
                             alpha=self.alpha,
                             clip_value=clip_value,
                             momentum=momentum,
                             weight_decay=weight_decay,
                             epoch_decay=epoch_decay,
                             model_ref=self.model_ref,
                             model_acc=self.model_acc,
                             model_acc_lr_x=self.model_acc_lr_x,
                             exp_avg=self.exp_avg,
                             exp_avg_sq=self.exp_avg_sq
                             )

        super(TiAda_Adam, self).__init__(self.params, self.defaults)

    def __setstate__(self, state):
        super(TiAda_Adam, self).__setstate__(state)
        for group in self.param_groups:
            group.setdefault('nesterov', False)

    def init_model_ref(self):
        self.model_ref = []
        for var in list(self.model.parameters()) + [self.a, self.b]:
            if var is not None:
                self.model_ref.append(torch.empty(var.shape).normal_(mean=0, std=0.01).to(self.device))
        return self.model_ref

    def init_model_acc(self):
        self.model_acc = []
        for var in list(self.model.parameters()) + [self.a, self.b]:
            if var is not None:
                self.model_acc.append(
                    torch.zeros(var.shape, dtype=torch.float32, device=self.device, requires_grad=False).to(
                        self.device))
        return self.model_acc

    def init_model_acc_lr_x(self):
        self.model_acc_lr_x = []
        for var in list(self.model.parameters()) + [self.a, self.b]:
            if var is not None:
                self.model_acc_lr_x.append(0)
        return self.model_acc_lr_x


    def init_exp_avg(self):
        self.exp_avg = []
        for var in list(self.model.parameters()) + [self.a, self.b]:
            if var is not None:
                self.exp_avg.append(
                    torch.zeros(var.shape, dtype=torch.float32, device=self.device, requires_grad=False).to(
                        self.device))
        return self.exp_avg

    def init_exp_avg_sq(self):
        self.exp_avg_sq = []
        for var in list(self.model.parameters()) + [self.a, self.b]:
            if var is not None:
                self.exp_avg_sq.append(
                    torch.zeros(var.shape, dtype=torch.float32, device=self.device, requires_grad=False).to(
                        self.device))
        return self.exp_avg_sq


    @property
    def optim_steps(self):
        return self.steps

    @property
    def get_params(self):
        return list(self.model.parameters())

    @torch.no_grad()
    def step(self, closure=None):
        """Performs a single optimization step."""
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            weight_decay = group['weight_decay']
            clip_value = group['clip_value']
            self.lr_x = group['lr_x']
            self.lr_y = group['lr_y']
            epoch_decay = group['epoch_decay']
            model_ref = group['model_ref']
            model_acc = group['model_acc']
            model_acc_lr_x = group['model_acc_lr_x']
            exp_avg = group['exp_avg']
            exp_avg_sq = group['exp_avg_sq']
            m = group['margin']
            a = group['a']
            b = group['b']
            beta1, beta2 = group['betas']
            alpha = group['alpha']
            bias_correction1 = 1 - beta1 ** self.step_inner
            bias_correction2 = 1 - beta2 ** self.step_inner
            bias_correction2_sqrt = math.sqrt(bias_correction2)
            for i, p in enumerate(group['params']):
                if p.grad is None:
                    continue
                d_p = torch.clamp(p.grad.data, -clip_value, clip_value) + epoch_decay * (
                        p.data - model_ref[i].data) + weight_decay * p.data
                exp_avg[i] = exp_avg[i] * beta1 + (1 - beta1) * d_p
                exp_avg_sq[i] = exp_avg_sq[i] * beta2 + (1 - beta2) * d_p ** 2

                step_size = group['lr_x'] / bias_correction1

                model_acc_lr_x[i] = model_acc_lr_x[i] + exp_avg_sq[i].sum()
                lr_x = 1 / (max(model_acc_lr_x[i], self.acc_lr_y) ** 0.5 / bias_correction2_sqrt + self.eps)

                p.data = p.data - step_size * lr_x * exp_avg[i]  # We may also consider using corrected values here
                model_acc[i].data = model_acc[i].data + p.data

            if alpha is not None:
                if alpha.grad is not None:
                    d_alpha = 2 * (m + b.data - a.data) - 2 * alpha.data
                    self.exp_avg_y = self.exp_avg_y * beta1 + (1 - beta1) * d_alpha
                    self.exp_avg_sq_y = self.exp_avg_sq_y * beta2 + (1 - beta2) * d_alpha ** 2

                    step_size = group['lr_y'] / bias_correction1

                    self.acc_lr_y += self.exp_avg_sq_y.sum()
                    lr_y = 1 / (self.acc_lr_y ** 0.5 / bias_correction2_sqrt + self.eps)

                    alpha.data = alpha.data + step_size * lr_y * self.exp_avg_y  # We may also consider using corrected values here
                    alpha.data = torch.clamp(alpha.data, 0, 999)
        self.T += 1
        self.step_inner += 1
        return loss

    def zero_grad(self):
        self.model.zero_grad()
        if self.a is not None and self.b is not None:
            self.a.grad = None
            self.b.grad = None
        if self.alpha is not None:
            self.alpha.grad = None

    def update_lr(self, decay_factor=None):
        if decay_factor != None:
            self.param_groups[0]['lr_x'] = self.param_groups[0]['lr_x'] / decay_factor
            self.param_groups[0]['lr_y'] = self.param_groups[0]['lr_y'] / decay_factor
            if self.verbose:
                print('Reducing learning rate_x to %.5f @ T=%s!' % (self.param_groups[0]['lr_x'], self.steps))
                print('Reducing learning rate_y to %.5f @ T=%s!' % (self.param_groups[0]['lr_y'], self.steps))

    def update_regularizer(self, decay_factor=None):
        if decay_factor != None:
            self.param_groups[0]['lr_x'] = self.param_groups[0]['lr_x'] / decay_factor
            self.param_groups[0]['lr_y'] = self.param_groups[0]['lr_y'] / decay_factor
            if self.verbose:
                print('Reducing learning rate_x to %.5f @ T=%s!' % (self.param_groups[0]['lr_x'], self.steps))
                print('Reducing learning rate_y to %.5f @ T=%s!' % (self.param_groups[0]['lr_y'], self.steps))
        if self.verbose:
            print('Updating regularizer @ T=%s!' % (self.steps))
        for i, param in enumerate(self.model_ref):
            self.model_ref[i].data = self.model_acc[i].data / self.T
        for i, param in enumerate(self.model_acc):
            self.model_acc[i].data = torch.zeros(param.shape, dtype=torch.float32, device=self.device,
                                                 requires_grad=False).to(self.device)
        self.T = 0



class Adam(torch.optim.Optimizer):
    def __init__(self,
                 model,
                 loss_fn=None,
                 a=None,  # to be deprecated
                 b=None,  # to be deprecated
                 alpha=None,  # to be deprecated
                 margin=1.0,
                 lr_x=0.1,
                 eps=1e-8,
                 lr_y=0.1,
                 betas=(0.9,0.999),
                 acc_lr_y=0,
                 exp_avg_y=0,
                 exp_avg_sq_y=0,
                 gamma=None,  # to be deprecated
                 clip_value=1.0,
                 weight_decay=1e-5,
                 epoch_decay=2e-3,  # default: gamma=500
                 momentum=0.9,
                 verbose=True,
                 device=None,
                 **kwargs):

        if not device:
            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        else:
            self.device = device
        assert (gamma is None) or (epoch_decay is None), 'You can only use one of gamma and epoch_decay!'
        if gamma is not None:
            assert gamma > 0
            epoch_decay = 1 / gamma

        self.margin = margin
        self.model = model
        self.lr_x = lr_x
        self.lr_y = lr_y
        self.steps = 0
        self.betas = betas
        self.step_inner = 1
        self.exp_avg_y = exp_avg_y
        self.exp_avg_sq_y = exp_avg_sq_y
        self.gamma = gamma  # to be deprecated
        self.clip_value = clip_value
        self.momentum = momentum
        self.weight_decay = weight_decay
        self.epoch_decay = epoch_decay
        self.acc_lr_y=acc_lr_y
        self.loss_fn = loss_fn
        self.eps = eps
        if loss_fn != None:
            try:
                self.a = loss_fn.a
                self.b = loss_fn.b
                self.alpha = loss_fn.alpha
            except:
                print('AUCLoss is not found!')
        else:
            self.a = a
            self.b = b
            self.alpha = alpha

        self.model_ref = self.init_model_ref()
        self.model_acc = self.init_model_acc()
        self.model_acc_lr_x = self.init_model_acc_lr_x()
        self.exp_avg = self.init_exp_avg()
        self.exp_avg_sq = self.init_exp_avg_sq()
        self.T = 0  # for epoch_decay
        self.verbose = verbose  # print updates for lr/regularizer

        def get_parameters(params):
            for p in params:
                yield p

        if self.a is not None and self.b is not None:
            self.params = get_parameters(list(model.parameters()) + [self.a, self.b])
        else:
            self.params = get_parameters(list(model.parameters()))
        self.defaults = dict(lr_x=self.lr_x,
                             lr_y=self.lr_y,
                             margin=margin,
                             a=self.a,
                             b=self.b,
                             betas=self.betas,
                             alpha=self.alpha,
                             clip_value=clip_value,
                             momentum=momentum,
                             weight_decay=weight_decay,
                             epoch_decay=epoch_decay,
                             model_ref=self.model_ref,
                             model_acc=self.model_acc,
                             model_acc_lr_x=self.model_acc_lr_x,
                             exp_avg=self.exp_avg,
                             exp_avg_sq=self.exp_avg_sq
                             )

        super(Adam, self).__init__(self.params, self.defaults)

    def __setstate__(self, state):
        super(Adam, self).__setstate__(state)
        for group in self.param_groups:
            group.setdefault('nesterov', False)

    def init_model_ref(self):
        self.model_ref = []
        for var in list(self.model.parameters()) + [self.a, self.b]:
            if var is not None:
                self.model_ref.append(torch.empty(var.shape).normal_(mean=0, std=0.01).to(self.device))
        return self.model_ref

    def init_model_acc(self):
        self.model_acc = []
        for var in list(self.model.parameters()) + [self.a, self.b]:
            if var is not None:
                self.model_acc.append(
                    torch.zeros(var.shape, dtype=torch.float32, device=self.device, requires_grad=False).to(
                        self.device))
        return self.model_acc

    def init_model_acc_lr_x(self):
        self.model_acc_lr_x = []
        for var in list(self.model.parameters()) + [self.a, self.b]:
            if var is not None:
                self.model_acc_lr_x.append(0)
        return self.model_acc_lr_x


    def init_exp_avg(self):
        self.exp_avg = []
        for var in list(self.model.parameters()) + [self.a, self.b]:
            if var is not None:
                self.exp_avg.append(
                    torch.zeros(var.shape, dtype=torch.float32, device=self.device, requires_grad=False).to(
                        self.device))
        return self.exp_avg

    def init_exp_avg_sq(self):
        self.exp_avg_sq = []
        for var in list(self.model.parameters()) + [self.a, self.b]:
            if var is not None:
                self.exp_avg_sq.append(
                    torch.zeros(var.shape, dtype=torch.float32, device=self.device, requires_grad=False).to(
                        self.device))
        return self.exp_avg_sq


    @property
    def optim_steps(self):
        return self.steps

    @property
    def get_params(self):
        return list(self.model.parameters())

    @torch.no_grad()
    def step(self, closure=None):
        """Performs a single optimization step."""
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            weight_decay = group['weight_decay']
            clip_value = group['clip_value']
            lr_x = group['lr_x']
            lr_y = group['lr_y']
            epoch_decay = group['epoch_decay']
            model_ref = group['model_ref']
            model_acc = group['model_acc']
            exp_avg = group['exp_avg']
            exp_avg_sq = group['exp_avg_sq']
            m = group['margin']
            a = group['a']
            b = group['b']
            beta1, beta2 = group['betas']
            alpha = group['alpha']
            bias_correction1 = 1 - beta1 ** self.step_inner
            bias_correction2 = 1 - beta2 ** self.step_inner

            for i, p in enumerate(group['params']):
                if p.grad is None:
                    continue
                d_p = torch.clamp(p.grad.data, -clip_value, clip_value) + epoch_decay * (
                        p.data - model_ref[i].data) + weight_decay * p.data
                exp_avg[i] = exp_avg[i] * beta1 + (1 - beta1) * d_p
                exp_avg_sq[i] = exp_avg_sq[i] * beta2 + (1 - beta2) * d_p ** 2

                exp_avg_hat = exp_avg[i] / bias_correction1
                exp_avg_sq_hat = exp_avg_sq[i] / bias_correction2

                p.data.addcdiv_(exp_avg_hat, (exp_avg_sq_hat + self.eps).sqrt(), value=-lr_x)
                model_acc[i].data += p.data

            self._update_alpha(alpha, m, a, b, beta1, beta2, bias_correction1, bias_correction2, lr_y)
        self.T += 1
        self.step_inner += 1
        return loss

    def _update_alpha(self, alpha, m, a, b, beta1, beta2, bias_correction1, bias_correction2, lr_y):
        if alpha is not None and alpha.grad is not None:
            d_alpha = 2 * (m + b.data - a.data) - 2 * alpha.data
            self.exp_avg_y = self.exp_avg_y * beta1 + (1 - beta1) * d_alpha
            self.exp_avg_sq_y = self.exp_avg_sq_y * beta2 + (1 - beta2) * d_alpha ** 2

            exp_avg_y_hat = self.exp_avg_y / bias_correction1
            exp_avg_sq_y_hat = self.exp_avg_sq_y / bias_correction2

            alpha.data.addcdiv_(exp_avg_y_hat, (exp_avg_sq_y_hat + self.eps).sqrt(), value=lr_y)
            alpha.data = torch.clamp(alpha.data, 0, 999)

    def zero_grad(self):
        self.model.zero_grad()
        if self.a is not None and self.b is not None:
            self.a.grad = None
            self.b.grad = None
        if self.alpha is not None:
            self.alpha.grad = None

    def update_lr(self, decay_factor=None):
        if decay_factor != None:
            self.param_groups[0]['lr_x'] = self.param_groups[0]['lr_x'] / decay_factor
            self.param_groups[0]['lr_y'] = self.param_groups[0]['lr_y'] / decay_factor
            if self.verbose:
                print('Reducing learning rate_x to %.5f @ T=%s!' % (self.param_groups[0]['lr_x'], self.steps))
                print('Reducing learning rate_y to %.5f @ T=%s!' % (self.param_groups[0]['lr_y'], self.steps))

    def update_regularizer(self, decay_factor=None):
        if decay_factor != None:
            self.param_groups[0]['lr_x'] = self.param_groups[0]['lr_x'] / decay_factor
            self.param_groups[0]['lr_y'] = self.param_groups[0]['lr_y'] / decay_factor
            if self.verbose:
                print('Reducing learning rate_x to %.5f @ T=%s!' % (self.param_groups[0]['lr_x'], self.steps))
                print('Reducing learning rate_y to %.5f @ T=%s!' % (self.param_groups[0]['lr_y'], self.steps))
        if self.verbose:
            print('Updating regularizer @ T=%s!' % (self.steps))
        for i, param in enumerate(self.model_ref):
            self.model_ref[i].data = self.model_acc[i].data / self.T
        for i, param in enumerate(self.model_acc):
            self.model_acc[i].data = torch.zeros(param.shape, dtype=torch.float32, device=self.device,
                                                 requires_grad=False).to(self.device)
        self.T = 0



class TiAda_Adam(torch.optim.Optimizer):
    def __init__(self,
                 model,
                 loss_fn=None,
                 a=None,  # to be deprecated
                 b=None,  # to be deprecated
                 alpha=None,  # to be deprecated
                 margin=1.0,
                 lr_x=0.1,
                 eps=1e-8,
                 lr_y=0.1,
                 betas=(0.9,0.999),
                 acc_lr_y=0,
                 exp_avg_y=0,
                 exp_avg_sq_y=0,
                 gamma=None,  # to be deprecated
                 clip_value=1.0,
                 weight_decay=1e-5,
                 epoch_decay=2e-3,  # default: gamma=500
                 momentum=0.9,
                 verbose=True,
                 device=None,
                 **kwargs):

        if not device:
            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        else:
            self.device = device
        assert (gamma is None) or (epoch_decay is None), 'You can only use one of gamma and epoch_decay!'
        if gamma is not None:
            assert gamma > 0
            epoch_decay = 1 / gamma

        self.margin = margin
        self.model = model
        self.lr_x = lr_x
        self.lr_y = lr_y
        self.step = 1
        self.exp_avg_y = exp_avg_y
        self.exp_avg_sq_y = exp_avg_sq_y
        self.gamma = gamma  # to be deprecated
        self.clip_value = clip_value
        self.momentum = momentum
        self.weight_decay = weight_decay
        self.epoch_decay = epoch_decay
        self.acc_lr_y=acc_lr_y
        self.loss_fn = loss_fn
        self.eps = eps
        if loss_fn != None:
            try:
                self.a = loss_fn.a
                self.b = loss_fn.b
                self.alpha = loss_fn.alpha
            except:
                print('AUCLoss is not found!')
        else:
            self.a = a
            self.b = b
            self.alpha = alpha

        self.model_ref = self.init_model_ref()
        self.model_acc = self.init_model_acc()
        self.model_acc_lr_x = self.model_acc_lr_x()
        self.exp_avg = self.exp_avg()
        self.exp_avg_sq = self.exp_avg_sq()
        self.T = 0  # for epoch_decay
        self.steps = 0  # total optim steps
        self.verbose = verbose  # print updates for lr/regularizer

        def get_parameters(params):
            for p in params:
                yield p

        if self.a is not None and self.b is not None:
            self.params = get_parameters(list(model.parameters()) + [self.a, self.b])
        else:
            self.params = get_parameters(list(model.parameters()))
        self.defaults = dict(lr_x=self.lr_x,
                             lr_y=self.lr_y,
                             margin=margin,
                             a=self.a,
                             b=self.b,
                             betas=betas,
                             alpha=self.alpha,
                             clip_value=clip_value,
                             momentum=momentum,
                             weight_decay=weight_decay,
                             epoch_decay=epoch_decay,
                             model_ref=self.model_ref,
                             model_acc=self.model_acc,
                             model_acc_lr_x=self.model_acc_lr_x
                             )

        super(TiAda_Adam, self).__init__(self.params, self.defaults)

    def __setstate__(self, state):
        super(TiAda_Adam, self).__setstate__(state)
        for group in self.param_groups:
            group.setdefault('nesterov', False)

    def init_model_ref(self):
        self.model_ref = []
        for var in list(self.model.parameters()) + [self.a, self.b]:
            if var is not None:
                self.model_ref.append(torch.empty(var.shape).normal_(mean=0, std=0.01).to(self.device))
        return self.model_ref

    def init_model_acc(self):
        self.model_acc = []
        for var in list(self.model.parameters()) + [self.a, self.b]:
            if var is not None:
                self.model_acc.append(
                    torch.zeros(var.shape, dtype=torch.float32, device=self.device, requires_grad=False).to(
                        self.device))
        return self.model_acc

    def model_acc_lr_x(self):
        self.model_acc_lr_x = []
        for var in list(self.model.parameters()) + [self.a, self.b]:
            if var is not None:
                self.model_acc_lr_x.append(0)
        return self.model_acc_lr_x


    def exp_avg(self):
        self.exp_avg = []
        for var in list(self.model.parameters()) + [self.a, self.b]:
            if var is not None:
                self.exp_avg.append(
                    torch.zeros(var.shape, dtype=torch.float32, device=self.device, requires_grad=False).to(
                        self.device))
        return self.exp_avg

    def exp_avg_sq(self):
        self.exp_avg_sq = []
        for var in list(self.model.parameters()) + [self.a, self.b]:
            if var is not None:
                self.exp_avg_sq.append(
                    torch.zeros(var.shape, dtype=torch.float32, device=self.device, requires_grad=False).to(
                        self.device))
        return self.exp_avg_sq


    @property
    def optim_steps(self):
        return self.steps

    @property
    def get_params(self):
        return list(self.model.parameters())

    @torch.no_grad()
    def step(self, closure=None, delta_x=None, delta_y=None):
        """Performs a single optimization step."""
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            weight_decay = group['weight_decay']
            clip_value = group['clip_value']
            self.lr_x = group['lr_x']
            self.lr_y = group['lr_y']
            epoch_decay = group['epoch_decay']
            model_ref = group['model_ref']
            model_acc = group['model_acc']
            model_acc_lr_x = group['model_acc_lr_x']
            exp_avg = group['exp_avg']
            exp_avg_sq = group['exp_avg_sq']
            m = group['margin']
            a = group['a']
            b = group['b']
            beta1, beta2 = group['betas']
            alpha = group['alpha']
            bias_correction1 = 1 - beta1 ** self.step
            bias_correction2 = 1 - beta2 ** self.step
            bias_correction2_sqrt = math.sqrt(bias_correction2)
            for i, p in enumerate(group['params']):
                if p.grad is None:
                    continue
                d_p = torch.clamp(p.grad.data, -clip_value, clip_value) + epoch_decay * (
                        p.data - model_ref[i].data) + weight_decay * p.data
                exp_avg[i] = exp_avg[i] * beta1 + (1 - beta1) * d_p
                exp_avg_sq[i] = exp_avg_sq[i] * beta2 + (1 - beta2) * d_p ** 2

                step_size = group['lr_x'] / bias_correction1


                model_acc_lr_x[i] = model_acc_lr_x[i] + exp_avg_sq[i].sum()
                lr_x = 1 / (max(model_acc_lr_x[i], self.acc_lr_y) ** 0.5 / bias_correction2_sqrt + self.eps)

                p.data = p.data - step_size * lr_x * exp_avg[i]  # We may also consider using corrected values here
                model_acc[i].data = model_acc[i].data + p.data

            if alpha is not None:
                if alpha.grad is not None:
                    d_alpha = 2 * (m + b.data - a.data) - 2 * alpha.data
                    self.exp_avg_y = self.exp_avg_y * beta1 + (1 - beta1) * d_alpha
                    self.exp_avg_sq_y = self.exp_avg_sq_y * beta2 + (1 - beta2) * d_alpha ** 2
                    step_size = group['lr_y'] / bias_correction1
                    self.acc_lr_y += self.exp_avg_sq_y.sum()
                    lr_y = 1 / (self.acc_lr_y ** 0.5 / bias_correction2_sqrt + self.eps)
                    alpha.data = alpha.data + step_size * lr_y * self.exp_avg_y  # We may also consider using corrected values here
                    alpha.data = torch.clamp(alpha.data, 0, 999)

        self.step += 1
        return loss

    def zero_grad(self):
        self.model.zero_grad()
        if self.a is not None and self.b is not None:
            self.a.grad = None
            self.b.grad = None
        if self.alpha is not None:
            self.alpha.grad = None

    def update_lr(self, decay_factor=None):
        if decay_factor != None:
            self.param_groups[0]['lr_x'] = self.param_groups[0]['lr_x'] / decay_factor
            self.param_groups[0]['lr_y'] = self.param_groups[0]['lr_y'] / decay_factor
            if self.verbose:
                print('Reducing learning rate_x to %.5f @ T=%s!' % (self.param_groups[0]['lr_x'], self.steps))
                print('Reducing learning rate_y to %.5f @ T=%s!' % (self.param_groups[0]['lr_y'], self.steps))

    def update_regularizer(self, decay_factor=None):
        if decay_factor != None:
            self.param_groups[0]['lr_x'] = self.param_groups[0]['lr_x'] / decay_factor
            self.param_groups[0]['lr_y'] = self.param_groups[0]['lr_y'] / decay_factor
            if self.verbose:
                print('Reducing learning rate_x to %.5f @ T=%s!' % (self.param_groups[0]['lr_x'], self.steps))
                print('Reducing learning rate_y to %.5f @ T=%s!' % (self.param_groups[0]['lr_y'], self.steps))
        if self.verbose:
            print('Updating regularizer @ T=%s!' % (self.steps))
        for i, param in enumerate(self.model_ref):
            self.model_ref[i].data = self.model_acc[i].data / self.T
        for i, param in enumerate(self.model_acc):
            self.model_acc[i].data = torch.zeros(param.shape, dtype=torch.float32, device=self.device,
                                                 requires_grad=False).to(self.device)
        self.T = 0



