from __future__ import division


import chainer
from chainer.training import extension
from chainer.training import extensions
chainer.config.cudnn_deterministic = True
chainer.config.autotune = False
import ipdb


def LinearShift(optimizer, opt_config, decay_config, **kwargs):
    decay_factor_name = opt_config['decay_factor']
    start_factor_value = opt_config['args'][decay_factor_name]
    decay_factor_value = decay_config['args']['end_factor_value']
    start_iter = decay_config['args']['start_iter']
    end_iter = decay_config['args']['end_iter']
    lr_decay_func = extensions.LinearShift(decay_factor_name,
                                           (start_factor_value,
                                            decay_factor_value),
                                           (start_iter, end_iter),
                                           optimizer=optimizer)
    trigger = None
    return lr_decay_func, trigger
    
    
def ExponentialShift(optimizer, opt_config, decay_config, **kwargs):
    decay_factor_name = opt_config['decay_factor']
    decay_ratio = decay_config['args']['decay_ratio']
    ld_interval = decay_config['args']['ld_interval']
    lr_decay_func = extensions.ExponentialShift(decay_factor_name,
                                                decay_ratio,
                                                optimizer=optimizer)
    trigger = (ld_interval, 'iteration')
    return lr_decay_func, trigger
    
    
def in_middle_modification(trainer, module, newkwargs):
    @chainer.training.make_extension()
    def make_change(trainer):

        #This works
        trainer.updater.penalty_weight=newkwargs['penalty_weight']

        #DOING THIS SOMEHOW FREEZES THE program!!
        #module_new_version = module(**newkwargs)
        #ipdb.set_trace()

        #trainer.updater = module_new_version

    return make_change
