import json
import os
import pickle

from functools import partial
from typing import Dict

import pandas as pd

from method.Optimizers.custom_apgo import AProx, AProx_NoAdaptive, AProx_NoMomentum, AProx_NoProx
from torch.optim import Adam, Adadelta, Adagrad, AdamW, SparseAdam, ASGD, SGD, LBFGS, RMSprop, Rprop

################################# Wrappers ################################################

def optimizer_wrapper(args, pred_model_parameters):
    optimizer_dict = {
        'AProx': AProx,
        'AProx_NoProx': AProx_NoProx,
        'AProx_NoAdaptive': AProx_NoAdaptive,
        'AProx_NoMomentum': AProx_NoMomentum,
        'AProx_NoWeightDecay': AProx,
        'Adam': Adam,
        'Adadelta': Adadelta,
        'Adagrad': Adagrad,
        'AdamW': AdamW,
        'RMSpp': RMSprop,
        'Rpp': Rprop,
        'SparseAdam': SparseAdam,
        'SGD': SGD
    }
    if optimizer_dict[args.optimizer] in ['AProx', 'AProx_NoProx', 'AProx_NoAdaptive', 'AProx_NoMomentum']:
        return optimizer_dict[args.optimizer](
            pred_model_parameters,
            lr=args.lr,
            alpha=0.99,
            beta=0.9,
            epsilon=1e-8,
            weight_decay=args.weight_decay
        )
    elif optimizer_dict[args.optimizer] in ['AProx_NoWeightDecay']:  # modify in exp
        return optimizer_dict[args.optimizer](
            pred_model_parameters,
            lr=args.lr,
            alpha=0.99,
            beta=0.9,
            epsilon=1e-8,
            weight_decay=0
        )

    else:
        return optimizer_dict[args.optimizer](
            pred_model_parameters,
            lr=args.lr
        )
