import numpy as np

from torch.optim import Adam
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR
from torch.optim.lr_scheduler import StepLR

def optimizer_factory(config):
    if 'optim' in config and 'optimizer_type' in config['optim']:
        if config['optim']['optimizer_type'] == 'adam_wo_scheduler':
            return AdamOptimizerWithNoScheduler(
                lrate=config['optim']['lrate'],
                betas=config['optim']['betas'],
                weight_decay=config['optim']['weight_decay']
            )
        elif config['optim']['optimizer_type'] == 'adamw_wo_scheduler':
            return AdamWOptimizerWithNoScheduler(
                lrate=config['optim']['lrate'],
                betas=config['optim']['betas'],
                weight_decay=config['optim']['weight_decay']
            )
        elif config['optim']['optimizer_type'] == 'adam_w_warmup_to_constant_scheduler':
            return AdamWOptimizerWithWarmupToConstantScheduler(
                lrate=config['optim']['lrate'],
                betas=config['optim']['betas'],
                weight_decay=config['optim']['weight_decay'],
                warmup_steps=config['optim']['warmup_steps']
            )
        elif config['optim']['optimizer_type'] == 'adam_w_warmup_scheduler':
            return AdamWOptimizerWithWarmupScheduler(
                max_lr=config['optim']['max_lr'],
                betas=config['optim']['betas'],
                weight_decay=config['optim']['weight_decay'],
                warmup_steps=config['optim']['warmup_steps']
            )
        elif config['optim']['optimizer_type'] == 'adam_w_decay_scheduler':
            return AdamWOptimizerWithDecayScheduler(
                lrate=config['optim']['lrate'],
                betas=config['optim']['betas'],
                step_size=config['optim']['step_size'],
                gamma=config['optim']['gamma'],
                weight_decay=config['optim']['weight_decay']
            )
        else:
            raise ValueError('Unknown optimizer_type')
    else:
        raise ValueError('config has no key ''optimizer_type''')

class AdamOptimizerWithNoScheduler(object):
    def __init__(self, lrate, betas, weight_decay):
        self._lrate = lrate
        self._betas = betas
        self._weight_decay = weight_decay

        self._optimizer = None
        self._scheduler = None

    def instantiate_optimizer(self, net_parameters):
        self._optimizer = Adam(
            net_parameters, lr=self._lrate, betas=self._betas, weight_decay=self._weight_decay)
        self._scheduler = None

class AdamWOptimizerWithNoScheduler(object):
    def __init__(self, lrate, betas, weight_decay):
        self._lrate = lrate
        self._betas = betas
        self._weight_decay = weight_decay

        self._optimizer = None
        self._scheduler = None

    def instantiate_optimizer(self, net_parameters):
        self._optimizer = AdamW(
            net_parameters, lr=self._lrate, betas=self._betas, weight_decay=self._weight_decay)
        self._scheduler = None

class AdamWOptimizerWithWarmupToConstantScheduler(object):
    def __init__(self, lrate, betas, weight_decay, warmup_steps):
        self._lrate = lrate
        self._betas = betas
        self._weight_decay = weight_decay
        self._warmup_steps = warmup_steps

        self._optimizer = None
        self._scheduler = None

    def instantiate_optimizer(self, net_parameters):
        lrate_fn = lambda step_num: np.minimum(1.0 * (step_num + 1) / self._warmup_steps, 1.0)

        self._optimizer = AdamW(
            net_parameters, lr=lrate, betas=self._betas, weight_decay=self._weight_decay)
        self._scheduler = LambdaLR(self._optimizer, lrate_fn)

class AdamWOptimizerWithDecayScheduler(object):
    def __init__(
            self, lrate, betas, weight_decay, step_size, gamma):
        self._lrate = lrate
        self._betas = betas
        self._step_size = step_size
        self._gamma = gamma
        self._weight_decay = weight_decay

        self._optimizer = None
        self._scheduler = None

    def instantiate_optimizer(self, net_parameters):
        self._optimizer = AdamW(
            net_parameters, lr=self._lrate, betas=self._betas, weight_decay=self._weight_decay)
        self._scheduler = StepLR(self._optimizer, step_size=step_size, gamma=self._gamma)

class AdamWOptimizerWithWarmupScheduler(object):
    def __init__(self, max_lr, betas, weight_decay, warmup_steps):
        self._max_lr = max_lr
        self._betas = betas
        self._weight_decay = weight_decay
        self._warmup_steps = warmup_steps

        self._optimizer = None
        self._scheduler = None

    def instantiate_optimizer(self, net_parameters):
        lrate = self._max_lr / np.power(self._warmup_steps, -0.5)
        lrate_fn = lambda step_num: np.minimum(
            np.power(step_num + 1, -0.5), (step_num + 1) * np.power(self._warmup_steps, -1.5))

        self._optimizer = AdamW(
            net_parameters, lr=lrate, betas=self._betas, weight_decay=self._weight_decay)
        self._scheduler = LambdaLR(self._optimizer, lrate_fn)
