from torch import nn

# Admission
admission_dict = {
    "vae_path": "./Recourse_Methods/Generative_Model/Saved_Models/vae_admission.pt",
    "path": "./Data_Sets/Admission/",
    "filename_train": "admission-train.csv",
    "filename_test": "admission-test.csv",
    "label": "zfya",
    "task": "regression",
    "batch_size": 32,
    "lr": 0.002,
    'epochs': 50,
    "d": 3,
    "H1": 5,
    "H2": 10,
    "activFun": nn.Softplus(),
    'lambda_reg': 1e-6
}

# Diabetes
diabetes_dict = {
    "vae_path": "./Recourse_Methods/Generative_Model/Saved_Models/vae_diabetes.pt",
    "path": "./Data_Sets/Diabetes/",
    "filename_train": "diabetes-train.csv",
    "filename_test": "diabetes-test.csv",
    "label": "readmitted",
    "task": "classification",
    "batch_size": 32,
    "lr": 0.002,
    'epochs': 50,
    "d": 20,
    "H1": 25,
    "H2": 30,
    "activFun": nn.Softplus(),
    'lambda_reg': 1e-6
}

# Heloc
heloc_dict = {
    "vae_path": "./Recourse_Methods/Generative_Model/Saved_Models/vae_heloc.pt",
    "path": "./Data_Sets/Heloc/",
    "filename_train": "heloc-train.csv",
    "filename_test": "heloc-test.csv",
    "label": "ExternalRiskEstimate",
    "task": "regression",
    "batch_size": 32,
    "lr": 0.002,
    'epochs': 50,
    "d": 12,
    "H1": 25,
    "H2": 25,
    "activFun": nn.Softplus(),
    'lambda_reg': 1e-6
}

# German
german_dict = {
    "vae_path": "./Recourse_Methods/Generative_Model/Saved_Models/vae_german.pt",
    "path": "./Data_Sets/German/",
    "filename_train": "german-train.csv",
    "filename_test": "german-test.csv",
    "label": "credit-risk",
    "task": "classification",
    "batch_size": 32,
    "lr": 0.002,
    'epochs': 50,
    "d": 3,
    "H1": 6,
    "H2": 14,
    "activFun": nn.Softplus(),
    'lambda_reg': 1e-6
}

# Adult
adult_dict = {
    "vae_path": "./Recourse_Methods/Generative_Model/Saved_Models/vae_adult.pt",
    "path": "./Data_Sets/Adult/",
    "filename_train": 'adult-train.csv',
    "filename_test": 'adult-test.csv',
    "label": 'income',
    "task": "classification",
    "batch_size": 256,
    "lr": 0.002,
    "epochs": 50,
    "d": 9,
    "H1": 15,
    "H2": 25,
    "activFun": nn.Softplus(),
    'lambda_reg': 1e-6
}

# Compas
compas_dict = {
    "vae_path": "./Recourse_Methods/Generative_Model/Saved_Models/vae_compas.pt",
    "path": "./Data_Sets/COMPAS/",
    "filename_train": 'compas-train.csv',
    "filename_test": 'compas-test.csv',
    "label": "risk",
    "task": "classification",
    "batch_size": 32,
    "lr": 0.002,
    "epochs": 50,
    "d": 6,
    "H1": 8,
    "H2": 10,
    "activFun": nn.Softplus(),
    'lambda_reg': 1e-6
}

# Meta
vae_meta_dictionary = {
    "heloc": heloc_dict,
    "diabetes": diabetes_dict,
    "compas": compas_dict,
    "adult": adult_dict,
    "german": german_dict,
    "admission": admission_dict
}
