from sklearn.ensemble import AdaBoostClassifier, BaggingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.neural_network import MLPClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.naive_bayes import GaussianNB, BernoulliNB, MultinomialNB
from sklearn.ensemble import RandomForestClassifier, ExtraTreesClassifier

from synthesizers.mwem import MWEMSynthesizer
from synthesizers.quail import QUAILSynthesizer
from synthesizers.pytorch.pytorch_synthesizer import PytorchDPSynthesizer
from synthesizers.preprocessors.preprocessing import GeneralTransformer
from synthesizers.pytorch.nn.all import DPGAN, PATEGAN, DPCTGAN, PATECTGAN

from diffprivlib.models import LogisticRegression as DPLR
from diffprivlib.models import GaussianNB as DPGNB

# Keep seed consistent for reproducibility 
SEED = 42

# Turn on/off balancing imbalanced data with SMOTE
BALANCE = True

# Turn on/off the synthesizers you want to use in eval here
SYNTHESIZERS = [
    ('mwem', MWEMSynthesizer),
    ('dpctgan', PytorchDPSynthesizer),
    ('patectgan', PytorchDPSynthesizer),
    ('dpgan',PytorchDPSynthesizer),
    ('pategan',PytorchDPSynthesizer),
    ('quail_dpctgan', QUAILSynthesizer),
    ('quail_mwem', QUAILSynthesizer),
    ('quail_patectgan', QUAILSynthesizer),
    ('quail_dpgan',QUAILSynthesizer),
    ('quail_pategan',QUAILSynthesizer),
]

# Add datasets on which to evaluate synthesis
KNOWN_DATASETS =  ['bank','adult','mushroom','shopping','car'] 

# Add ML models on which to evaluate utility
KNOWN_MODELS = [AdaBoostClassifier, BaggingClassifier,
               LogisticRegression, MLPClassifier,
               RandomForestClassifier] 

# Mirror strings for ML models, to log
KNOWN_MODELS_STR = ['AdaBoostClassifier', 'BaggingClassifier',
               'LogisticRegression', 'MLPClassifier',
               'GaussianNB', 'RandomForestClassifier']

SYNTH_SETTINGS = {
    'dpctgan': {
        'default': {
            'preprocessor': None,
            'gan': DPCTGAN(epochs=300, loss='wasserstein', sigma = 4.5)
        }
    },
    'patectgan': {
        'default': {
            'preprocessor': None,
            'gan': PATECTGAN(epochs=300, sigma = 4.5)
        },
    },
    'dpgan': {
        'default': {
            'preprocessor': GeneralTransformer(),
            'gan': DPGAN(epochs=300)
        },
    },
    'pategan': {
        'default': {
            'preprocessor': GeneralTransformer(),
            'gan': PATEGAN(batch_size=640)
        }
    },
    'mwem': {
        'car': {
            'Q_count':400, 
            'iterations':20,
            'mult_weights_iterations': 15,
            'split_factor':7,
            'max_bin_count':400
        },
        'mushroom': {
            'Q_count':400,
            'iterations':30,
            'mult_weights_iterations':20,
            'split_factor':4,
            'max_bin_count':400
        },
        'bank': {
            'Q_count':400,
            'iterations':25,
            'mult_weights_iterations':15,
            'split_factor':3,
            'max_bin_count':200
        },
        'adult': {
            'Q_count': 400,
            'iterations':20,
            'mult_weights_iterations':15,
            'splits':[[0,1,2],[3,4,5],[6,7,8],[9,10],[11,12],[13,14]],
            'max_bin_count':100
        },
        'shopping': {
            'Q_count':400,
            'iterations':30,
            'mult_weights_iterations':20,
            'split_factor':2,
            'max_bin_count':400
        },
        'default': {
            'Q_count':400,
            'iterations':30,
            'mult_weights_iterations':20,
            'split_factor':3,
            'max_bin_count':400
        },
    },
    'quail_dpctgan': {
        'mushroom': {
            'dp_synthesizer': PytorchDPSynthesizer,
            'synth_args': {
                'preprocessor': None,
                'gan': DPCTGAN(epochs=300, loss='wasserstein', sigma = 4.5)
            },
            'dp_classifier': DPLR,
            'class_args': {},
            'target': 'edible',
            'eps_split': 0.8
        },
        'bank': {
            'dp_synthesizer': PytorchDPSynthesizer,
            'synth_args': {
                'preprocessor': None,
                'gan': DPCTGAN(epochs=300, loss='wasserstein', sigma = 4.5)
            },
            'dp_classifier': DPLR,
            'class_args': {},
            'target': 'yesno',
            'eps_split': 0.8
        },
        'shopping': {
            'dp_synthesizer': PytorchDPSynthesizer,
            'synth_args': {
                'preprocessor': None,
                'gan': DPCTGAN(epochs=300, loss='wasserstein', sigma = 4.5)
            },
            'dp_classifier': DPLR,
            'class_args': {},
            'target': 'Revenue',
            'eps_split': 0.8
        },
        'adult': {
            'dp_synthesizer': PytorchDPSynthesizer,
            'synth_args': {
                'preprocessor': None,
                'gan': DPCTGAN(epochs=300, loss='wasserstein', sigma = 4.5)
            },
            'dp_classifier': DPLR,
            'class_args': {},
            'target': 'earning-class',
            'eps_split': 0.8
        },
        'default': {
            'dp_synthesizer': PytorchDPSynthesizer,
            'synth_args': {
                'preprocessor': None,
                'gan': DPCTGAN(epochs=300, loss='wasserstein', sigma = 4.5)
            },
            'dp_classifier': DPLR,
            'class_args': {},
            'target': 'class',
            'eps_split': 0.8
        },
    },
    'quail_patectgan': {
        'mushroom': {
            'dp_synthesizer': PytorchDPSynthesizer,
            'synth_args': {
                'preprocessor': None,
                'gan': PATECTGAN(epochs=300, sigma = 4.5)
            },
            'dp_classifier': DPLR,
            'class_args': {},
            'target': 'edible',
            'eps_split':0.8
        },
        'bank': {
            'dp_synthesizer': PytorchDPSynthesizer,
            'synth_args': {
                'preprocessor': None,
                'gan': PATECTGAN(epochs=300, sigma = 4.5)
            },
            'dp_classifier': DPLR,
            'class_args': {},
            'target': 'yesno',
            'eps_split':0.8
        },
        'shopping': {
            'dp_synthesizer': PytorchDPSynthesizer,
            'synth_args': {
                'preprocessor': None,
                'gan': PATECTGAN(epochs=300, sigma = 4.5)
            },
            'dp_classifier': DPLR,
            'class_args': {},
            'target': 'Revenue',
            'eps_split':0.8
        },
        'adult': {
            'dp_synthesizer': PytorchDPSynthesizer,
            'synth_args': {
                'preprocessor': None,
                'gan': PATECTGAN(epochs=300, sigma = 4.5)
            },
            'dp_classifier': DPLR,
            'class_args': {},
            'target': 'earning-class',
            'eps_split':0.8
        },
        'default': {
            'dp_synthesizer': PytorchDPSynthesizer,
            'synth_args': {
                'preprocessor': None,
                'gan': PATECTGAN(epochs=300, sigma = 4.5)
            },
            'dp_classifier': DPLR,
            'class_args': {},
            'target': 'class',
            'eps_split':0.8
        },
    },
    'quail_dpgan': {
        'mushroom': {
            'dp_synthesizer': PytorchDPSynthesizer,
            'synth_args': {
                'preprocessor': GeneralTransformer(),
                'gan': DPGAN(epochs=300)
            },
            'dp_classifier': DPLR,
            'class_args': {},
            'target': 'edible',
            'eps_split':0.8
        },
        'bank': {
            'dp_synthesizer': PytorchDPSynthesizer,
            'synth_args': {
                'preprocessor': GeneralTransformer(),
                'gan': DPGAN(epochs=300)
            },
            'dp_classifier': DPLR,
            'class_args': {},
            'target': 'yesno',
            'eps_split':0.8
        },
        'shopping': {
            'dp_synthesizer': PytorchDPSynthesizer,
            'synth_args': {
                'preprocessor': GeneralTransformer(),
                'gan': DPGAN(epochs=300)
            },
            'dp_classifier': DPLR,
            'class_args': {},
            'target': 'Revenue',
            'eps_split':0.8
        },
        'adult': {
            'dp_synthesizer': PytorchDPSynthesizer,
            'synth_args': {
                'preprocessor': GeneralTransformer(),
                'gan': DPGAN(epochs=300)
            },
            'dp_classifier': DPLR,
            'class_args': {},
            'target': 'earning-class',
            'eps_split':0.8
        },
        'default': {
            'dp_synthesizer': PytorchDPSynthesizer,
            'synth_args': {
                'preprocessor': GeneralTransformer(),
                'gan': DPGAN(epochs=300)
            },
            'dp_classifier': DPLR,
            'class_args': {},
            'target': 'class',
            'eps_split':0.8
        }
    },
    'quail_pategan': {
        'mushroom': {
            'dp_synthesizer': PytorchDPSynthesizer,
            'synth_args': {
                'preprocessor': GeneralTransformer(),
                'gan': PATEGAN(batch_size=640)
            },
            'dp_classifier': DPLR,
            'class_args': {},
            'target': 'edible',
            'eps_split':0.8
        },
        'bank': {
            'dp_synthesizer': PytorchDPSynthesizer,
            'synth_args': {
                'preprocessor': GeneralTransformer(),
                'gan': PATEGAN(batch_size=640)
            },
            'dp_classifier': DPLR,
            'class_args': {},
            'target': 'yesno',
            'eps_split':0.8
        },
        'shopping': {
            'dp_synthesizer': PytorchDPSynthesizer,
            'synth_args': {
                'preprocessor': GeneralTransformer(),
                'gan': PATEGAN(batch_size=640)
            },
            'dp_classifier': DPLR,
            'class_args': {},
            'target': 'Revenue',
            'eps_split':0.8
        },
        'adult': {
            'dp_synthesizer': PytorchDPSynthesizer,
            'synth_args': {
                'preprocessor': GeneralTransformer(),
                'gan': PATEGAN(batch_size=640)
            },
            'dp_classifier': DPLR,
            'class_args': {},
            'target': 'earning-class',
            'eps_split':0.8
        },
        'default': {
            'dp_synthesizer': PytorchDPSynthesizer,
            'synth_args': {
                'preprocessor': GeneralTransformer(),
                'gan': PATEGAN(batch_size=640)
            },
            'dp_classifier': DPLR,
            'class_args': {},
            'target': 'class',
            'eps_split':0.8
        }
    },
    'quail_mwem': {
        'car': {
            'dp_synthesizer': MWEMSynthesizer,
            'synth_args': {
                'Q_count':400, 
                'iterations':25,
                'mult_weights_iterations': 20,
                'split_factor':7,
                'max_bin_count':400
            },
            'dp_classifier': DPLR,
            'class_args': {},
            'target': 'class',
            'eps_split':0.8
        },
        'mushroom': {
            'dp_synthesizer': MWEMSynthesizer,
            'synth_args': {
                'Q_count':400,
                'iterations':30,
                'mult_weights_iterations':20,
                'split_factor':4,
                'max_bin_count':400
            },
            'dp_classifier': DPLR,
            'class_args': {},
            'target': 'edible',
            'eps_split':0.8
        },
        'bank': {
            'dp_synthesizer': MWEMSynthesizer,
            'synth_args': {
                'Q_count':400,
                'iterations':25,
                'mult_weights_iterations':15,
                'split_factor':3,
                'max_bin_count':300
            },
            'dp_classifier': DPLR,
            'class_args': {},
            'target': 'yesno',
            'eps_split':0.8
        },
        'adult': {
            'dp_synthesizer': MWEMSynthesizer,
            'synth_args': {
                'Q_count': 400,
                'iterations':20,
                'mult_weights_iterations':15,
                'splits':[[0,1,2],[3,4,5],[6,7,8],[9,10],[11,12],[13]],
                'max_bin_count':100
            },
            'dp_classifier': DPLR,
            'class_args': {},
            'target': 'earning-class',
            'eps_split':0.8
        },
        'shopping': {
            'dp_synthesizer': MWEMSynthesizer,
            'synth_args': {
                'Q_count':400,
                'iterations':30,
                'mult_weights_iterations':20,
                'split_factor':2,
                'max_bin_count':400
            },
            'dp_classifier': DPLR,
            'class_args': {},
            'target': 'Revenue',
            'eps_split':0.8
        },
        'default': {
            'dp_synthesizer': MWEMSynthesizer,
            'synth_args': {
                'Q_count':500,
                'iterations':30,
                'mult_weights_iterations':20,
                'split_factor':7,
                'max_bin_count':400
            },
            'dp_classifier': DPLR,
            'class_args': {},
            'target': 'class',
            'eps_split':0.8
        }
    }
}

MODEL_ARGS = {
    'AdaBoostClassifier': {
        'random_state': SEED,
        'n_estimators': 100
    },
    'BaggingClassifier': {
        'random_state': SEED
    },
    'LogisticRegression': {
        'random_state': SEED,
        'max_iter': 1000,
        'multi_class': 'auto',
        'solver': 'lbfgs'
    },
    'MLPClassifier': {
        'random_state': SEED,
        'max_iter': 2000,
        'early_stopping': True,
        'n_iter_no_change': 20
    },
    'DecisionTreeClassifier': {
        'random_state': SEED,
        'class_weight': 'balanced'
    },
    'RandomForestClassifier': {
        'random_state': SEED,
        'class_weight': 'balanced',
        'n_estimators': 200
    },
    'ExtraTreesClassifier': {
        'random_state': SEED,
        'class_weight': 'balanced',
        'n_estimators': 200
    }
}
