from argparse import Namespace
import os
from pathlib import Path

class Args(Namespace):
    dataset =  'adult'
    num_classes = 2
    output_col_name = 'income'
    split = 0.75
    fairness_metric = 'DemParity'
    
    teacher_query_set_split = 0.7
    num_teachers = 4
    threshold = 2
    fairness_threshold = 0.2
    sigma_threshold=60
    sigma_fair_threshold=2
    sigma_gnmax=25
    budget = 1000
    delta = 1e-5

def process_arguments(args: Namespace):
    if args.dataset == 'adult':
        args.path = './Datasets/Adult/adult_original_purified.csv'
        args.num_inp_attr = 102
        args.cols_to_norm = ['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', 'hours-per-week']
        args.sensitive_attributes = ['sex']
        args.num_sensitives = 2
        args.output_col_name = '>50K'
    elif args.dataset == 'retired-adult':
        args.path = './Datasets/Adult/Retired-Adult/adult_reconstruction_processed.csv'
        args.num_inp_attr = 101
        args.cols_to_norm = ['age', 'hours-per-week', 'education-num', 'capital-gain', 'capital-loss']
        args.sensitive_attributes = ['gender']
        args.num_sensitives = 2
        args.output_col_name = 'income'
    elif args.dataset == 'credit-card':
        args.path = './Datasets/CreditCard/credit-card-defaulters_processed.csv'
        args.num_inp_attr = 85
        args.cols_to_norm = ['LIMIT_BAL', 'AGE', 'BILL_AMT1', 'BILL_AMT2', 'BILL_AMT3', 'BILL_AMT4', 'BILL_AMT5', 'BILL_AMT6', 'PAY_AMT1', 'PAY_AMT2', 'PAY_AMT3', 'PAY_AMT4', 'PAY_AMT5', 'PAY_AMT6']
        args.sensitive_attributes = ['SEX']
        args.num_sensitives = 2
        args.output_col_name = 'default payment next month'
    elif args.dataset == 'credit-card-german':
        args.path = './Datasets/CreditCardGerman/credit-card-german.csv'
        args.num_inp_attr = 57
        args.cols_to_norm = ['duration', 'credit amount', 'installment rate', 'residence', 'age', 'credits', 'people liable']
        args.sensitive_attributes = ['status and sex']
        args.num_sensitives = 4
        args.output_col_name = 'risk'
    elif args.dataset == 'chit-default':
        args.path = './Datasets/ChitDefault/chit-default.csv'
        args.num_inp_attr = 57
        args.cols_to_norm = ['aucn_date', 'inst_due', 'inst_paid', 'inst_spread', 'total_inst_due', 'total_inst_paid', 'div_due', \
                             'div_paid', 'total_div_due', 'total_div_paid', 'all_bids', 'win_bid_amt', 'prized_amt', 'month', 'penalty', 'other_cost', \
                             'last_trans_date', 'last_payment_date', 'diff_inst', 'no_trans', 'total_trans', 'monthly_income', 'age']
        args.sensitive_attributes = ['sex']
        args.num_sensitives = 3
        args.output_col_name = 'default'
    elif args.dataset == 'chit-default-small':
        args.path = './Datasets/ChitDefaultSmall/chit-default-small.csv'
        args.num_inp_attr = 57
        args.cols_to_norm = ['aucn_date', 'inst_due', 'inst_paid', 'inst_spread', 'total_inst_due', 'total_inst_paid', 'div_due', \
                             'div_paid', 'total_div_due', 'total_div_paid', 'all_bids', 'win_bid_amt', 'prized_amt', 'month', 'penalty', 'other_cost', \
                             'last_trans_date', 'last_payment_date', 'diff_inst', 'no_trans', 'total_trans', 'monthly_income', 'age']
        args.sensitive_attributes = ['sex']
        args.num_sensitives = 3
        args.output_col_name = 'default'
    elif args.dataset == 'parkinsons':
        args.path = './Datasets/Parkinsons/parkinsons_updrs_processed.csv'
        args.num_inp_attr = 19
        args.cols_to_norm = ['age', 'test_time', 'motor_UPDRS', 'Jitter(%)', 'Jitter(Abs)', 'Jitter:RAP', 'Jitter:PPQ5', 'Jitter:DDP', 'Shimmer', 'Shimmer(dB)', 'Shimmer:APQ3', 'Shimmer:APQ5', 'Shimmer:APQ11', 'Shimmer:DDA', 'NHR', 'HNR', 'RPDE', 'DFA', 'PPE']
        args.sensitive_attributes = ['sex']
        args.num_sensitives = 2
        args.output_col_name = 'total_UPDRS'
    else:
        raise ValueError("Dataset not found")

    args.results_db_path = os.path.join(args.results_dir, f"{args.pate_based_model}_{args.dataset}_{args.fairness_metric}_results.parquet")

    if args.use_optuna:
        args.verbose = True

    args.gt_fairness = args.fairness_metric in ['ErrorParity', 'EqualityOfOdds']

    serialized_name = f"{args.backend}_{args.dataset}_num-teachers_{args.num_teachers}_seed_{args.seed}"                      
    args.log_path = os.path.join(args.log_path, serialized_name)
    args.trained_model_path = os.path.join(args.log_path, "trained_models")
    Path(args.trained_model_path).mkdir(parents=True, exist_ok=True)

    return args