import os
import sys
sys.path.insert(0, './')
import numpy as np

import torch
import torch.nn as nn

# from fastDP import PrivacyEngine

h_message = '''
instructions for setting an optimizer
>>> SGD
name=sgd,lr=$LR$,momentum=$0.9$,dampening=$0$,weight_decay=$0$

>>> Adam
name=adam,lr=$LR$,beta1=$0.9$,beta2=$0.999$,eps=$1e-8$,weight_decay=$0$,amsgrad=$0$
'''

def parse_optim(policy, params, model=None):

    kwargs = {}
    if 'weight_decay' in policy.keys() and 'not_wd_bn' in policy.keys() and policy['not_wd_bn']:
        decay, no_decay = [], []
        for name, param in params:
            if 'bn' not in name and 'bias' not in name:
                decay.append(param)
            else:
                no_decay.append(param)
        params = [{'params': decay, 'weight_decay': policy['weight_decay']},
                  {'params': no_decay, 'weight_decay': 0}]
        print('not using wd for bn layers')
    else:
        params = [param for name, param in params]

    if policy['name'].lower() in ['h', 'help']:
        print(h_message)
        exit(0)
    elif policy['name'].lower() in ['sgd']:
        kwargs['lr'] = policy['lr']
        kwargs['momentum'] = policy['momentum'] if 'momentum' in policy else 0.9
        kwargs['dampening'] = policy['dampening'] if 'dampening' in policy else 0
        kwargs['weight_decay'] = policy['weight_decay'] if 'weight_decay' in policy else 0
        optimizer = torch.optim.SGD(params, **kwargs)
    elif policy['name'].lower() in ['adam']:
        kwargs['lr'] = policy['lr']
        kwargs['betas'] = (policy['beta1'] if 'beta1' in policy else 0.9, policy['beta2'] if 'beta2' in policy else 0.999)
        kwargs['eps'] = policy['eps'] if 'eps' in policy else 1e-8
        kwargs['weight_decay'] = policy['weight_decay'] if 'weight_decay' in policy else 0.
        kwargs['amsgrad'] = True if 'amsgrad' in policy and np.abs(policy['amsgrad']) > 1e-6 else False
        optimizer = torch.optim.Adam(params, **kwargs)

    elif policy['name'].lower() in ['dpsgd'] and model is not None:
        kwargs['lr'] = policy['lr']
        kwargs['weight_decay'] = policy['weight_decay'] if 'weight_decay' in policy else 0
        optimizer = torch.optim.SGD(params, **kwargs)

        privacy_engine = PrivacyEngine(
            module=model,
            batch_size=policy['batch_size'],
            epochs=policy['epoch_num'],
            sample_size=policy['sample_size'] if policy['sample_size'] is not None else None,
            max_grad_norm=policy['max_grad_norm'] if 'max_grad_norm' in policy.keys() else None,
            target_delta=policy['target_delta'] if 'target_delta' in policy.keys() else None,
            target_epsilon=policy['target_epsilon'] if 'target_epsilon' in policy.keys() else None,
        )
        privacy_engine.attach(optimizer)

    else:
        raise NotImplementedError('Optimizer %s is not supported' % policy['name'])

    return optimizer
