import argparse
from src.distributionally_fairness import wdf_calculate, delta_calculate
from src.out_of_sample_sensitivity import fixed_calculate, fit_calculate
import warnings
warnings.filterwarnings("ignore")


if __name__ == "__main__":
    
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=0)
    args = parser.parse_args()

    # Run for all datasets and classifiers
    datasets =  ['student', 'communities_crime', 'heritage_health', 'meps', 'acs_income', 'heloc', 'celeba', 'adult']
    classifiers = ['logistic', 'linear_svm', 'nonlinear_svm', 'mlp', 'gbm', 'adaboost']
    N_SAMPLES = 1000
    N_EXPERIMENTS = 10000 
    for dataset in datasets:
        for classifier in classifiers:
            print(f"\nAnalyzing {dataset} dataset with {classifier} classifier and Random seed {args.seed}...")
            # Run fixed_calculate
            print("Running fixed_calculate...")
            fixed_calculate(
                dataset=dataset, 
                classifier_name=classifier, 
                n_samples = N_SAMPLES, 
                n_experiments = N_EXPERIMENTS, 
                random_seed=args.seed
                )
            # Run fit_calculate
            print("Running fit_calculate...")
            fit_calculate(
                dataset=dataset, 
                classifier_name=classifier, 
                n_samples = N_SAMPLES, 
                n_experiments = N_EXPERIMENTS, 
                random_seed=args.seed
                )
            # Run wdf_calculate
            print("Running wdf_calculate...")
            wdf_calculate(
                dataset_name=dataset,
                classifier_name=classifier,
                n_samples=N_SAMPLES,
                n_experiments=N_EXPERIMENTS,
                random_seed=args.seed,
                delta=0.001,
                q=2
                )
            # Run delta_calculate
            print("Running delta_calculate...")
            delta_calculate(
                dataset_name=dataset,
                classifier_name=classifier,
                n_samples=N_SAMPLES,
                n_experiments=N_EXPERIMENTS,
                random_seed=args.seed,
                delta_steps=1000,
                q=2
                )
            