import torch
from src.utils import utils
from src import regularizers #as reg
import math

class LinBreg(torch.optim.Optimizer):
    def __init__(self,params,lr=1e-3,reg=regularizers.reg_none(), delta=1.0, momentum=0.0):
        if lr < 0.0:
            raise ValueError("Invalid learning rate")
            
        defaults = dict(lr=lr, reg=reg, delta=delta, momentum=momentum)
        super(LinBreg, self).__init__(params, defaults)
        
    @torch.no_grad()
    def step(self, closure=None):
        for group in self.param_groups:
            delta = group['delta']
            # define regularizer for this group
            reg = group['reg'] 
            step_size = group['lr']
            momentum = group['momentum']
            for p in group['params']:
                if p.grad is None:
                    continue
                # get grad and state
                grad = p.grad.data
                state = self.state[p]
                if len(state) == 0:
                    state['step'] = 0
                    # get prox
                    # initialize subgradients
                    state['sub_grad'] = self.initialize_sub_grad(p,reg, delta)
                    state['momentum_buffer'] = None
                # -------------------------------------------------------------
                # update scheme
                # -------------------------------------------------------------
                # get the current sub gradient
                sub_grad = state['sub_grad']
                # update on the subgradient
                if momentum > 0.0: # with momentum
                    mom_buff = state['momentum_buffer']
                    if state['momentum_buffer'] is None:
                        mom_buff = torch.zeros_like(grad)
 
                    mom_buff.mul_(momentum)
                    mom_buff.add_((1-momentum)*step_size*grad) 
                    state['momentum_buffer'] = mom_buff
                    #update subgrad
                    sub_grad.add_(-mom_buff)
                                                            
                else: # no momentum
                    sub_grad.add_(-step_size * grad)
                # update step for parameters
                p.data = reg.prox(delta * sub_grad, delta)
        
    def initialize_sub_grad(self,p, reg, delta):
        p_init = p.data.clone()
        return 1/delta * p_init + reg.sub_grad(p_init)
    
    @torch.no_grad()
    def evaluate_reg(self):
        reg_vals = []
        for group in self.param_groups:
            group_reg_val = 0.0
            delta = group['delta']
            
            # define regularizer for this group
            reg = group['reg']
            
            # evaluate the reguarizer for each parametr in group
            for p in group['params']:
                group_reg_val += reg(p)
                
            # append the group reg val
            reg_vals.append(group_reg_val)
            
        return reg_vals  

class perform_mask(torch.nn.Module):
    def __init__(self, mask):
        self.mask = mask
        super(perform_mask)
    def __call__(self, x):
        return self.mask * x 

class LinBregSparse(torch.optim.Optimizer):
    def __init__(self, params, lr=1e-2, delta=1.0, reg=regularizers.reg_none(), momentum=0.0,
                 full_update_frequency=1, full_update_duration=1):
        """
        LinBreg algorithm but primarily update only the non-zero parameters. 
        Update all parameters every full_update_frequency training steps.
        
        Parameters
        ----------
        params : Model parameters.
        lr : Float, optional
            Learning rate for the optimizer. The default is 1e-2.
        delta : Float, optional
            Smoothing parameter used in the regulariser. The default is 1.0.
        reg : Regularizer, optional
            Choice of regularizer to be used in the proximal step.
            The default is no regularizer.
        full_update_frequency : Int, optional
            How often all parameters are updated, not just the non-zeros.
            The default is 1.

        """       
        full_update_period = full_update_duration + full_update_frequency - 1
        defaults = dict(lr=lr, delta=delta, reg=reg,
                        full_update_period=full_update_period,
                        full_update_frequency=full_update_frequency,
                        full_update_duration=full_update_duration, 
                        momentum=momentum)
        super(LinBregSparse, self).__init__(params, defaults)
        
        # Initialise states immediately so that hooks are applied on first backward pass
        for group in self.param_groups:
            delta = group['delta']
            reg= group['reg']
            for p in group['params']:
                if p not in self.state:
                    self.state[p] = {}
                state = self.state[p]
                state['step'] = 0 # First update always full
                state['sub_grad'] = self.initialize_sub_grad(p, reg, delta)
                state['momentum_buffer'] = None

    def update_hook(self, param, mask=None):
        if param.requires_grad:
            # Only apply hook to a parameter that:
            # - May feasibly get updated 
            if mask is None:
                mask = param.detach() != 0
                
            handle = self.state[param].get('hook_handle')
            if handle: # Remove existing hook
                handle.remove()
            
            self.state[param]['hook_handle'] = param.register_hook(perform_mask(mask))
            
    def remove_hook(self, param):
        handle = self.state[param].get('hook_handle')
        if handle:
            handle.remove()
        

    @torch.no_grad()
    def step(self, closure=None):
        for group in self.param_groups:
            delta = group['delta']
            lr = group['lr']
            reg = group['reg']
            momentum = group['momentum']
            # full_update_period = group['full_update_period']
            full_update_duration = group['full_update_duration']
            full_update_frequency = group['full_update_frequency']
            full_update_period = full_update_duration + full_update_frequency - 1
            
            for param in group['params']:
                if param.grad is None:
                    continue
                grad = param.grad
                
                state = self.state[param]
                step = state['step']
                
                if param.requires_grad and grad is not None:
                    # Current update | Next update | Hook modification
                    #   Full         |   Full      |   Nothing
                    #   Sparse       |   Sparse    |   Nothing
                    #   Full         |   Sparse    |   Redefine hook
                    #   Sparse       |   Full      |   Remove hook
                    full_update      =  (step % full_update_period ) < full_update_duration
                    next_update_full  = ((step+1) % full_update_period ) < full_update_duration
                   
                    sub_grad = state['sub_grad']
                    if momentum > 0.0:
                        mom_buff = state['momentum_buffer']
                        if state['momentum_buffer'] is None:
                            mom_buff = torch.zeros_like(grad)
    
                        mom_buff.mul_(momentum)
                        mom_buff.add_((1-momentum)*lr*grad) 
                        state['momentum_buffer'] = mom_buff
                        #update subgrad
                        sub_grad.add_(-mom_buff)
                    else:
                        sub_grad.add_( -lr * grad ) # Hooks have masked gradient implicitly

                    if full_update:
                        param.copy_(reg.prox(delta * sub_grad, delta))
                        
                        if not next_update_full and isinstance(reg,regularizers.reg_l1_l2_conv):
                            # Compute ||K_ij||_2 for each kernel
                            norm_per_kernel = torch.norm(param.view(param.shape[0], param.shape[1], -1), p=2, dim=2)
                            mask = (norm_per_kernel > 0.0).float() 
                            mask = mask[:, :, None, None]  # shape: (out_channels, in_channels, 1, 1)
                            self.update_hook(param,mask)
                            state['mask'] = mask
                        elif not next_update_full:
                            mask = param.clone().detach() != 0
                            self.update_hook(param, mask)
                            state['mask'] = mask
                            # sub_grad *= mask # Fully commit to restricted subgrad
                            
                    else: # Sparse update
                        # Naive implementation of restricted subgrad
                        param.copy_(reg.prox(delta * sub_grad * state['mask'], delta))
                           
                        if next_update_full:
                            self.remove_hook(param)
                        
                state['step'] += 1
                
            
    def initialize_sub_grad(self, param, reg, delta):
        p = param.data.clone()
        return 1/delta * p + reg.sub_grad(p)
    
    

class AdaBregSparse(torch.optim.Optimizer):
    def __init__(self, params, lr=1e-2,delta=1.0, reg=regularizers.reg_none(), 
                 full_update_frequency=1, betas = (0.9,0.999), eps=1e-8):
        """
        LinBreg algorithm but primarily update only the non-zero parameters. 
        Update all parameters every full_update_frequency training steps.
        
        Parameters
        ----------
        params : Model parameters.
        lr : Float, optional
            Learning rate for the optimizer. The default is 1e-2.
        delta : Float, optional
            Smoothing parameter used in the regulariser. The default is 1.0.
        reg : Regularizer, optional
            Choice of regularizer to be used in the proximal step.
            The default is no regularizer.
        full_update_frequency : Int, optional
            How often all parameters are updated, not just the non-zeros.
            The default is 1.

        """       
        defaults = dict(lr=lr, delta=delta, reg=reg,
                        full_update_frequency=full_update_frequency, betas=betas, eps=eps)
        super(AdaBregSparse, self).__init__(params, defaults)
        
        # Initialise states immediately so that hooks are applied on first backward pass
        for group in self.param_groups:
            delta = group['delta']
            reg= group['reg']
            for p in group['params']:
                state = self.state[p]
                state['step'] = 0
                state['sub_grad'] = self.initialize_sub_grad(p, reg, delta)
                state['exp_avg'] = torch.zeros_like(state['sub_grad'])
                state['exp_avg_sq'] = torch.zeros_like(state['sub_grad'])
                if full_update_frequency > 1: # First update will be sparse
                    self.update_hook(p)

    def update_hook(self, param, mask=None):
        if param.requires_grad and param.count_nonzero().item() > 0:
            # Only apply hook to a parameter that:
            # - May feasibly get updated 
            # - Contain non-zero entries
            if mask is None:
                mask = param.clone().detach() != 0
                
            handle = self.state[param].get('hook_handle')
            if handle: # Remove existing hook
                handle.remove()
            
            self.state[param]['hook_handle'] = param.register_hook(perform_mask(mask))
            
    def remove_hook(self, param):
        handle = self.state[param].get('hook_handle')
        if handle:
            handle.remove()
        

    @torch.no_grad()
    def step(self, closure=None):
        for group in self.param_groups:
            delta = group['delta']
            lr = group['lr']
            reg = group['reg']
            full_update_frequency = group['full_update_frequency']
            beta1,beta2 = group['betas']
            eps = group['eps']

            for param in group['params']:
                if param.grad is None:
                    continue
                grad = param.grad
                
                state = self.state[param]
                step = state['step']
                step +=1
                
                if param.requires_grad and grad is not None:
                    sub_grad = state['sub_grad']
                    exp_avg = state['exp_avg']
                    exp_avg_sq = state['exp_avg_sq']

                    bias_correction1 = 1-beta1**step
                    bias_correction2 = 1 - beta2**step

                    exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
                    exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)

                    # denominator in the fraction
                    denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
                    
                    # step size in adam update
                    step_size = lr / bias_correction1
                    
                    # update subgrad
                    sub_grad.addcdiv_(exp_avg, denom, value=-step_size)

                    #update parameter
                    param.copy_(reg.prox(delta * sub_grad, delta))
                    
                    # Current update | Next update | Hook modification
                    #   Full         |   Full      |   Nothing
                    #   Sparse       |   Sparse    |   Nothing
                    #   Full         |   Sparse    |   Redefine hook
                    #   Sparse       |   Full      |   Remove hook
                    full_update      = (step      % full_update_frequency == 0)
                    next_update_full = (step + 1) % full_update_frequency == 0
                    # if full_update and not next_update_full:
                    #     self.update_hook(param)
                    # elif (not full_update) and next_update_full:
                    #     self.remove_hook(param)


                    if full_update and not next_update_full and isinstance(reg,regularizers.reg_l1_l2_conv):
                            norm_per_kernel = torch.norm(param.view(param.shape[0], param.shape[1], -1), p=2, dim=2)
                            mask = (norm_per_kernel > 0.0).float() 
                            mask = mask[:, :, None, None]  # shape: (out_channels, in_channels, 1, 1)
                            self.update_hook(param,mask)
                            state['mask'] = mask
                    elif full_update and (not next_update_full):
                            mask = param.clone().detach() != 0
                            self.update_hook(param, mask)
                            state['mask'] = mask
                            # sub_grad *= mask # Fully commit to restricted subgrad     
                    elif (not full_update) and next_update_full:
                        self.remove_hook(param)        

                state['step'] += 1
            
    def initialize_sub_grad(self, param, reg, delta):
        p = param.data.clone()
        return 1/delta * p + reg.sub_grad(p)

    
    
class LinBregSparseML(torch.optim.Optimizer):
    def __init__(self, params, lr=1e-2, delta=1.0, reg=regularizers.reg_none(), 
                 eps=1e-3, kappa=0.8):
        """
        LinBreg algorithm but primarily update only the non-zero parameters. 
        Update all parameters of a module on the next iteration when a condition
        based on multilevel methods is satisfied, i.e. when either the current sparse 
        gradient norm is much smaller than the full gradient norm or if the sparse
        gradient norm is too small.
        
        Parameters
        ----------
        params : Model parameters.
        lr : Float, optional
            Learning rate for the optimizer. The default is 1e-2.
        delta : Float, optional
            Smoothing parameter used in the regulariser. The default is 1.0.
        reg : Regularizer, optional
            Choice of regularizer to be used in the proximal step.
            The default is no regularizer.
        eps: Float in (0,1), optional
            Leniancy for how strict coarse update condition should be for 
        kappa: Float in (0,1), optional
            Leniancy for how strict coarse update condition should be
        """       
        defaults = dict(lr=lr, delta=delta, reg=reg,
                        eps=eps, kappa=kappa)
        super(LinBregSparseML, self).__init__(params, defaults)
        
        # Initialise states immediately so that hooks are applied on first backward pass
        for group in self.param_groups:
            delta = group['delta']
            reg= group['reg']
            for p in group['params']:
                state = self.state[p]
                state['step'] = 0
                state['sparse_flag']=[]
                state['next_update_full'] = True
                state['sub_grad'] = self.initialize_sub_grad(p, reg, delta)       
                
    def update_hook(self, param, mask=None):
        if param.requires_grad:
            # Only apply hook to a parameter that:
            # - May feasibly get updated 
            if mask is None:
                mask = param.clone().detach() != 0
                
            handle = self.state[param].get('hook_handle')
            if handle: # Remove existing hook
                handle.remove()
            
            class perform_mask(torch.nn.Module):    
                def __init__(self, mask):
                    self.mask = mask
                    super(perform_mask)
                def __call__(self, x):
                    return self.mask * x
            
            self.state[param]['hook_handle'] = param.register_hook(perform_mask(mask))
            
    def remove_hook(self, param):
        handle = self.state[param].get('hook_handle')
        if handle:
            handle.remove()      
            
    @torch.no_grad()
    def step(self, closure=None):
        for group in self.param_groups:
            delta = group['delta']
            lr = group['lr']
            reg = group['reg']

            eps = group['eps']
            kappa = group['kappa']
            for param in group['params']:
                if param.grad is None:
                    continue
                grad = param.grad
                
                state = self.state[param]
                
                if param.requires_grad and grad is not None:
                    
                    full_update = state['next_update_full'] # Get flag from previous iteration
                    
                    sub_grad = state['sub_grad']
                    sub_grad.add_( -lr * grad ) # Hooks have masked gradient implicitly
                    if full_update:
                        param.copy_(reg.prox(delta * sub_grad, delta))
                        # state['grad_norm'] = torch.norm(grad)
                        
                        mask = param.clone().detach() != 0
                        full_grad_norm = torch.norm(grad)
                        state['grad_norm'] = full_grad_norm
                        restrct_grad_norm = torch.norm(grad*mask)
                        next_update_full = bool(restrct_grad_norm < kappa * full_grad_norm or restrct_grad_norm/grad.numel() < eps)
                    
                        
                        # next_update_full = False # How to allow for multiple full updates? #!!!
                    else: # Sparse update
                        # Naive implementation of restricted subgrad
                        param.copy_(reg.prox(delta * sub_grad * state['mask'], delta))
                        
                        full_grad_norm = state['grad_norm']
                        restrct_grad_norm = grad.norm()
                        # print(f'Full g: {state["grad_norm"]/grad.numel():.3f} | Rest g: {restrct_grad_norm/grad.numel():.3f}')
                        
                        next_update_full = bool(restrct_grad_norm < kappa * full_grad_norm or restrct_grad_norm/grad.numel() < eps)
                    
                    # Current update | Next update | Hook modification
                    #   Full         |   Full      |   Nothing
                    #   Sparse       |   Sparse    |   Nothing
                    #   Full         |   Sparse    |   Redefine hook
                    #   Sparse       |   Full      |   Remove hook
                    if full_update and not next_update_full:
                        self.update_hook(param, mask)
                        state['mask'] = mask
                        # sub_grad *= mask # Fully commit to restricted subgrad
                    elif not full_update and next_update_full:
                        self.remove_hook(param)
                        
                    state['next_update_full'] = next_update_full
                    state['sparse_flag'].append(full_update)
                state['step'] += 1
           
            
    def initialize_sub_grad(self, param, reg, delta):
        p = param.data.clone()
        return 1/delta * p + reg.sub_grad(p)
    
    @torch.no_grad()
    def evaluate_reg(self):
        reg_vals = []
        for group in self.param_groups:
            group_reg_val = 0.0
            
            # define regularizer for this group
            reg = group['reg']
            
            # evaluate the reguarizer for each parametr in group
            for p in group['params']:
                group_reg_val += reg(p)
                
            # append the group reg val
            reg_vals.append(group_reg_val)
            
        return reg_vals


#%% ------------------------------------------------------------------------------------------------------    
class ProxSGD(torch.optim.Optimizer):
    """
    Taken from the BregmanLearning repo:
        https://github.com/TimRoith/BregmanLearning
    """
    def __init__(self,params,lr=1e-3,reg=regularizers.reg_none()):
        if lr < 0.0:
            raise ValueError("Invalid learning rate")
            
        defaults = dict(lr=lr, reg=reg)
        super(ProxSGD, self).__init__(params, defaults)
        
    @torch.no_grad()
    def step(self, closure=None):
        for group in self.param_groups:
            # define regularizer for this group
            reg = group['reg'] 
            step_size = group['lr']
            for p in group['params']:
                if p.grad is None:
                    continue
                # get grad and state
                grad = p.grad.data
                state = self.state[p]
                if len(state) == 0:
                    state['step'] = 0
                    
                # -------------------------------------------------------------
                # update scheme
                # -------------------------------------------------------------               
                # gradient steps
                p.data.add_(-step_size * grad)
                # proximal step
                p.data = reg.prox(p.data, step_size)
                
    @torch.no_grad()
    def evaluate_reg(self):
        reg_vals = []
        for group in self.param_groups:
            group_reg_val = 0.0
            # define regularizer for this group
            reg = group['reg']
            
            # evaluate the reguarizer for each parametr in group
            for p in group['params']:
                group_reg_val += reg(p)
                
            # append the group reg val
            reg_vals.append(group_reg_val)
            
        return reg_vals
                   
#%% ------------------------------------------------------------------------------------------------------           
class AdaBreg(torch.optim.Optimizer):
    """
    Taken from the BregmanLearning repo:
        https://github.com/TimRoith/BregmanLearning
    """
    def __init__(self,params,lr=1e-3,reg=regularizers.reg_none(), delta=1.0, betas=(0.9, 0.999), eps=1e-8):
        if lr < 0.0:
            raise ValueError("Invalid learning rate")
            
        defaults = dict(lr=lr, reg=reg, delta=delta, betas=betas, eps=eps)
        super(AdaBreg, self).__init__(params, defaults)
        
    @torch.no_grad()
    def step(self, closure=None):
        for group in self.param_groups:
            delta = group['delta']
            # get regularizer for this group
            reg = group['reg']
            # get parameters for adam
            lr = group['lr']
            beta1, beta2 = group['betas']
            eps = group['eps']
            for p in group['params']:
                if p.grad is None:
                    continue
                # get grad and state
                grad = p.grad.data
                state = self.state[p]
                if len(state) == 0:
                    state['step'] = 0
                    # get prox
                    # initialize subgradients
                    state['sub_grad'] = self.initialize_sub_grad(p,reg, delta)
                    state['exp_avg'] = torch.zeros_like(state['sub_grad'])
                    state['exp_avg_sq'] = torch.zeros_like(state['sub_grad'])
                # -------------------------------------------------------------
                # update scheme
                # -------------------------------------------------------------
                # update step
                state['step'] += 1
                step = state['step']
                # get the current sub gradient and averages
                sub_grad = state['sub_grad']
                exp_avg = state['exp_avg']
                exp_avg_sq = state['exp_avg_sq']
                
                # define bias correction factors
                bias_correction1 = 1 - beta1 ** step
                bias_correction2 = 1 - beta2 ** step

                # Decay the first and second moment running average coefficient
                exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)

                # denominator in the fraction
                denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
                
                # step size in adam update
                step_size = lr / bias_correction1
                
                # update subgrad
                sub_grad.addcdiv_(exp_avg, denom, value=-step_size)
                
                # update step for parameters
                p.data = reg.prox(delta * sub_grad, delta)
        
    def initialize_sub_grad(self,p, reg, delta):
        p_init = p.data.clone()
        return 1/delta * p_init + reg.sub_grad(p_init)
    
    @torch.no_grad()
    def evaluate_reg(self):
        reg_vals = []
        for group in self.param_groups:
            group_reg_val = 0.0
            
            # define regularizer for this group
            reg = group['reg']
            
            # evaluate the reguarizer for each parametr in group
            for p in group['params']:
                group_reg_val += reg(p)
                
            # append the group reg val
            reg_vals.append(group_reg_val)
            
        return reg_vals
 

#%% Constructors of optimizers
def calc_full_update_freq(conf, train_split = 0.9):
    """
    Two modes for freezing non-zero parameters and performing full updates:
        - step mode : Full update is performed every few training steps for a 
                            single step(/batch item)
        - epoch mode : Full updates are performed every few epochs but for 
                            the entire duration of said epoch
    """
    if conf.full_update_mode =='step':
        return conf.full_update_frequency, conf.full_update_duration
    
    if conf.dataset == 'MNIST': # Size of training data
        dataset_size = 60000
    elif conf.dataset == 'CIFAR10':
        dataset_size = 50000
    else:
        raise NotImplementedError('Write code that tells me what size the whole dataset is')
    if conf.max_samples:
        dataset_size = min(dataset_size, conf.max_samples)
    train_size = dataset_size * train_split
    batches_in_epoch = math.ceil(train_size / conf.batch_size)
    
    return (conf.full_update_frequency-1)*batches_in_epoch + 1, conf.full_update_duration*batches_in_epoch
    

def get_opt(conf, model, opt_state_dict=None, scheduler_state_dict=None):
    """
    Prepare model for training. Return the specified optimiser. All optimisers
    take the standard approach of having parameters as an input, rather than 
    the entire model.
    """
        
    if conf.optim == "SGD":
        opt = torch.optim.SGD(model.parameters(), lr=conf.learning_rate,
                              momentum=conf.momentum)
    elif conf.optim == "SGD-sparse":
        opt = torch.optim.SGD(model.parameters(), lr=conf.learning_rate)
        utils.make_optim_sparse(model, opt, conf.full_update_frequency)
    elif conf.optim == "Adam":
        opt = torch.optim.Adam(model.parameters(), lr=conf.learning_rate)
    elif conf.optim == 'LinBreg':
        weights_conv = utils.get_weights_conv(model)
        weights_linear = utils.get_weights_linear(model)
        weights_batch = utils.get_weights_batch(model)
        biases = utils.get_bias(model)
        regularizer = regularizers.get_reg(conf.reg, conf.lambda1)
        if conf.conv_group:
            reg2 = regularizers.reg_l1_l2_conv(lamda=conf.lambda0)
        else:
            reg2 = regularizers.reg_l1(lamda=conf.lambda0)

        opt = LinBreg([{'params': weights_conv, 'lr' : conf.learning_rate, 'reg' : reg2,'delta':conf.delta, 'momentum':conf.momentum},
                           {'params': weights_linear, 'lr' : conf.learning_rate, 'reg' : regularizer,'delta':conf.delta, 'momentum':conf.momentum},
                           {'params': weights_batch, 'lr': conf.learning_rate, 'momentum':conf.momentum},
                           {'params': biases, 'lr': conf.learning_rate, 'momentum':conf.momentum}])
    elif conf.optim == 'ProxSGD':
        regularizer = regularizers.get_reg(conf.reg, conf.lambda1)
        opt = ProxSGD(model.parameters(), lr=conf.learning_rate, reg=regularizer)
    elif conf.optim == "LinBregSparse":
        weights_conv = utils.get_weights_conv(model)
        weights_linear = utils.get_weights_linear(model)
        weights_batch = utils.get_weights_batch(model)
        biases = utils.get_bias(model)
        regularizer = regularizers.get_reg(conf.reg, conf.lambda1)
        if conf.conv_group:
            reg2 = regularizers.reg_l1_l2_conv(lamda=conf.lambda0)
        else:
            reg2 = regularizers.reg_l1(lamda=conf.lambda0)
            
        full_update_frequency, full_update_duration = calc_full_update_freq(conf)

        opt = LinBregSparse([{'params': weights_conv, 'lr' : conf.learning_rate, 'reg' : reg2,'delta':conf.delta, 'momentum':conf.momentum,'full_update_frequency':full_update_frequency,'full_update_duration':full_update_duration},
                           {'params': weights_linear, 'lr' : conf.learning_rate, 'reg' : regularizer,'delta':conf.delta, 'momentum':conf.momentum, 'full_update_frequency':full_update_frequency,'full_update_duration':full_update_duration},
                           {'params': weights_batch, 'lr': conf.learning_rate,'momentum':conf.momentum},
                           {'params': biases, 'lr': conf.learning_rate,'momentum':conf.momentum}])

        
    elif conf.optim == "LinBregSparseML":
        weights_conv = utils.get_weights_conv(model)
        weights_linear = utils.get_weights_linear(model)
        biases = utils.get_bias(model)
        regularizer = regularizers.get_reg(conf.reg, conf.lambda1)
        if conf.conv_group:
            reg2 = regularizers.reg_l1_l2_conv(lamda=conf.lambda0)
        else:
            reg2 = regularizers.reg_l1(lamda=conf.lambda0)
        opt = LinBregSparseML([{'params': weights_conv, 'lr' : conf.learning_rate, 'reg' : reg2,'delta':conf.delta, 'kappa':conf.kappa, 'eps':conf.eps},
                           {'params': weights_linear, 'lr' : conf.learning_rate, 'reg' : regularizer,'delta':conf.delta, 'kappa':conf.kappa, 'eps':conf.eps},
                           {'params': biases, 'lr': conf.learning_rate, 'kappa':conf.kappa, 'eps':conf.eps}])

    elif conf.optim == "AdaBreg":
        weights_conv = utils.get_weights_conv(model)
        weights_linear = utils.get_weights_linear(model)
        weights_batch = utils.get_weights_batch(model)
        biases = utils.get_bias(model)
        regularizer = regularizers.get_reg(conf.reg, conf.lambda1)        
        if conf.conv_group:
            reg2 = regularizers.reg_l1_l2_conv(lamda=conf.lambda0)
        else:
            reg2 = regularizers.reg_l1(lamda=conf.lambda0)
        opt = AdaBreg([{'params': weights_conv, 'lr' : conf.learning_rate, 'reg' : reg2,'delta':conf.delta},
                           {'params': weights_linear, 'lr' : conf.learning_rate, 'reg' : regularizers.reg_l1(lamda=conf.lambda1),'delta':conf.delta},
                           {'params': weights_batch, 'lr': conf.learning_rate},
                           {'params': biases, 'lr': conf.learning_rate}])

    elif conf.optim == 'AdaBregSparse':
        weights_conv = utils.get_weights_conv(model)
        weights_linear = utils.get_weights_linear(model)
        weights_batch = utils.get_weights_batch(model)
        biases = utils.get_bias(model)
        regularizer = regularizers.get_reg(conf.reg, conf.lambda1)
        if conf.conv_group:
            reg2 = regularizers.reg_l1_l2_conv(lamda=conf.lambda0)
        else:
            reg2 = regularizers.reg_l1(lamda=conf.lambda0)
            
        full_update_frequency, full_update_duration = calc_full_update_freq(conf)

        opt = AdaBregSparse([{'params': weights_conv, 'lr' : conf.learning_rate, 'reg' : reg2,'delta':conf.delta, 'full_update_frequency':full_update_frequency},
                           {'params': weights_linear, 'lr' : conf.learning_rate, 'reg' : regularizer,'delta':conf.delta, 'full_update_frequency':full_update_frequency},
                           {'params': weights_batch, 'lr': conf.learning_rate},
                           {'params': biases, 'lr': conf.learning_rate}])
    
    elif conf.optim == "debug":
        opt = None
        
    else:
        raise NotImplementedError(f'Choice of optimiser "{conf.optim}" not recognised')
    
    
    # Determine scheduler
    if conf.lr_scheduler == 'CosineAnnealing':
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=200, eta_min=1e-5)
    elif conf.lr_scheduler == 'ReduceLROnPlateau':
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, mode = 'min', factor = 0.5, patience=5, threshold = 1e-4, min_lr=1e-6)
    
    # Load states from checkpoints if provided
    if opt_state_dict:
        opt.load_state_dict(opt_state_dict)
    if scheduler_state_dict:
        scheduler.load_state_dict(scheduler_state_dict)
    
    return opt, scheduler

