import torch
import math as _math
import numpy as np
import typing
import logging
from ..utils.Utils import exponential_average, weighted_avg_and_std


logging.basicConfig(filename='logging.txt', level=logging.WARNING)





class CombineOptimizers(object):
    def __init__(self, *optimizers):
        '''
        Allows the user to add multiple optimizers and update parameters with
        a single ```.step()``` call.
        This code has been taken from:
        https://discuss.pytorch.org/t/two-optimizers-for-one-model/11085/7
        
        
        Arguments
        ---------

        - ```optimizers```: ```torch.optim```:
            The optimizers to combine into one class. You 
            may pass as many as you like.

        '''

        self.optimizers = optimizers

    def zero_grad(self):
        for op in self.optimizers:
            op.zero_grad()

    def step(self, *args, **kwargs):
        for op in self.optimizers:
            op.step(*args, **kwargs)










class DiscreteRankingSTD(object):
    def __init__(self,
                    function=lambda x: x,
                    discrete_amount:float=0.005,
                    hold_off:int=0, 
                    strictness:float=1.0,
                    ):
        '''
        This class calculates the level of depression 
        to apply to gradient updates from a batch of 
        source data.
        
        
        
        Arguments
        ---------
        
        - ```function```: ```_type_```, optional:
            This argument allows you to apply a function
            to the value before it is returned. 
            Defaults to ```lambdax:x```.
        
        - ```discrete_amount```: ```float```, optional:
            The step size used when calculating the depression. 
            Defaults to ```0.005```.
        
        - ```hold_off```: ```int```, optional:
            The number of calls to this function before
            depression will start. Until depression starts,
            this function will return 0 on each call. 
            Defaults to ```0```.
        
        - ```strictness```: ```float```, optional:
            The number of standard deviations away from the 
            mean loss a mean source loss has to be 
            before depression is applied. 
            Defaults to ```1.0```.
        
        
        '''
        self.function = function
        self.discrete_amount = discrete_amount
        self.source_xn = np.asarray([])
        self.hold_off = hold_off
        self.strictness = strictness
        self.step = 0
        return

    def __call__(self, 
                    loss_array:np.ndarray, 
                    source_idx:int, 
                    *args, 
                    **kwargs):
        '''
        
        Arguments
        ---------
        
        - ```loss_array```: ```np.ndarray```: 
            The loss values for the last n batches of each source.
            Where n is the history size.
            This should be of shape ```(n_sources, n_batches_prev_tracked)```.
        
        - ```source_idx```: ```int```: 
            The index in the loss array of the source 
            being updated.
        

        Returns
        --------
        
        - ```out```: ```_type_``` : 
            The depression value, d in the depression calculation:
            ```dep = 1-tanh(m*d)**2```.
            This means, the larger the value, the more depression 
            will be applied during training.
        
        '''
        # increasing step and checking if the hold off time has passed.
        self.step += 1
        if self.step < self.hold_off:
            return 0

        logging.debug('Source Index {}'.format(source_idx))

        # keeps track of the current depression applied to each source
        # these will be used as weights in the standard deviation and 
        # mean calculations
        if len(loss_array) > len(self.source_xn):
            self.source_xn = np.hstack([self.source_xn, np.zeros(len(loss_array) - len(self.source_xn))])

        # mask is True where loss array source is not equal to the current source
        mask = np.ones(loss_array.shape[0], dtype=bool)
        mask[source_idx] = False

        # if the range in loss values is close to 0, return no depression
        if np.all(np.isclose(np.ptp(loss_array[mask]), 0)):
            return 0

        # mean loss of current source
        mean_source_loss = np.mean(loss_array[~mask])

        # weighted mean and standard deviation of the sources other
        # than the current source.
        weights = np.ones_like(loss_array)/((self.source_xn + 1)[:,np.newaxis])
        (mean_not_source_loss, 
        std_not_source_loss) = weighted_avg_and_std(loss_array[mask], 
                                                    weights=weights[mask])

        # calculates whether to trust a source more or less
        logging.debug('{} < {}'.format(mean_source_loss, mean_not_source_loss + self.strictness*std_not_source_loss))
        if mean_source_loss < mean_not_source_loss + self.strictness*std_not_source_loss:
            movement = -1
        else:
            movement = 1
        logging.debug('movement {}'.format(movement))
        logging.debug('source_xn {}'.format(self.source_xn[source_idx]))
        # moving the current trust level depending on the movement calculated above
        self.source_xn[source_idx] += movement
        if self.source_xn[source_idx] < 0:
            self.source_xn[source_idx] = 0
        
        # calculating the depression value
        depression = self.function(self.discrete_amount*self.source_xn[source_idx])

        return depression









class AdamLAP(torch.optim.Optimizer):
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
                lap_n:int=10, depression_strength:float=1.0, depression_function='discrete_ranking_std', 
                depression_function_kwargs:dict={},
                 weight_decay=0, amsgrad=False, *, foreach = None,
                 maximize: bool = False, source_is_bool:bool=False, 
                 ):
        '''
        Depression won't be applied until at least ```lap_n``` loss values
        have been collected for at least two sources. This could be 
        longer if a ```hold_off``` parameter is used in the depression function.

        LAP stands for loss adapted plasticity.
        
        Arguments
        ---------
        
        - ```params```: 
            Iterable of parameters to optimize or dicts defining
            parameter groups. The ```depression_strength``` argument
            is compatible with parameter groups, however ```lap_n```,
            ```depression_function```, ```depression_function_kwargs```,
            ```source_is_bool``` are not.
        
        - ```lr```: ```float```, optional:
            Learning rate. See documentation for ```torch.optim.Adam```.
            Defaults to ```1e-3```.
        
        - ```betas```: ```tuple```, optional:
            See documentation for ```torch.optim.Adam```.
            Defaults to ```(0.9, 0.999)```.
        
        - ```lap_n```: ```int```, optional:
            The number of previous loss values for each source
            to be used in the loss adapted plasticity
            calculations.
            Defaults to ```10```.
        
        - ```eps```: ```float```, optional:
            See documentation for ```torch.optim.Adam```.
            Defaults to ```1e-8```.
        
        - ```weight_decay```: ```float```, optional:
            Weight decay (L2 penalty).
            See documentation for ```torch.optim.Adam```.
            Defaults to ```0```.
        
        - ```amsgrad```: ```bool```, optional:
            See documentation for ```torch.optim.Adam```.
            Defaults to ```False```.
        
        - ```depression_strength```: ```float```:
            This float determines the strength of the depression
            applied to the gradients. It is the value of m in 
            ```dep = 1-tanh(m*d)**2```.
            Defaults to ```1```.
        
        - ```depression_function```: ```function``` or ```string```, optional:
            This is the function used to calculate the depression
            based on the loss array (with sources containing full 
            loss history) and the source of the current batch. 
            Ensure that the first two arguments of this function are
            ```loss_array``` and ```source_idx```.
            If string, please ensure it is ```'discrete_ranking_std'```
            Defaults to ```discrete_ranking_std```.
        
        - ```depression_function_kwargs```: ```dict```, optional:
            Keyword arguments that will be used in ```depression_function```
            when initiating it, if it is specified by a string.
            Defaults to ```{}```.
        
        - ```source_is_bool```: ```bool```, optional:
            This tells the optimizer that the sources will be named ```True```
            when the source is corrupted and ```False``` if the source is not.
            If the incoming source is corrupted, then the optimizer will not
            make a step.
            Defaults to ```False```.
        
        '''
        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if not 0.0 <= eps:
            raise ValueError("Invalid epsilon value: {}".format(eps))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
        if not 0.0 <= weight_decay:
            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
        if (not 0 <= lap_n) and (type(lap_n) == int):
            raise ValueError("Invalid parameter for lap_n: {}. "\
                                "Please use an integer larger than 0".format(lap_n))
        if not 0.0 <= depression_strength:
            raise ValueError("Invalid depression stregnth: {}".format(depression_strength))

        # storing settings and creating the loss array
        self.lap_n = lap_n
        self.loss_array = -1*np.ones((1,self.lap_n))
        self.source_dict = {}
        self.n_sources = 0
        self.depression_function_kwargs = depression_function_kwargs
        self.depression_function = (depression_function if not type(depression_function) == str 
                                                        else self._get_depression_function(depression_function))
        self.source_step_dict = {}
        self.source_is_bool = source_is_bool

        defaults = dict(lr=lr, betas=betas, eps=eps,
                        depression_strength=depression_strength,
                        weight_decay=weight_decay, amsgrad=amsgrad,
                        maximize=maximize, foreach=foreach)
        super(AdamLAP, self).__init__(params, defaults)

    def __setstate__(self, state):
        super().__setstate__(state)
        for group in self.param_groups:
            group.setdefault('amsgrad', False)
            group.setdefault('maximize', False)
            group.setdefault('foreach', None)
        state_values = list(self.state.values())
        step_is_tensor = (len(state_values) != 0) and torch.is_tensor(state_values[0]['step'])
        if not step_is_tensor:
            for s in state_values:
                s['step'] = torch.tensor(float(s['step']))

    def _has_complete_history(self):
        # returns source indices in which there is a complete history of losses
        return np.argwhere(np.sum(self.loss_array != -1, axis=1) == self.lap_n).reshape(-1)

    def _get_depression_function(self, name):
        '''
        Function to get the drepression function by name.
        '''
        if name == 'discrete_ranking_std':
            return DiscreteRankingSTD(**self.depression_function_kwargs)

        else:
            raise NotImplementedError('{} is not a known depression function. Please '\
                                        'pass the function instead of the name.'.format(name))

    @torch.no_grad()
    def step(self, loss:float, source, closure=None, override_dep=None, writer=None, **kwargs):
        """Performs a single optimization step.

        Arguments
        ---------

        - loss: ```float```:
            This is the loss value that is used in the depression calculations.
        
        - source: ```hashable```:
            This is the source name that is used to
            store the loss values for the different sources.
        
        - closure: ```(callable, optional)```: 
            A closure that reevaluates the model and returns the loss.
            See documentation for ```torch.optim.Adam```.
            Defaults to ```None```.
        
        - override_dep: ```bool``` or ```None```:
            If ```None```, then whether to apply depression will be decided
            based on the logic of this class. If ```True```, then depression will 
            be applied. This might cause unexpected results if there is no depression value
            calculated based on whether there is enough data available in the 
            ```.loss_array```. In this case, not depression is applied.
            If ```False```, then depression will not be applied.
            This is mostly useful as an option to turn off LAP.
            Defaults to ```None```.
        
        - ```writer```: ```torch.utils.tensorboard.SummaryWriter```:
            A tensorboard writer can be passed into this function to track metrics.
            Defaults to ```None```.

        """



        loss_calc = None
        if closure is not None:
            loss_calc = closure()

        logging.debug('source, {}'.format(source))
        logging.debug('loss, {}'.format(loss))

        # if reliability of source is given, update only when
        # data is reliable
        if self.source_is_bool:
            if source:
                return loss_calc
            else:
                if not override_dep in [True, False]:
                    override_dep = False

        # building the loss array
        if not source in self.source_dict:
            # if new source, add row to the loss array
            self.source_dict[source] = self.n_sources
            self.n_sources += 1
            source_idx = self.source_dict[source]
            self.loss_array = np.concatenate([self.loss_array, -1*np.ones((1, self.lap_n))], axis=0)
            self.loss_array[source_idx, -1] = loss
        else:
            # if already tracked source, move history along and add new loss value
            source_idx = self.source_dict[source]
            losses = self.loss_array[source_idx]
            losses[:-1] = losses[1:]
            losses[-1] = loss
            logging.debug('losses, {}'.format(losses))
            logging.debug('loss array, {}'.format(self.loss_array))
        
        # saves the number of times each source has been seen for summary writer
        if not source in self.source_step_dict:
            self.source_step_dict[source] = 0
        self.source_step_dict[source] += 1

        # finds sources that have a complete history of losses
        history_idx = self._has_complete_history()

        # if current source has full history and at least one other source does
        # then perform depression calculations
        if (len(history_idx)>1) and (source_idx in history_idx):
            depressing = True
        else:
            depressing = False

        # calculate the depression value
        if depressing:
            depression = self.depression_function(loss_array=self.loss_array[history_idx], 
                                                    source_idx=np.argwhere(history_idx == source_idx).reshape(-1)[0])
        logging.debug('depressing, {}'.format(depressing))
        
        # depression boolean override from argument
        # if override is True and there is no depression value calculated
        # the then depression value is set to 0 (no depression)
        if not override_dep is None:
            if override_dep in [True, False]:
                if not depressing:
                    depression=0.0
                depressing = override_dep
            else:
                raise TypeError('override_dep must be of boolean value, or None. Please see docs.')


        for group in self.param_groups:
            params_with_grad = []
            grads = []
            exp_avgs = []
            exp_avg_sqs = []
            max_exp_avg_sqs = []
            state_steps = []
            beta1, beta2 = group['betas']

            # calculate the actual depression to be multiplied by the gradients
            if depressing:
                logging.debug('Depression, {}'.format(depression))
                actual_depression = 1-torch.pow(
                                            torch.tanh(
                                                torch.tensor(depression*group['depression_strength'])),
                                            2).item()
            else:
                actual_depression = 1
            
            # saves the depression value to the writer
            if not writer is None:
                writer.add_scalars('Actual Depression Value', 
                                    {'{}'.format(source): actual_depression}, 
                                    self.source_step_dict[source])

            logging.debug('Actual Depression, {}'.format(actual_depression))

            for p in group['params']:
                if p.grad is not None:
                    params_with_grad.append(p)
                    if p.grad.is_sparse:
                        raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
                    
                    # ======= applying depression ======= 
                    p.grad.mul_(actual_depression)
                    # =================================== 

                    grads.append(p.grad)
                    state = self.state[p]
                    # Lazy state initialization
                    if len(state) == 0:
                        state['step'] = torch.tensor(0.)
                        # Exponential moving average of gradient values
                        state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                        # Exponential moving average of squared gradient values
                        state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                        if group['amsgrad']:
                            # Maintains max of all exp. moving avg. of sq. grad. values
                            state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)

                    exp_avgs.append(state['exp_avg'])
                    exp_avg_sqs.append(state['exp_avg_sq'])

                    if group['amsgrad']:
                        max_exp_avg_sqs.append(state['max_exp_avg_sq'])

                    state_steps.append(state['step'])

            for i, param in enumerate(params_with_grad):

                grad = grads[i] if not group['maximize'] else -grads[i]
                exp_avg = exp_avgs[i]
                exp_avg_sq = exp_avg_sqs[i]
                step_t = state_steps[i]
                # update step
                step_t += 1
                step = step_t.item()

                bias_correction1 = 1 - beta1 ** step
                bias_correction2 = 1 - beta2 ** step

                if group['weight_decay'] != 0:
                    grad = grad.add(param, alpha=group['weight_decay'])

                # Decay the first and second moment running average coefficient
                exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
                if group['amsgrad']:
                    # Maintains the maximum of all 2nd moment running avg. till now
                    torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i])
                    # Use the max. for normalizing running avg. of gradient
                    denom = (max_exp_avg_sqs[i].sqrt() / _math.sqrt(bias_correction2)).add_(group['eps'])
                else:
                    denom = (exp_avg_sq.sqrt() / _math.sqrt(bias_correction2)).add_(group['eps'])



                step_size = group['lr'] / bias_correction1
                param.addcdiv_(exp_avg, denom, value=-step_size)

        return loss_calc












class SGDLAP(torch.optim.Optimizer):
    def __init__(self, params, lr=1e-3, momentum=0, dampening=0, lap_n:int=10, depression_strength:float=1,
                 weight_decay=0, nesterov=False, *, maximize=False, foreach: bool = None,
                 depression_function='min_max_mean', depression_function_kwargs:dict={},
                 source_is_bool:bool=False, 
                ):
        '''
        Implements stochastic gradient descent (optionally with momentum).

        Nesterov momentum is based on the formula from
        `On the importance of initialization and momentum in deep learning`__.

        Depression won't be applied until at least ```lap_n``` loss values
        have been collected for at least two sources.

        LAP stands for loss adapted plasticity.
        
        Arguments
        ---------
        
        - ```params```: 
            Iterable of parameters to optimize or dicts defining
            parameter groups. The ```depression_strength``` argument
            is compatible with parameter groups, however ```lap_n```,
            ```depression_function```, ```depression_function_kwargs```,
            ```source_is_bool``` are not.

        - ```lr```: ```float```, optional:
            Learning rate.
            Defaults to ```1e-3```.
        
        - ```lap_n```: ```int```, optional:
            The number of previous loss values for each source
            to be used in the loss adapted plasticity
            calculations.
            Defaults to ```10```.
        
        - ```momentum```: ```float```, optional:
            Momentum factor.
            Defaults to ```0```.

        - ```dampening```: ```float```, optional:
            Dampening for momentum.
            Defaults to ```0```.

        - ```nesterov```: ```bool```, optional:
            Eenables Nesterov momentum.
            Defaults to ```False```.

        - ```maximize```: ```bool```, optional:
            Maximize the params based on the objective, 
            instead of minimizing
            Defaults to ```False```.

        - ```weight_decay```: ```float```, optional:
            Weight decay (L2 penalty)
            Defaults to ```0```.

        - ```foreach```: ```bool```, optional:
            Whether foreach implementation of optimizer
            is used.
            Defaults to ```None```.

        - ```depression_strength```: ```float```:
            This float determines the strength of the depression
            applied to the gradients. It is the value of m in 
            ```dep = 1-tanh(m*d)**2```.
            Defaults to ```1```.
        
        - ```depression_function```: ```function``` or ```string```, optional:
            This is the function used to calculate the depression
            based on the loss array (with sources containing full 
            loss history) and the source of the current batch. 
            Ensure that the first two arguments of this function are
            ```loss_array``` and ```source_idx```.
            If string, please ensure it is ```'discrete_ranking_std'```
            Defaults to ```discrete_ranking_std```.
        
        - ```depression_function_kwargs```: ```dict```, optional:
            Keyword arguments that will be used in ```depression_function```
            when initiating it, if it is specified by a string.
            Defaults to ```{}```.
        
        - ```source_is_bool```: ```bool```, optional:
            This tells the optimizer that the sources will be named ```True```
            when the source is corrupted and ```False``` if the source is not.
            If the incoming source is corrupted, then the optimizer will not
            make a step.
            Defaults to ```False```.
        '''

        if lr < 0.0:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if momentum < 0.0:
            raise ValueError("Invalid momentum value: {}".format(momentum))
        if weight_decay < 0.0:
            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
        if (not 0 <= lap_n) and (type(lap_n) == int):
            raise ValueError("Invalid parameter for lap_n: {}. "\
                                "Please use an integer larger than 0".format(lap_n))

        defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
                        weight_decay=weight_decay, nesterov=nesterov,
                        maximize=maximize, foreach=foreach, lap_n=lap_n, 
                        depression_strength=depression_strength)
        if nesterov and (momentum <= 0 or dampening != 0):
            raise ValueError("Nesterov momentum requires a momentum and zero dampening")
        super(SGDLAP, self).__init__(params, defaults)

        self.lap_n = lap_n
        self.loss_array = -1*np.ones((1,self.lap_n))
        self.source_dict = {}
        self.n_sources = 0
        self.depression_function_kwargs = depression_function_kwargs
        self.depression_function = (depression_function if not type(depression_function) == str 
                                                        else self._get_depression_function(depression_function))
        self.source_step_dict = {}
        self.source_is_bool = source_is_bool

    def _has_complete_history(self):
        # returns source indices in which there is a complete history of losses
        return np.argwhere(np.sum(self.loss_array != -1, axis=1) == self.lap_n).reshape(-1)

    def _get_depression_function(self, name):
        '''
        Function to get the drepression function by name.
        '''
        if name == 'discrete_ranking_std':
            return DiscreteRankingSTD(**self.depression_function_kwargs)

        else:
            raise NotImplementedError('{} is not a known depression function. Please '\
                                        'pass the function instead of the name.'.format(name))

    def __setstate__(self, state):
        super().__setstate__(state)
        for group in self.param_groups:
            group.setdefault('nesterov', False)
            group.setdefault('maximize', False)
            group.setdefault('foreach', None)

    @torch.no_grad()
    def step(self, loss:float, source, closure=None, override_dep=None, writer = None, **kwargs):
        """Performs a single optimization step.

        Arguments
        ---------

        - loss: ```float```:
            This is the loss value that is used in the depression calculations.
        
        - source: ```hashable```:
            This is the source name that is used to
            store the loss values for the different sources.
        
        - closure: ```(callable, optional)```: 
            A closure that reevaluates the model and returns the loss.
            See documentation for ```torch.optim.Adam```.
            Defaults to ```None```.
        
        - override_dep: ```bool``` or ```None```:
            If ```None```, then whether to apply depression will be decided
            based on the logic of this class. If ```True```, then depression will 
            be applied. This might cause unexpected results if there is no depression value
            calculated based on whether there is enough data available in the 
            ```.loss_array```. In this case, not depression is applied.
            If ```False```, then depression will not be applied.
            This is mostly useful as an option to turn off LAP.
            Defaults to ```None```.
        
        - ```writer```: ```torch.utils.tensorboard.SummaryWriter```:
            A tensorboard writer can be passed into this function to track metrics.
            Defaults to ```None```.

        """

        loss_calc = None
        if closure is not None:
            loss_calc = closure()

        # if reliability of source is given, update only when
        # data is reliable
        if self.source_is_bool:
            if source:
                return loss_calc
            else:
                if not override_dep in [True, False]:
                    override_dep = False

        logging.debug('source, {}'.format(source))
        logging.debug('loss, {}'.format(loss))

        # building the loss array
        if not source in self.source_dict:
            # if new source, add row to the loss array
            self.source_dict[source] = self.n_sources
            self.n_sources += 1
            source_idx = self.source_dict[source]
            self.loss_array = np.concatenate([self.loss_array, -1*np.ones((1,self.lap_n))], axis=0)
            self.loss_array[source_idx, -1] = loss
        else:
            # if already tracked source, move history along and add new loss value
            source_idx = self.source_dict[source]
            losses = self.loss_array[source_idx]
            losses[:-1] = losses[1:]
            losses[-1] = loss
            logging.debug('losses, {}'.format(losses))
            logging.debug('loss array, {}'.format(self.loss_array))
        
        # saves the number of times each source has been seen for summary writer
        if not source in self.source_step_dict:
            self.source_step_dict[source] = 0
        self.source_step_dict[source] += 1
        
        # finds sources that have a complete history of losses
        history_idx = self._has_complete_history()

        # if current source has full history and at least one other source does
        # then perform depression calculations
        if (len(history_idx)>1) and (source_idx in history_idx):
            depressing = True
        else:
            depressing = False

        # calculate the depression value
        if depressing:
            depression = self.depression_function(loss_array=self.loss_array[history_idx], 
                                                    source_idx=np.argwhere(history_idx == source_idx).reshape(-1)[0],
                                                    )
        logging.debug('depressing, {}'.format(depressing))

        # depression boolean override from argument
        # if override is True and there is no depression value calculated
        # the then depression value is set to 0 (no depression)
        if not override_dep is None:
            if override_dep in [True, False]:
                if not depressing:
                    depression=0.0
                depressing = override_dep
            else:
                raise TypeError('override_dep must be of boolean value, or None. Please see docs.')


        for group in self.param_groups:
            params_with_grad = []
            grads = []
            momentum_buffer_list = []
            has_sparse_grad = False

            for p in group['params']:
                if p.grad is not None:
                    params_with_grad.append(p)
                    grads.append(p.grad)
                    if p.grad.is_sparse:
                        has_sparse_grad = True

                    state = self.state[p]
                    if 'momentum_buffer' not in state:
                        momentum_buffer_list.append(None)
                    else:
                        momentum_buffer_list.append(state['momentum_buffer'])

            if len(params_with_grad) == 0:
                continue
            
            # calculate the actual depression to be multiplied by the gradients
            if depressing:
                logging.debug('Depression, {}'.format(depression))
                actual_depression = 1-torch.pow(
                                            torch.tanh(
                                                torch.tensor(depression*group['depression_strength'])),
                                            2).item()
            else:
                actual_depression = 1

            if not writer is None:
                writer.add_scalars('Actual Depression Value', 
                                    {'{}'.format(source): actual_depression}, 
                                    self.source_step_dict[source])
            logging.debug('Actual Depression: {}'.format(actual_depression))

            
            # ======= applying depression ======= 
            torch._foreach_mul_(grads, actual_depression)
            # =================================== 

            if has_sparse_grad is None:
                has_sparse_grad = any([grad.is_sparse for grad in grads])

            if group['weight_decay'] != 0:
                grads = torch._foreach_add(grads, params_with_grad, alpha=group['weight_decay'])

            if group['momentum'] != 0:
                bufs = []

                all_states_with_momentum_buffer = True
                for i in range(len(momentum_buffer_list)):
                    if momentum_buffer_list[i] is None:
                        all_states_with_momentum_buffer = False
                        break
                    else:
                        bufs.append(momentum_buffer_list[i])

                if all_states_with_momentum_buffer:
                    torch._foreach_mul_(bufs, group['momentum'])
                    torch._foreach_add_(bufs, grads, alpha=(1 - group['dampening']))
                else:
                    bufs = []
                    for i in range(len(momentum_buffer_list)):
                        if momentum_buffer_list[i] is None:
                            buf = momentum_buffer_list[i] = torch.clone(grads[i]).detach()
                        else:
                            buf = momentum_buffer_list[i]
                            buf.mul_(group['momentum']).add_(grads[i], 
                                                            alpha=(1 - group['dampening']))

                        bufs.append(buf)

                if group['nesterov']:
                    torch._foreach_add_(grads, bufs, alpha=group['momentum'])
                else:
                    grads = bufs

            alpha = group['lr'] if group['maximize'] else -group['lr']

            if not has_sparse_grad:
                torch._foreach_add_(params_with_grad, grads, alpha=alpha)
            else:
                # foreach APIs dont support sparse
                for i in range(len(params_with_grad)):
                    params_with_grad[i].add_(grads[i], alpha=alpha)


            # update momentum_buffers in state
            for p, momentum_buffer in zip(params_with_grad, momentum_buffer_list):
                state = self.state[p]
                state['momentum_buffer'] = momentum_buffer

        return loss_calc







